# AutoEncoders

> Collection of Autoencoder models

In [None]:
#| default_exp models.autoencoders

In [None]:
#| hide
%load_ext autoreload
%autoreload 2
from nbdev.showdoc import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
#| export

import torch.nn as nn
import torch
from torch.utils.data import DataLoader
from lightning import LightningModule, Trainer

from hydra.utils import instantiate
from omegaconf import OmegaConf

from nimrod.image.datasets import ImageDataset
from nimrod.modules import Encoder, Decoder
from nimrod.models.conv import ConvLayer, ConvNet
from nimrod.utils import time_it

Seed set to 42


## Overview
A flexible, powerful autoencoder implementation using PyTorch and PyTorch Lightning, designed for representation learning, dimensionality reduction, and generative modeling.

## Core Components

### AutoEncoder Class
A modular autoencoder with configurable encoder and decoder:
- Supports custom encoder and decoder architectures
- Simple forward pass for data reconstruction
- Flexible input tensor handling

#### Key Parameters
- `encoder`: Custom encoder layer
- `decoder`: Custom decoder layer

### AutoEncoderPL Class
A PyTorch Lightning wrapper for autoencoder training:
- Integrated loss computation
- Standardized training, validation, and test steps
- Automatic logging of reconstruction loss
- Adam optimizer with configurable learning rate

## Features
- Modular design with separate encoder and decoder
- PyTorch Lightning integration
- Mean Squared Error (MSE) reconstruction loss
- Supports batch processing
- Easy to extend and customize

## Architectural Variants
- Standard Autoencoders
- Variational Autoencoders (VAE)
- Denoising Autoencoders
- Sparse Autoencoders
- Convolutional Autoencoders

## Supported Operations
- Data reconstruction
- Representation learning
- Dimensionality reduction
- Feature extraction
- Batch prediction
- Model evaluation

## Dependencies
- PyTorch
- PyTorch Lightning
- Nimrod custom modules (Encoder, Decoder)

## Usage Example
```python
# Create encoder and decoder
enc = Encoder()
dec = Decoder()

# Instantiate autoencoder
autoencoder = AutoEncoder(enc, dec)

# Wrap with Lightning module
pl_model = AutoEncoderPL(autoencoder)

# Train using PyTorch Lightning Trainer
trainer.fit(pl_model)
```

In [None]:
cfg = OmegaConf.load('../config/data/image/image.yaml')
dm = instantiate(cfg, name='fashion_mnist', data_dir='../data/image/')
dm.prepare_data()
dm.setup()
print(dm.num_classes)

[23:02:45] INFO - Init ImageDataModule for fashion_mnist
[23:02:45] INFO - fashion_mnist Dataset: init
[23:02:50] INFO - fashion_mnist Dataset: init
[23:02:53] INFO - split train into train/val [0.8, 0.2]
[23:02:53] INFO - train: 48000 val: 12000, test: 10000


10


In [None]:
cfg = OmegaConf.load('../config/model/image/conv.yaml')
model = instantiate(cfg, num_classes=dm.num_classes)

