In [13]:
# Stolen from https://pytorch-lightning.readthedocs.io/en/latest/notebooks/course_UvA-DL/08-deep-autoencoders.html

import os
import urllib.request
from urllib.error import HTTPError

import matplotlib
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import torch.optim as optim
import torch.utils.data as data
import torchvision
from IPython.display import set_matplotlib_formats
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from tqdm.notebook import tqdm



# %matplotlib inline
# set_matplotlib_formats("svg", "pdf")  # For export
matplotlib.rcParams["lines.linewidth"] = 2.0
sns.reset_orig()
sns.set()

# Tensorboard extension (for visualization purposes later)
# %load_ext tensorboard

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "../datasets"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "saved_models/autoencoder"

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

Global seed set to 42


Device: cuda:0


In [14]:
# Transformations applied on each image => only make them a tensor
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# Loading the training dataset. We need to split it into a training and validation part
train_dataset = FashionMNIST(root=DATASET_PATH, train=True, transform=transform, download=True)
train_set, val_set = torch.utils.data.random_split(train_dataset, [int(len(train_dataset)*0.9), int(len(train_dataset)*0.1)])

print(f"val-set size:   {len(val_set)}")
print(f"train-set size: {len(train_set)}")


# Loading the test set
test_set = FashionMNIST(root=DATASET_PATH, train=False, transform=transform, download=True)

# We define a set of data loaders that we can use for various purposes later.
train_loader = data.DataLoader(train_set, batch_size=256, shuffle=True, drop_last=True, pin_memory=True, num_workers=4)
val_loader = data.DataLoader(val_set, batch_size=256, shuffle=False, drop_last=False, num_workers=4)
test_loader = data.DataLoader(test_set, batch_size=256, shuffle=False, drop_last=False, num_workers=4)

def get_train_images(num):
    return torch.stack([val_set[i][0] for i in range(num)], dim=0)

val-set size:   6000
train-set size: 54000


In [15]:
class Encoder(nn.Module):
    def __init__(self, input_size: int, latent_dim: int, act_fn: object = nn.ReLU):
        """
        Args:
           num_input_channels : Number of input channels of the image. For CIFAR, this parameter is 3
           base_channel_size : Number of channels we use in the first convolutional layers. Deeper layers might use a duplicate of it.
           latent_dim : Dimensionality of latent representation z
           act_fn : Activation function used throughout the encoder network
        """
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(input_size, 512),
            act_fn(),
            nn.Linear(512, 128),
            act_fn(),
            nn.Linear(128, latent_dim),
            act_fn(),
        )

    def forward(self, x):
        return self.net(x)

In [16]:
class Decoder(nn.Module):
    def __init__(self, input_size: int, latent_dim: int, act_fn: object = nn.ReLU):
        """
        Args:
           num_input_channels : Number of channels of the image to reconstruct. For CIFAR, this parameter is 3
           base_channel_size : Number of channels we use in the last convolutional layers. Early layers might use a duplicate of it.
           latent_dim : Dimensionality of latent representation z
           act_fn : Activation function used throughout the decoder network
        """
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, 128),
            act_fn(),
            nn.Linear(128, 512),
            act_fn(),
            nn.Linear(512, input_size),
            nn.Tanh(),  # The input images is scaled between -1 and 1, hence the output has to be bounded as well
            nn.Unflatten(1, (1, 28, 28))
        )

    def forward(self, x):
        return self.net(x)

