In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import lightning as L

In [7]:
DEVICE='cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# import mnist
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import v2 as T

In [None]:
tfms=T.Compose([
  T.ToImage(),
  T.ToDtype(torch.float32),
  T.Normalize((0.5,),(0.5,))
])

train_ds=datasets.MNIST("./data/mnist",train=True,transform=tfms,download=True)
valid_ds=datasets.MNIST("./data/mnist",train=False,transform=tfms,download=False)

len(train_ds),len(valid_ds)

In [None]:
train_dl=DataLoader(train_ds,batch_size=64,shuffle=True,num_workers=2,pin_memory=True)
valid_dl=DataLoader(train_ds,batch_size=128,shuffle=False,num_workers=2,pin_memory=True)

len(train_dl),len(valid_dl)

In [45]:
from torchvision.utils import make_grid
import matplotlib.pyplot  as plt

def show_images(images:torch.Tensor,mean=(0.5,),std=(0.5,)):
  mean,std=torch.tensor(mean),torch.tensor(std)
  images=images*std+mean

  nrows=images.shape[0]//8
  grid=make_grid(images,nrows=nrows)
  grid=grid.permute(1,2,0)
  plt.imshow(grid)
  plt.show()

In [None]:
for images,_ in train_dl:
  show_images(images)
  break

In [28]:
import itertools
from torchvision.utils import make_grid

class ConvBlock(nn.Module):
  def __init__(self,out_channels:int,kernel_size:int,*,dropout=0.2,**kwargs):
    super().__init__()
    self.conv_block=nn.Sequential(
      nn.LazyConv2d(out_channels,kernel_size,**kwargs),
      nn.LeakyReLU(inplace=True),
      nn.BatchNorm2d(out_channels),
      nn.Dropout2d(dropout)
    )

  def forward(self,X:torch.Tensor):
    return self.conv_block(X)
  

class DeConvBlock(nn.Module):
  def __init__(self,out_channels:int,kernel_size:int,*,dropout=0.2,**kwargs):
    super().__init__()
    self.deconv_block=nn.Sequential(
      nn.LazyConvTranspose2d(out_channels,kernel_size,**kwargs),
      nn.LeakyReLU(negative_slope=0.2,inplace=True),
      nn.BatchNorm2d(out_channels)
    )

  def forward(self,X:torch.Tensor):
    return self.deconv_block(X)
  



class ConvAutoEncoder(L.LightningModule):
  def __init__(self,input_shape:tuple|list,encoder_config:list[int,int,int]|tuple[int,int,int],projection_dim:int,decoder_config:list[int,int]|tuple[int,int]):
    super().__init__()

    self.input_shape=input_shape

    encoder=[]
    for out_channels,kernel_size,apply_pool in encoder_config:
      encoder.append(ConvBlock(out_channels,kernel_size,padding='same'))
      if apply_pool:
        encoder.append(nn.MaxPool2d((2,2)))
    encoder.append(nn.AdaptiveAvgPool2d(projection_dim))
    self.encoder=nn.Sequential(*encoder)

    decoder=[]
    for out_channels,kernel_size in decoder_config:
      decoder.append(DeConvBlock(out_channels,kernel_size))
    decoder.append(nn.AdaptiveAvgPool2d(input_shape[1:]))
    self.decoder=nn.Sequential(*decoder)

    self.training_step_outputs=[]
    self.validation_step_outputs=[]
  
  def forward(self, X: torch.Tensor) -> torch.Tensor:
    X=self.encoder(X)
    print(X.shape)
    X=self.decoder(X)
    print(X.shape)
    return X

  def configure_optimizers(self):
    optimizer = torch.optim.AdamW(
    itertools.chain(self.encoder.parameters(), self.decoder.parameters()),lr=3e-4)
    return optimizer



  def training_step(self, batch):
    X,_=batch
    out = self.forward(X)
    loss = F.mse_loss(X, out)
    self.training_step_output.append(loss)
    return  loss

  def validation_step(self, batch,batch_idx):
    X,_=batch
    # Generate images
    if batch_idx == 0:
      gen_images = self.forward(X)

      # Normalize the generated images to be between 0 and 1
      gen_images = gen_images * 0.5 + 0.5
      grid = make_grid(gen_images[:32], 8)
      self.logger.experiment.add_image("Generated Images Grid", grid, self.current_epoch)

    out = self.forward(X)
    loss = F.mse_loss(X, out)
    self.validation_step_output.append(loss)
    return loss

  def predict_step(self, batch):
      X,_=batch
      out = self.forward(X)
      return out



  def on_train_epoch_end(self) -> None:
      self.on_epoch_end(self.training_step_output, "training_loss")
      print()

  def on_validation_epoch_end(self) -> None:
      self.on_epoch_end(self.validation_step_output, "validation_loss")
      self.epoch+=1


  def on_epoch_end(self, data: list, log_str: str):
      avg_loss = torch.tensor(data).mean()
      print(f"{log_str}: {avg_loss}")
      self.log(log_str, avg_loss)
      # free up the memory
      data.clear()

In [41]:
input_shape=(1,28,28)
encoder_config=[(8,7,1),(16,5,1),(32,1,0),(64,3,0),(16,3,1)]
projection_dim=8
decoder_config=[(8,7),(16,5),(64,3),(128,1),(1,3)]

model=ConvAutoEncoder(input_shape,encoder_config,projection_dim,decoder_config).to(DEVICE)
model

ConvAutoEncoder(
  (encoder): Sequential(
    (0): ConvBlock(
      (conv_block): Sequential(
        (0): LazyConv2d(0, 8, kernel_size=(7, 7), stride=(1, 1), padding=same)
        (1): LeakyReLU(negative_slope=0.01, inplace=True)
        (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): Dropout2d(p=0.2, inplace=False)
      )
    )
    (1): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (2): ConvBlock(
      (conv_block): Sequential(
        (0): LazyConv2d(0, 16, kernel_size=(5, 5), stride=(1, 1), padding=same)
        (1): LeakyReLU(negative_slope=0.01, inplace=True)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): Dropout2d(p=0.2, inplace=False)
      )
    )
    (3): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (4): ConvBlock(
      (conv_block): Sequential(
        (0): LazyConv2d(0, 32, kerne

In [42]:
imgs=torch.randn((2,1,28,28))
out=model(imgs)
out.shape

torch.Size([2, 16, 8, 8])
torch.Size([2, 1, 28, 28])


torch.Size([2, 1, 28, 28])

In [43]:
from torchsummary import summary
summary(model,(1,28,28))

torch.Size([2, 16, 8, 8])
torch.Size([2, 1, 28, 28])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 8, 28, 28]             400
         LeakyReLU-2            [-1, 8, 28, 28]               0
       BatchNorm2d-3            [-1, 8, 28, 28]              16
         Dropout2d-4            [-1, 8, 28, 28]               0
         ConvBlock-5            [-1, 8, 28, 28]               0
         MaxPool2d-6            [-1, 8, 14, 14]               0
            Conv2d-7           [-1, 16, 14, 14]           3,216
         LeakyReLU-8           [-1, 16, 14, 14]               0
       BatchNorm2d-9           [-1, 16, 14, 14]              32
        Dropout2d-10           [-1, 16, 14, 14]               0
        ConvBlock-11           [-1, 16, 14, 14]               0
        MaxPool2d-12             [-1, 16, 7, 7]               0
           Conv2d-13             [-1, 32, 7, 7]   

In [44]:
print("Encoder:")
summary(model.encoder,(1,28,28))
print("\nDecoder:")
summary(model.decoder,(16,8,8))

Encoder:
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 8, 28, 28]             400
         LeakyReLU-2            [-1, 8, 28, 28]               0
       BatchNorm2d-3            [-1, 8, 28, 28]              16
         Dropout2d-4            [-1, 8, 28, 28]               0
         ConvBlock-5            [-1, 8, 28, 28]               0
         MaxPool2d-6            [-1, 8, 14, 14]               0
            Conv2d-7           [-1, 16, 14, 14]           3,216
         LeakyReLU-8           [-1, 16, 14, 14]               0
       BatchNorm2d-9           [-1, 16, 14, 14]              32
        Dropout2d-10           [-1, 16, 14, 14]               0
        ConvBlock-11           [-1, 16, 14, 14]               0
        MaxPool2d-12             [-1, 16, 7, 7]               0
           Conv2d-13             [-1, 32, 7, 7]             544
        LeakyReLU-14          

In [None]:
from lightning import Trainer
from lightning.pytorch.loggers import TensorBoardLogger

logger=TensorBoardLogger("./logs/tensorboard")
DEVICE='cuda' if torch.cuda.is_available() else 'cpu'
epochs=12

trainer=Trainer(accelerator=DEVICE,devices='auto',logger=logger,fast_dev_run=True,max_epochs=epochs,enable_progress_bar=True)

In [None]:
trainer.fit(model,train_dl,valid_dl)