In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split

In [None]:
#! pip install pytorch_lightning
import pytorch_lightning as pl

In [6]:
from torchvision import datasets

In [7]:
class Encoder(nn.Module):
  def __init__(self):
    super().__init__()
    self.l1 = nn.Sequential(nn.Linear(28*28, 64), nn.ReLU(), nn.Linear(64, 3))

  def forward(self, x):
    return self.l1(x)

class Decoder(nn.Module):
  def __init__(self):
    super().__init__()
    self.l1 = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28*28))

  def forward(self, x):
    return self.l1(x)

In [15]:
class LitAutoEncoder(pl.LightningModule):
  def __init__(self, encoder, decoder):
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder
    self.save_hyperparameters()

  def training_step(self, batch, batch_idx):
    # training step defines the train loop
    x, y = batch
    x = x.view(x.size(0), -1)
    z = self.encoder(x)
    x_hat = self.decoder(z)
    loss = F.mse_loss(x_hat, x)
    return loss

  def validation_step(self, batch, batch_idx):
    # this is the validation loop
    x, y = batch
    x = x.view(x.size(0), -1)
    z = self.encoder(x)
    x_hat = self.decoder(z)
    test_loss = F.mse_loss(x_hat, x)
    self.log("test_loss", test_loss)

  def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
    return optimizer

  def test_step(self, batch, batch_idx):
    # this is the test loop
    x, y = batch
    x = x.view(x.size(0), -1)
    z = self.encoder(x)
    x_hat = self.decoder(z)
    test_loss = F.mse_loss(x_hat, x)
    self.log("test_loss", test_loss)


In [16]:
# load data sets
train_set = datasets.MNIST(root = "MNIST", download=True, train=True, transform=transforms.ToTensor())


In [17]:
# use 20% of training data for validation
train_set_size = int(len(train_set)*0.8)
valid_set_size = len(train_set) - train_set_size

# split the train set into two
seed = torch.Generator().manual_seed(42)
train_set, valid_set = random_split(train_set, [train_set_size, valid_set_size], generator=seed)

In [18]:
train_set = DataLoader(train_set)
val_set = DataLoader(valid_set)

In [19]:
checkpoint_path = "/content/drive/MyDrive/Checkpoints"

In [20]:
# model
autoencoder = LitAutoEncoder(Encoder(), Decoder())

# initialise the Trainer
trainer = pl.Trainer(max_epochs = 5, default_root_dir=checkpoint_path)


# train with both splits
trainer.fit(autoencoder, train_set, val_set)

  f"Attribute {k!r} is an instance of `nn.Module` and is already saved during checkpointing."
  f"Attribute {k!r} is an instance of `nn.Module` and is already saved during checkpointing."
GPU available: True, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /content/drive/MyDrive/Checkpoints/lightning_logs

  | Name    | Type    | Params
------------------------------------
0 | encoder | Encoder | 50.4 K
1 | decoder | Decoder | 51.2 K
------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [21]:
test_set = datasets.MNIST(root = "MNIST", download=True, train=False, transform=transforms.ToTensor())

In [22]:
trainer.test(autoencoder, dataloaders=DataLoader(test_set))

Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           0.03933792561292648
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.03933792561292648}]

In [30]:
new_autoencoder = LitAutoEncoder.load_from_checkpoint("/content/drive/MyDrive/Checkpoints/lightning_logs/version_0/checkpoints/epoch=4-step=240000.ckpt")

# disable randomness, dropout, etc...
new_autoencoder.eval()


# predict with the model
#y_hat = new_autoencoder(x)

  f"Attribute {k!r} is an instance of `nn.Module` and is already saved during checkpointing."
  f"Attribute {k!r} is an instance of `nn.Module` and is already saved during checkpointing."


LitAutoEncoder(
  (encoder): Encoder(
    (l1): Sequential(
      (0): Linear(in_features=784, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=3, bias=True)
    )
  )
  (decoder): Decoder(
    (l1): Sequential(
      (0): Linear(in_features=3, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=784, bias=True)
    )
  )
)

In [31]:
checkpoint = torch.load("/content/drive/MyDrive/Checkpoints/lightning_logs/version_0/checkpoints/epoch=4-step=240000.ckpt")
print(checkpoint["hyper_parameters"])

{'encoder': Encoder(
  (l1): Sequential(
    (0): Linear(in_features=784, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=3, bias=True)
  )
), 'decoder': Decoder(
  (l1): Sequential(
    (0): Linear(in_features=3, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=784, bias=True)
  )
)}


In [None]:
checkpoint