### 2. Define a LightningModule

In [10]:
import os
import torch
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning as L

# define any number of nn.Module 
encoder = nn.Sequential(nn.Linear(28*28, 64), nn.ReLU(), nn.Linear(64,3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28*28))



In [3]:
# define the LightningModule

class LitAutoEncoder(L.LightningModule):
  def __init__(self, encoder, decoder):
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder

  def training_step(self, batch, batch_idx):
    # training_step define training loop.
    # it is independent of forward
    x, y = batch
    x = x.view(x.size(0), -1)
    z = self.encoder(x)
    x_hat = self.decoder(z)
    loss = nn.functional.mse_loss(x_hat, x)
    # Logging to Tensorboard
    self.log("train loss: ", loss)
    return loss
  
  def configure_optimizers(self):
    optimizer = optim.AdamW(self.parameters(),lr=1e-3)
    return optimizer
  
# init the autoencoder
autoencorder = LitAutoEncoder(encoder, decoder)

### 3. Define a dataset

In [4]:
# Load dataset
dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
train_loader = utils.data.DataLoader(dataset)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /home/whpark/playground/torch/lightning/autoencoder/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:01<00:00, 5036270.37it/s]


Extracting /home/whpark/playground/torch/lightning/autoencoder/MNIST/raw/train-images-idx3-ubyte.gz to /home/whpark/playground/torch/lightning/autoencoder/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /home/whpark/playground/torch/lightning/autoencoder/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 1436686.91it/s]


Extracting /home/whpark/playground/torch/lightning/autoencoder/MNIST/raw/train-labels-idx1-ubyte.gz to /home/whpark/playground/torch/lightning/autoencoder/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /home/whpark/playground/torch/lightning/autoencoder/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 1736379.81it/s]


Extracting /home/whpark/playground/torch/lightning/autoencoder/MNIST/raw/t10k-images-idx3-ubyte.gz to /home/whpark/playground/torch/lightning/autoencoder/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /home/whpark/playground/torch/lightning/autoencoder/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 9283883.42it/s]

Extracting /home/whpark/playground/torch/lightning/autoencoder/MNIST/raw/t10k-labels-idx1-ubyte.gz to /home/whpark/playground/torch/lightning/autoencoder/MNIST/raw






### 4. Train the model

In [5]:
# train the model
trainer = L.Trainer(limit_train_batches=100, max_epochs=1)
trainer.fit(autoencorder, train_dataloaders=train_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /home/whpark/playground/torch/lightning/autoencoder/lightning_logs

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 50.4 K
1 | decoder | Sequential | 51.2 K
---------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)
/home/whpark/miniconda3/envs/dev/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Epoch 0: 100%|██████████| 100/100 [00:00<00:00, 155.43it/s, v_num=0]

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 100/100 [00:00<00:00, 152.75it/s, v_num=0]


In [8]:
# load checkpoint
checkpoint = "./lightning_logs/version_0/checkpoints/epoch=0-step=100.ckpt"
autoencorder = LitAutoEncoder.load_from_checkpoint(checkpoint, encoder=encoder, decoder=decoder)

# choose your trained nn.Module
encoder = autoencorder.encoder
encoder.eval()

# embed 4 fake images!
fake_image_batch = torch.rand(4, 28*28, device=autoencorder.device)
embeddings = encoder(fake_image_batch)
print("⚡" * 20, "\nPredictions (4 image embeddings):\n", embeddings, "\n", "⚡" * 20)

⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡ 
Predictions (4 image embeddings):
 tensor([[-0.2780,  0.5091,  0.9623],
        [-0.3035,  0.5207,  0.8997],
        [-0.2540,  0.4455,  0.9214],
        [-0.3283,  0.4872,  1.0196]], grad_fn=<AddmmBackward0>) 
 ⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡
