In [None]:
from pathlib import Path
from typing import Any, List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import io, transforms, models
import torchvision.transforms.functional as TF

# Wandb login:
from kaggle_secrets import UserSecretsClient
import wandb
user_secrets = UserSecretsClient()
secret_value = user_secrets.get_secret("wandb_api_key")
wandb.login(key=secret_value)

%matplotlib inline

In [None]:
FILES = "/kaggle/input/celeba-dataset/img_align_celeba/img_align_celeba/"
# TRAIN_FILES = "/kaggle/input/coco-2017-dataset/coco2017/train2017/"
# VALID_FILES = "/kaggle/input/coco-2017-dataset/coco2017/val2017/"
IMAGE_SIZE = 64
BATCH_SIZE = 64
EPOCHS = 10
LR = 1e-3
CHANNELS = 3
VALID_IMAGES = 5

In [None]:
class ImageData(Dataset):
    def __init__(self, files: List[str]):
        self.files = files
        self.resize = transforms.Resize((IMAGE_SIZE, IMAGE_SIZE))
        
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, i):
        img = io.read_image(self.files[i])
        img = self.resize(img)
        
        if img.shape[0] == 1:
            img = torch.cat([img]*3)

        return img / 255.0 - 0.5

files = [str(file) for file in Path(FILES).glob("*.jpg")]
train_files, valid_files = train_test_split(files, test_size=0.1)
# train_files = [str(file) for file in Path(TRAIN_FILES).glob("*.jpg")]
# valid_files = [str(file) for file in Path(VALID_FILES).glob("*.jpg")]
train_ds = ImageData(train_files)
valid_ds = ImageData(valid_files)
train_dl = DataLoader(
    train_ds, 
    BATCH_SIZE, 
    shuffle=True, 
    drop_last=True, 
    num_workers=4,
    pin_memory=True,
)
valid_dl = DataLoader(
    valid_ds, 
    BATCH_SIZE*2, 
    shuffle=False, 
    drop_last=False, 
    num_workers=4,
    pin_memory=True,
)

In [None]:
x = next(iter(train_dl))
len(train_ds), len(valid_ds), x.shape, x.mean(), x.std()

## Model

In [None]:
class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels, scale_factor=2):
        super().__init__()
        self.conv2d_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2d_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=2)
        self.batch_norm = nn.BatchNorm2d(out_channels)
        
    def forward(self, x):
        x = F.gelu(self.conv2d_1(x))
        x = F.gelu(self.conv2d_2(x))
        
        return self.batch_norm(x)

class Encoder(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.downsampler = nn.ModuleList(
            [DownSample(c_in, c_out) for c_in, c_out in zip(channels[:-1], channels[1:])]
        )
        self.squeeze_wh_1 = nn.Conv2d(channels[-1], channels[-1], kernel_size=2)
        self.squeeze_wh_2 = nn.Conv2d(channels[-1], channels[-1], kernel_size=2)
                
    def forward(self, x):
        for downsample in self.downsampler:
            x = downsample(x)
        
        mu = self.squeeze_wh_1(x)
        log_var = self.squeeze_wh_2(x)
        return mu, log_var

In [None]:
class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels, scale_factor=2):
        super().__init__()
        self.up_sample = nn.Upsample(scale_factor=scale_factor)
        self.conv2d_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2d_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.batch_norm = nn.BatchNorm2d(out_channels)
        
    def forward(self, x):
        out = self.up_sample(x)
        out = F.gelu(self.conv2d_1(out))
        out = F.gelu(self.conv2d_2(out))
        
        return self.batch_norm(out)
    
class Decoder(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.upsampler = nn.ModuleList(
            [UpSample(c_in, c_out) for c_in, c_out in zip(channels[:-1], channels[1:])]
        )
        self.conv_1 = nn.Conv2d(channels[-1], CHANNELS, kernel_size=1)
        self.conv_2 = nn.Conv2d(CHANNELS, CHANNELS, kernel_size=1)
        
    def forward(self, z):
        for upsample in self.upsampler:
            z = upsample(z)
            
        z = F.leaky_relu(self.conv_1(z))
        return torch.sigmoid(self.conv_2(z)) - 0.5

## KL Divergence
See this [stack exchange](https://stats.stackexchange.com/questions/318184/kl-loss-with-a-unit-gaussian) question for definition of KL divergence:

In [None]:
def kl_divergence(mu: torch.FloatTensor, log_var: torch.FloatTensor) -> torch.FloatTensor:
    kl_divergence_per_instance = -0.5 * (1 + log_var - torch.square(mu) - torch.exp(log_var))
    return kl_divergence_per_instance.mean()

## Lightning Trainer

In [None]:
def inv_transform(image: torch.FloatTensor):
    return (image + 0.5)

def get_wandb_images(image: torch.FloatTensor, reconstruction: torch.FloatTensor):
    return [
        wandb.Image(TF.to_pil_image(inv_transform(image))),
        wandb.Image(TF.to_pil_image(inv_transform(reconstruction))),
    ]

In [None]:
class LightningModel(pl.LightningModule):
    def __init__(
        self,
        encoder: nn.Module,
        decoder: nn.Module,
        learning_rate: float,
    ):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.learning_rate = learning_rate

    def common_step(
        self,
        x: torch.FloatTensor,
    ) -> torch.FloatTensor:
        mu, log_var = self.encoder(x)
        z = mu + torch.exp(0.5 * log_var) * torch.randn(mu.shape).to(mu.device)
        out = self.decoder(z)
        
        kl_loss = kl_divergence(mu, log_var)
        reconstruction_loss = F.mse_loss(x, out)
        loss = kl_loss + reconstruction_loss
        
        return loss, reconstruction_loss, out

    def training_step(
        self, x: torch.FloatTensor, *args: List[Any]
    ) -> torch.Tensor:
        loss, reconstruction_loss, reconstruction = self.common_step(x)
        self.log(name="Training loss", value=loss, on_step=True, on_epoch=True)
        self.log(name="Training reconstruction loss", value=reconstruction_loss, on_step=True, on_epoch=True)
        return loss
    
    def validation_step(
        self, x: torch.FloatTensor, *args: List[Any]
    ) -> None:
        loss, reconstruction_loss, reconstruction = self.common_step(x)
        self.log(name="Validation loss", value=loss, on_step=True, on_epoch=True)
        self.log(name="Validation reconstruction loss", value=reconstruction_loss, on_step=True, on_epoch=True)
        return x.cpu(), reconstruction.cpu()
        
    def validation_epoch_end(self, validation_step_outputs):
        images, preds = zip(*validation_step_outputs)
        images = torch.cat(images, dim=0)
        preds = torch.cat(preds, dim=0)
        columns = ["image", "reconstruction"]
        indices = np.random.choice(len(images), VALID_IMAGES, replace=False)
        rows = [get_wandb_images(images[i], preds[i]) for i in indices]
        table = wandb.Table(data=rows, columns=columns)
        self.logger.experiment.log({f"epoch {self.current_epoch + 1} results": table})
        
    def on_after_backward(self):
        if self.trainer.global_step % 50 == 0:  # don't make the tf file huge
            with torch.no_grad():
                for name, param in self.named_parameters():
                    if "weight" in name and not "norm" in name and param.requires_grad:
                        self.logger.experiment.log({f"{name}": wandb.Histogram(param.cpu())})
                        self.logger.experiment.log({f"{name}_grad": wandb.Histogram(param.grad.cpu())})
                        
    def configure_optimizers(self) -> torch.optim.Optimizer:
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)

In [None]:
!mkdir /kaggle/working/logs
encoder = Encoder([3, 4, 8, 16, 32, 64])
decoder = Decoder([64, 32, 16, 8, 4, 4, 4])
lightning_model = LightningModel(encoder, decoder, LR)
logger = WandbLogger("VAE 2", "/kaggle/working/logs/", project="VAE")
trainer = pl.Trainer(
    max_epochs=EPOCHS,
    gpus=torch.cuda.device_count(),
    gradient_clip_val=1.0,
    logger=logger,
    precision=16,
#     auto_lr_find=True,
#     limit_train_batches=10,
#     limit_val_batches=10,
)
trainer.fit(lightning_model, train_dl, valid_dl)

In [None]:
mu, log_var = encoder(x)
z = mu + torch.exp(0.5*log_var) * torch.randn(mu.shape)
out = decoder(z)

out.shape, out.mean(), out.std(), out.min(), out.max(), x.min(), x.max()

In [None]:
i = 16
print(out[i].std(), x[i].std())
img1 = TF.to_pil_image(x[i] + 0.5)
img2 = TF.to_pil_image(out[i] + 0.5)
plt.figure(figsize=(12, 5))
plt.subplot(121)
plt.imshow(img1)
plt.subplot(122)
plt.imshow(img2)
plt.show()

### 