In [17]:
class Autoencoder(pl.LightningModule):
    def __init__(
        self,
        input_size: int,
        latent_dim: int,
        encoder_class: object = Encoder,
        decoder_class: object = Decoder,
    ):
        super().__init__()
        # Saving hyperparameters of autoencoder
        self.save_hyperparameters()
        # Creating encoder and decoder
        self.encoder = encoder_class(input_size, latent_dim)
        self.decoder = decoder_class(input_size, latent_dim)

    def forward(self, x):
        """The forward function takes in an image and returns the reconstructed image."""
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat

    def _get_reconstruction_loss(self, batch):
        """Given a batch of images, this function returns the reconstruction loss (MSE in our case)"""
        x, _ = batch  # We do not need the labels
        x_hat = self.forward(x)
        loss = F.mse_loss(x, x_hat, reduction="none")
        loss = loss.sum().mean()
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        # Using a scheduler is optional but can be helpful.
        # The scheduler reduces the LR if the validation performance hasn't improved for the last N epochs
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.2, patience=20, min_lr=5e-5)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}

    def training_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log("val_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log("test_loss", loss)

In [18]:
class GenerateCallback(pl.Callback):
    def __init__(self, input_imgs, every_n_epochs=1):
        super().__init__()
        self.input_imgs = input_imgs  # Images to reconstruct during training
        # Only save those images every N epochs (otherwise tensorboard gets quite large)
        self.every_n_epochs = every_n_epochs

    def on_epoch_end(self, trainer, pl_module):
        if trainer.current_epoch % self.every_n_epochs == 0:
            # Reconstruct images
            input_imgs = self.input_imgs.to(pl_module.device)
            with torch.no_grad():
                pl_module.eval()
                reconst_imgs = pl_module(input_imgs)
                pl_module.train()
            # Plot and add to tensorboard
            imgs = torch.stack([input_imgs, reconst_imgs], dim=1).flatten(0, 1)
            grid = torchvision.utils.make_grid(imgs, nrow=2, normalize=True, range=(-1, 1))
            trainer.logger.experiment.add_image("Reconstructions", grid, global_step=trainer.global_step)

In [None]:
def train_fmnist(latent_dim):
    # Create a PyTorch Lightning trainer with the generation callback
    trainer = pl.Trainer(
        default_root_dir=os.path.join(CHECKPOINT_PATH, "fmnist_%i" % latent_dim),
        gpus=1 if str(device).startswith("cuda") else 0,
        max_epochs=500,
        callbacks=[
            ModelCheckpoint(save_weights_only=True),
            GenerateCallback(get_train_images(10), every_n_epochs=10),
            LearningRateMonitor("epoch"),
        ],
    )
    trainer.logger._log_graph = True  # If True, we plot the computation graph in tensorboard
    trainer.logger._default_hp_metric = None  # Optional logging argument that we don't need

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, "fmnist_%i.ckpt" % latent_dim)
    if os.path.isfile(pretrained_filename):
        print("Found pretrained model, loading...")
        model = Autoencoder.load_from_checkpoint(pretrained_filename)
    else:
        model = Autoencoder(input_size=28*28, latent_dim=latent_dim)
        trainer.fit(model, train_loader, val_loader)
    # Test best model on validation and test set
    val_result = trainer.test(model, test_dataloaders=val_loader, verbose=False)
    test_result = trainer.test(model, test_dataloaders=test_loader, verbose=False)
    result = {"test": test_result, "val": val_result}
    return model, result
train_fmnist(16)

In [52]:
# Encode entirety of FMNIST to latent space codes
from torch.utils.data import Dataset, DataLoader

class EncodedFashionMNIST():

    class LabeledCodes(Dataset):
        latent_codes: Tensor
        targets: Tensor

        def __init__(self, encoder: nn.Module, dataloader: DataLoader):
            for i, (x, y) in enumerate(dataloader):
                with torch.no_grad():
                    latent_code: Tensor = encoder(x)
                if i == 0:
                    self.latent_codes = latent_code
                    self.targets = y
                else:
                    self.latent_codes = torch.cat((self.latent_codes, latent_code))
                    self.targets = torch.cat((self.targets, y))

            print(self.latent_codes.size())

        def __len__(self):
            return len(self.targets)

        def __getitem__(self, index):
            return (self.latent_codes[index], self.targets[index])

    testset: LabeledCodes 
    trainset: LabeledCodes

    def __init__(self, encoder, train_loader, test_loader) -> None:
        self.testset = EncodedFashionMNIST.LabeledCodes(encoder, test_loader)
        self.trainset = EncodedFashionMNIST.LabeledCodes(encoder, train_loader)

model = Autoencoder.load_from_checkpoint("saved_models/autoencoder/fmnist_16/lightning_logs/version_10/checkpoints/epoch=368-step=77489.ckpt")

train_encoder_loader = data.DataLoader(train_dataset, batch_size=256, pin_memory=True, num_workers=4, shuffle=False)
test_encoder_loader = data.DataLoader(test_set, batch_size=256, pin_memory=True, num_workers=4)

my_dataset = EncodedFashionMNIST(model.encoder, train_encoder_loader, test_encoder_loader)


torch.Size([10000, 16])
torch.Size([60000, 16])


In [53]:
import dill

print(my_dataset.testset[0])
with open("latent_fashion_mnist.pkl", "bw") as f:
    dill.dump(my_dataset, f, recurse=True)

(tensor([0.0000, 3.1287, 6.3555, 0.0000, 8.0942, 0.0000, 4.8815, 0.0000, 3.5234,
        3.7017, 6.4643, 3.2514, 4.7463, 0.0000, 2.4219, 7.5642]), tensor(9))


In [2]:
from avalanche.benchmarks.generators import nc_benchmark


nc_benchmark()



# class LatentFashionMnist():