[23:04:48] INFO - ConvNetX: init
[23:04:48] INFO - Classifier: init
/Users/slegroux/miniforge3/envs/nimrod/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:208: Attribute 'nnet' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['nnet'])`.


In [None]:
#| notest

dm.batch_size = 256
lr = 0.4

trainer = Trainer(max_epochs=5)
trainer.fit(model, dm.train_dataloader(), dm.val_dataloader())
trainer.test(model, dm.test_dataloader())

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
[23:08:25] INFO - Optimizer: Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.0001
    maximize: False
    weight_decay: 0
)
[23:08:25] INFO - Scheduler: <torch.optim.lr_scheduler.OneCycleLR object>
/Users/slegroux/miniforge3/envs/nimrod/lib/python3.11/site-packages/lightning/pytorch/core/optimizer.py:316: The lr scheduler dict contains the key(s) ['monitor', 'strict'], but the keys will be ignored. You need to call `lr_scheduler.step()` manually in manual optimization.

  | Name         | Type               | Params | Mode 
------------------------------------------------------------
0 | loss         | CrossEntropyLoss   | 0      | train
1 | train_acc    | MulticlassAccuracy | 0      | train
2 | val_acc      | MulticlassAccuracy | 0      | train
3 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/slegroux/miniforge3/envs/nimrod/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/Users/slegroux/miniforge3/envs/nimrod/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.
/Users/slegroux/miniforge3/envs/nimrod/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test/loss': 0.595595121383667, 'test/acc': 0.8274000287055969}]

In [None]:
#| export
class AutoEncoder(nn.Module):
    """ A modular autoencoder with configurable encoder and decoder """
    def __init__(self,
        encoder:nn.Module, # Encoder layer
        decoder:nn.Module # Decoder layer
        ):

        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(
        self,
        x:torch.Tensor # Tensor B x C X H X W
        )->torch.Tensor: # Reconstructed input tensor of shape B x C X H X W

        """
        Forward pass of the AutoEncoder model.
        """

        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat

In [None]:
show_doc(AutoEncoder.forward)

---

### AutoEncoder.forward

>      AutoEncoder.forward (x:torch.Tensor)

*Forward pass of the AutoEncoder model.*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| x | Tensor | Tensor B x C X H X W |
| **Returns** | **Tensor** | **Reconstructed input tensor of shape B x C X H X W** |

In [None]:
enc = Encoder()
dec = Decoder()
a = AutoEncoder(enc, dec)
batch = torch.rand((10, 28*28))
y = a(batch)
print(y.shape)

torch.Size([10, 784])


In [None]:
ds = ImageDataset(name='fashion_mnist', data_dir='../data/image/')
dl = DataLoader(ds)
b = next(iter(dl))
print(len(b), b[0].shape, b[1].shape)


[22:49:28] INFO - fashion_mnist Dataset: init


Downloading data:   0%|          | 0.00/30.9M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.18M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/60000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

2 torch.Size([1, 1, 28, 28]) torch.Size([1])


## AutoEncoder_X

In [None]:
#| export
class AutoEncoderPL(LightningModule):
    """ LightningModule for AutoEncoder """
    def __init__(
        self,
        autoencoder:AutoEncoder # AutoEncoder instance
        ):
        super().__init__()
        # self.save_hyperparameters()
        self.save_hyperparameters(ignore=['autoencoder'])
        self.autoencoder = autoencoder
        self.metric = torch.nn.MSELoss()

    def forward(
        self,
        x: torch.Tensor # Tensor B x L
        )->torch.Tensor: # Reconstructed input tensor of shape B x L
        """
        Forward pass of the AutoEncoder model.
        """
        return self.autoencoder(x)
    
    def predict_step(self, batch, batch_idx):
        """
        Forward pass of the AutoEncoder model.
        """
        x, y = batch
        x = x.view(x.size(0), -1)
        with torch.no_grad():
            return self.autoencoder(x)

    def _shared_eval(self, batch, batch_idx, prefix, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True):
        x, _ = batch
        x = x.view(x.size(0), -1) # flatten B x C x H x W to B x L (grey pic)
        x_hat = self.autoencoder(x)
        loss = self.metric(x_hat, x)
        self.log(f"{prefix}/loss", loss, on_step=on_step, on_epoch=on_epoch, sync_dist=sync_dist)
        return loss

    def training_step(self, batch, batch_idx):
        return self._shared_eval(batch, batch_idx, "train")
    
    def test_step(self, batch, batch_idx):
        self._shared_eval(batch, batch_idx, "test")
    
    def validation_step(self, batch, batch_idx):
        return self._shared_eval(batch, batch_idx, "val")

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


In [None]:
show_doc(AutoEncoderPL.forward)

---

### AutoEncoderPL.forward

>      AutoEncoderPL.forward (x:torch.Tensor)

*Forward pass of the AutoEncoder model.*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| x | Tensor | Tensor B x L |
| **Returns** | **Tensor** | **Reconstructed input tensor of shape B x L** |

In [None]:
#| hide
# def on_validation_batch_end(
#         self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx, wadn
#         ):
#         # `outputs` comes from `LightningModule.validation_step`
#         # which corresponds to our model predictions in this case
        
#         # Let's log 20 sample image predictions from the first batch
#         if batch_idx == 0:
#             n = 20
#             x, y = batch
#             images = [img for img in x[:n]]
#             captions = [f'Ground Truth: {y_i} - Prediction: {y_pred}' 
#                 for y_i, y_pred in zip(y[:n], outputs[:n])]
            
            
#             # Option 1: log images with `WandbLogger.log_image`
#             wandb_logger.log_image(
#                 key='sample_images', 
#                 images=images, 
#                 caption=captions)


In [None]:
autoencoder_pl = AutoEncoderPL(a)
b = torch.rand((5,28*28))
y = autoencoder_pl(b)
print(y.shape)

torch.Size([5, 784])


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()