# Lightning Module:

In [None]:
import zarr

def read_zarr_file(file_path, array_or_group_key=None):
    """
    Read and extract data from a .zarr file.

    Parameters:
    - file_path: str, the path to the .zarr file.
    - array_or_group_key: str, optional key specifying which array or group to extract from the Zarr store.

    Returns:
    Zarr array or group, depending on what is stored in the file.
    """
    # Open Zarr file
    root = zarr.open(file_path, mode='r')

    if array_or_group_key is None:
        # Return the root group or array if no key is specified
        return root
    else:
        # Otherwise, return the specified array or group
        return root[array_or_group_key]

# Usage example
file_path = "your_file.zarr"

# To read the root array or group
root = read_zarr_file(file_path)

# To read a specific array or group
specific_array_or_group = read_zarr_file(file_path, "array_or_group_key_here")


In [3]:
# load autoreload:
%load_ext autoreload
import os
from torch import optim, nn, utils, Tensor
from torchvision.transforms import ToTensor
import pytorch_lightning as pl
import torch.nn as nn
import torch.optim as optim


from Models.PhiNet import PhiNet

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
x = Tensor(1, 2, 128, 128)

model = PhiNet(use_SA=False)
y = model(x)

In [None]:


class LitModel(pl.LightningModule):
    def __init__(self, PhiNet):
        super().__init__()
        self.model = PhiNet()

    def forward(self, x):
        return self.model(x)

    def common_step(self, batch, batch_idx, stage=None):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.model.encoder(x)
        x_hat = self.model.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        
        if stage == 'train':
            self.log("train_loss", loss)
        elif stage == 'val':
            self.log("val_loss", loss)
        elif stage == 'test':
            self.log("test_loss", loss)
        
        return loss

    def training_step(self, batch, batch_idx):
        return self.common_step(batch, batch_idx, 'train')

    def validation_step(self, batch, batch_idx):
        return {"val_loss": self.common_step(batch, batch_idx, 'val')}

    def test_step(self, batch, batch_idx):
        return {"test_loss": self.common_step(batch, batch_idx, 'test')}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        self.log("avg_val_loss", avg_loss)

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        self.log("avg_test_loss", avg_loss)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

# Assuming encoder and decoder are defined
# PhiNet should include both encoder and decoder as self.encoder and self.decoder respectively
autoencoder = LitModel(PhiNet)


# Dataloader:

In [None]:
# setup data
dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
train_loader = utils.data.DataLoader(dataset)

# Trainer Fit:

In [None]:
# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = pl.Trainer(limit_train_batches=100, max_epochs=100)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

# Use the model:

In [None]:
# load checkpoint
checkpoint = "./lightning_logs/version_0/checkpoints/epoch=0-step=100.ckpt"
autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint, encoder=encoder, decoder=decoder)

# choose your trained nn.Module
encoder = autoencoder.encoder
encoder.eval()

# embed 4 fake images!
fake_image_batch = Tensor(4, 28 * 28)
embeddings = encoder(fake_image_batch.cuda())
print("⚡" * 20, "\nPredictions (4 image embeddings):\n", embeddings, "\n", "⚡" * 20)