### 2. Define a LightningModule

In [10]:
import os
import torch
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning as L

# define any number of nn.Module 
encoder = nn.Sequential(nn.Linear(28*28, 64), nn.ReLU(), nn.Linear(64,3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28*28))



In [22]:
# define the LightningModule

class LitAutoEncoder(L.LightningModule):
  def __init__(self, encoder, decoder):
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder

  def training_step(self, batch, batch_idx):
    # training_step define training loop.
    # it is independent of forward
    x, y = batch
    x = x.view(x.size(0), -1)
    z = self.encoder(x)
    x_hat = self.decoder(z)
    loss = nn.functional.mse_loss(x_hat, x)
    # Logging to Tensorboard
    self.log("train loss: ", loss)
    return loss

  def validation_step(self, batch, batch_idx):
    x, y = batch
    x = x.view(x.size(0), -1)
    z = self.encoder(x)
    x_hat = self.decoder(z)
    val_loss = nn.functional.mse_loss(x_hat, x)
    self.log("val_loss", val_loss)
  
  # define test loop
  def test_step(self, batch, batch_idx):
    x, y = batch
    x = x.view(x.size(0), -1)  # (batch, -1)
    z = self.encoder(x)
    x_hat = self.decoder(z)
    test_loss = nn.functional.mse_loss(x_hat, x)
    self.log("test loss: ", test_loss)

  def configure_optimizers(self):
    optimizer = optim.AdamW(self.parameters(),lr=1e-3)
    return optimizer
  
# init the autoencoder
autoencorder = LitAutoEncoder(encoder, decoder)

### 3. Load Dataset

In [23]:
import torch.utils.data as data
from torchvision import datasets
import torchvision.transforms as transforms

# Load Dataset
transform = transforms.ToTensor()
train_set = datasets.MNIST(root='MNIST', download=True, train=True, transform=transform)
test_set = datasets.MNIST(root='MNIST', download=True, train=False, transform=transform)

# train_loader = data.DataLoader(train_set)

# add validation step
# use 20% of training data for validation
train_set_size = int(len(train_set)*0.8)
valid_set_size = len(train_set) - train_set_size

# split the train set into two
seed = torch.Generator().manual_seed(42)
train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size], generator=seed)

train_loader = DataLoader(train_set)
valid_loader = DataLoader(valid_set)

In [21]:
print(f"train_set size: {len(train_set)}, valid_set size: {len(valid_set)}")

train_set size: 48000, valid_set size: 12000


### 4. Train the model

In [24]:
# train the model
trainer = L.Trainer(limit_train_batches=100, max_epochs=1)
trainer.fit(autoencorder, train_loader, valid_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 50.4 K
1 | decoder | Sequential | 51.2 K
---------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)


                                                                            

/home/whpark/miniconda3/envs/dev/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/home/whpark/miniconda3/envs/dev/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Epoch 0: 100%|██████████| 100/100 [00:46<00:00,  2.13it/s, v_num=3] 

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 100/100 [00:46<00:00,  2.13it/s, v_num=3]


In [17]:
# Train with test loop
from torch.utils.data import DataLoader

trainer.test(autoencorder, dataloaders=DataLoader(test_set))

/home/whpark/miniconda3/envs/dev/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Testing DataLoader 0:   0%|          | 0/10000 [01:39<?, ?it/s]
Testing: |          | 10000/? [00:18<00:00, 541.26it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       test loss:           0.0677996352314949
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test loss: ': 0.0677996352314949}]

In [8]:
# load checkpoint
checkpoint = "./lightning_logs/version_0/checkpoints/epoch=0-step=100.ckpt"
autoencorder = LitAutoEncoder.load_from_checkpoint(checkpoint, encoder=encoder, decoder=decoder)

# choose your trained nn.Module
encoder = autoencorder.encoder
encoder.eval()

# embed 4 fake images!
fake_image_batch = torch.rand(4, 28*28, device=autoencorder.device)
embeddings = encoder(fake_image_batch)
print("⚡" * 20, "\nPredictions (4 image embeddings):\n", embeddings, "\n", "⚡" * 20)

⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡ 
Predictions (4 image embeddings):
 tensor([[-0.2780,  0.5091,  0.9623],
        [-0.3035,  0.5207,  0.8997],
        [-0.2540,  0.4455,  0.9214],
        [-0.3283,  0.4872,  1.0196]], grad_fn=<AddmmBackward0>) 
 ⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡
