In [2]:
!pip install pytorch_lightning
!pip install wandb

Collecting pytorch_lightning
  Downloading pytorch_lightning-2.2.1-py3-none-any.whl (801 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m801.6/801.6 kB[0m [31m15.1 MB/s[0m eta [36m0:00:00[0m
Collecting torchmetrics>=0.7.0 (from pytorch_lightning)
  Downloading torchmetrics-1.3.1-py3-none-any.whl (840 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m840.4/840.4 kB[0m [31m51.3 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.8.0 (from pytorch_lightning)
  Downloading lightning_utilities-0.10.1-py3-none-any.whl (24 kB)
Installing collected packages: lightning-utilities, torchmetrics, pytorch_lightning
Successfully installed lightning-utilities-0.10.1 pytorch_lightning-2.2.1 torchmetrics-1.3.1
Collecting wandb
  Downloading wandb-0.16.3-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m27.7 MB/s[0m eta [36m0:00:00[0m
Collecting GitPython!=3.1.29,>=1.0.0 (from wandb)
 

In [3]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl
import wandb

class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=128, data_dir: str = "~/torch_datasets"):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.5,), (0.5,))
        ])
        self.dims = (1, 28, 28)
        self.num_classes = 10

    def prepare_data(self):
        torchvision.datasets.MNIST(self.data_dir, train=True, download=True)
        torchvision.datasets.MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            mnist_full = torchvision.datasets.MNIST(self.data_dir, train=True, transform=self.transform)
            self.train_dataset, self.val_dataset = random_split(mnist_full, [55000, 5000])
        if stage == 'test' or stage is None:
            self.test_dataset = torchvision.datasets.MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)

class AE(pl.LightningModule):
    def __init__(self, input_shape=784, learning_rate=1e-3, epochs=10, batch_size=128):
        super().__init__()
        self.save_hyperparameters()  # Save hyperparameters for logging
        self.encoder_hidden_layer = nn.Linear(in_features=input_shape, out_features=128)
        self.encoder_output_layer = nn.Linear(in_features=128, out_features=128)
        self.decoder_hidden_layer = nn.Linear(in_features=128, out_features=128)
        self.decoder_output_layer = nn.Linear(in_features=128, out_features=input_shape)

    def forward(self, features):
        activation = self.encoder_hidden_layer(features)
        activation = torch.relu(activation)
        code = self.encoder_output_layer(activation)
        code = torch.relu(code)
        activation = self.decoder_hidden_layer(code)
        activation = torch.relu(activation)
        activation = self.decoder_output_layer(activation)
        reconstructed = torch.relu(activation)
        return reconstructed

    def training_step(self, batch, batch_idx):
        x, _ = batch
        reconstructions = self(x.view(x.size(0), -1))
        loss = F.mse_loss(reconstructions, x.view(x.size(0), -1))
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        # Log images
        if batch_idx % 100 == 0:
            original = x[:8]  # Log the first 8 images
            reconstructed = reconstructions[:8]  # Log the first 8 reconstructed images
            self.log_images(original, reconstructed, prefix='train')

        return loss

    def validation_step(self, batch, batch_idx):
        x, _ = batch
        reconstructions = self(x.view(x.size(0), -1))
        loss = F.mse_loss(reconstructions, x.view(x.size(0), -1))
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)

        # Log images
        if batch_idx % 100 == 0:
            original = x[:8]  # Log the first 8 images
            reconstructed = reconstructions[:8]  # Log the first 8 reconstructed images
            self.log_images(original, reconstructed, prefix='val')

        return loss

    def test_step(self, batch, batch_idx):
        x, _ = batch
        reconstructions = self(x.view(x.size(0), -1))
        loss = F.mse_loss(reconstructions, x.view(x.size(0), -1))
        self.log('test_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)

        # Log images
        if batch_idx % 100 == 0:
            original = x[:8]  # Log the first 8 images
            reconstructed = reconstructions[:8]  # Log the first 8 reconstructed images
            self.log_images(original, reconstructed, prefix='test')

        return loss

    def log_images(self, original, reconstructed, prefix='train'):
        original_grid = torchvision.utils.make_grid(original, nrow=8, normalize=True)
        reconstructed_grid = torchvision.utils.make_grid(reconstructed, nrow=8, normalize=True)
        wandb.log({f'{prefix}_original': [wandb.Image(original_grid)]})
        wandb.log({f'{prefix}_reconstructed': [wandb.Image(reconstructed_grid)]})

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        return optimizer

# Initialize Wandb
wandb.init(project='autoencoder')
"""
run = wandb.init()
artifact = run.use_artifact('oudaisalameh/autoencoder/run-u2ri2izv-history:v0', type='wandb-history')
artifact_dir = artifact.download()
"""

"""
from my_module import MyModel  # Import your LightningModule subclass
checkpoint = torch.load("checkpoint.ckpt")
# Load the model from the checkpoint
model = MyModel.load_from_checkpoint(checkpoint_path)
"""

# Create AE model
model = AE()

# DataModule
dm = MNISTDataModule()
dm.prepare_data()
dm.setup()

# Callbacks
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor='val_loss',
    dirpath='./checkpoints',
    filename='autoencoder-{epoch:02d}-{val_loss:.2f}',
    save_top_k=1,
    mode='min',
    save_weights_only=False  # Save entire model
)

# Trainer
trainer = pl.Trainer(max_epochs=model.hparams.epochs, logger=pl.loggers.WandbLogger(), callbacks=[checkpoint_callback])#,resume_from_checkpoint='checkpoint.ckpt'

# Training
trainer.fit(model, dm)

# Testing
trainer.test(datamodule=dm)

wandb.finish()  # Close the wandb run after training and testing finishes


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /root/torch_datasets/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 117038851.57it/s]

Extracting /root/torch_datasets/MNIST/raw/train-images-idx3-ubyte.gz to /root/torch_datasets/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /root/torch_datasets/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 18370593.54it/s]


Extracting /root/torch_datasets/MNIST/raw/train-labels-idx1-ubyte.gz to /root/torch_datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /root/torch_datasets/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 95194651.02it/s]

Extracting /root/torch_datasets/MNIST/raw/t10k-images-idx3-ubyte.gz to /root/torch_datasets/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /root/torch_datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 8659331.26it/s]


Extracting /root/torch_datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz to /root/torch_datasets/MNIST/raw



INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loggers/wandb.py:391: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
INFO:pytorch_lightning.callbacks.model_summary:
  | Name                 | Type   | Params
------------------------------------------------
0 | encoder_hidden_layer | Linear | 100 K 
1 | encoder_output_layer | Linear | 16.5 K
2 | decoder_hidden_layer | Linear | 16.5 K
3 | decoder_output_layer | Linear | 101 K 
------------------------------------------------
234 K     Trainable params
0         Non-trainable 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.
INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at /content/checkpoints/autoencoder-epoch=09-val_loss=0.89.ckpt
INFO:pytorch_lightning.utilities.rank_zero:Loaded model weights from the checkpoint at /content/checkpoints/autoencoder-epoch=09-val_loss=0.89.ckpt


Testing: |          | 0/? [00:00<?, ?it/s]

VBox(children=(Label(value='0.287 MB of 0.287 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▂▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▅▆▆▆▇▇▇▇▇▇▇▇█
test_loss,▁
train_loss_epoch,█▃▂▂▂▂▁▁▁▁
train_loss_step,█▆▅▄▃▄▃▂▃▂▃▃▂▂▃▃▃▃▃▂▂▂▃▂▂▂▁▂▂▁▂▂▂▃▂▂▂▂▃▂
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████
val_loss,█▄▃▂▂▂▁▁▁▁

0,1
epoch,10.0
test_loss,0.8901
train_loss_epoch,0.89093
train_loss_step,0.8895
trainer/global_step,4300.0
val_loss,0.89142
