# Libraries

In [None]:
!pip install torch torchinfo torchvision medmnist
!pip install --upgrade matplotlib

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Collecting medmnist
  Downloading medmnist-3.0.1-py3-none-any.whl (25 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch)
  Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)


In [None]:
import os, time
from collections import OrderedDict

In [None]:
import matplotlib.pyplot as plt
import medmnist
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from google.colab import drive
from matplotlib import cm
from medmnist import INFO, PathMNIST, ChestMNIST, BloodMNIST, PneumoniaMNIST, DermaMNIST
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from tqdm.notebook import tqdm

In [None]:
GDRIVE = "/content/drive"
drive.mount(GDRIVE)

Mounted at /content/drive


In [None]:
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cpu')

# Model and data classes

In [None]:
class FilteredDataset(Dataset):
    def __init__(self, original_dataset, filter_fn):
        self.original_dataset = original_dataset
        self.indices = [
            i for i in range(len(original_dataset)) if filter_fn(*original_dataset[i])
        ]

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        return self.original_dataset[self.indices[idx]]

In [None]:
class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        return x.view(x.shape[0], -1)


class Encoder(nn.Module):
    def __init__(self, shape, nhid=128, nclass=0):
        super(Encoder, self).__init__()
        self.channel = shape[0]
        self.image_w = shape[1]
        self.image_h = shape[2]
        self.embed_input = nn.Conv2d(
            in_channels=self.channel, out_channels=self.channel, kernel_size=1
        )
        self.embed_class = nn.Linear(nclass, self.image_w * self.image_h)

        modules = []
        in_channels = self.channel + 1
        for out_channels in [32, 64, 128, 256, 512]:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(
                        in_channels=in_channels,
                        out_channels=out_channels,
                        kernel_size=3,
                        stride=2,
                        padding=1,
                    ),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU(),
                )
            )
            in_channels = out_channels

        self.encode = nn.Sequential(*modules)
        self.calc_mean = nn.Linear(512, nhid)
        self.calc_logvar = nn.Linear(512, nhid)

    def forward(self, x, y):
        embedded_x = self.embed_input(x)
        embedded_y = self.embed_class(y.float())
        embedded_y = embedded_y.view(-1, self.image_w, self.image_h)
        embedded_y = embedded_y.unsqueeze(1)
        encoding = torch.concat([embedded_x, embedded_y], dim=1)
        encoding = self.encode(encoding)
        encoding = torch.flatten(encoding, start_dim=1)
        return self.calc_mean(encoding), self.calc_logvar(encoding)


class Decoder(nn.Module):
    def __init__(self, shape, nhid=128, nclass=0):
        super(Decoder, self).__init__()
        self.channel = shape[0]
        self.image_w = shape[1]
        self.image_h = shape[2]
        self.decode_latent = nn.Linear(nhid + nclass, 512 * 4)

        modules = []
        in_channels = 512
        for out_channels in [256, 128, 64, 32]:
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(
                        in_channels=in_channels,
                        out_channels=out_channels,
                        kernel_size=3,
                        stride=2,
                        padding=1,
                        output_padding=1,
                    ),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU(),
                )
            )
            in_channels = out_channels

        self.decode = nn.Sequential(
            *modules,
            nn.ConvTranspose2d(
                in_channels=32,
                out_channels=32,
                kernel_size=3,
                stride=2,
                padding=2,
                output_padding=0,
            ),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=32,
                out_channels=32,
                kernel_size=3,
                stride=2,
                padding=0,
            ),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=32,
                out_channels=self.channel,
                kernel_size=3,
                stride=1,
                padding=0,
            ),
            nn.Tanh()
        )

    def forward(self, z, y):
        decoding = self.decode_latent(torch.concat([z, y], dim=1))
        decoding = decoding.view(-1, 512, 2, 2)
        return self.decode(decoding)


class CVAE(nn.Module):
    def __init__(self, shape, nhid=128, nclass=0):
        super(CVAE, self).__init__()
        self.nclass = nclass
        self.dim = nhid
        self.encoder = Encoder(shape, nhid, nclass)
        self.decoder = Decoder(shape, nhid, nclass)

    def sampling(self, mean, logvar):
        eps = torch.randn(mean.shape).to(DEVICE)
        sigma = torch.exp(0.5 * logvar)
        return mean + eps * sigma

    def forward(self, x, y):
        mean, logvar = self.encoder(x, y)
        z = self.sampling(mean, logvar)
        return self.decoder(z, y), mean, logvar

    def generate(self, y):
        z = torch.randn((y.shape[0], self.dim)).to(DEVICE)
        return self.decoder(z, y)

In [None]:
def loss(X, X_hat, mean, logvar, kld_weight):
    reconstruction_loss = F.mse_loss(X_hat, X)
    KL_divergence = 0.5 * torch.mean(
        torch.sum(-1 - logvar + logvar.exp() + mean.pow(2), dim=1), dim=0
    )
    return reconstruction_loss + KL_divergence


def adjust_lr(optimizer, decay_rate=0.95):
    for param_group in optimizer.param_groups:
        param_group["lr"] *= decay_rate


def generate_test_images(model):
    with torch.no_grad():
        labels = F.one_hot(torch.arange(5 * model.nclass) % model.nclass).to(DEVICE)
        x = model.generate(labels) * 0.5 + 0.5

    figure, axes = plt.subplots(5, model.nclass)
    for row in range(5):
        for col in range(model.nclass):
            img = x[row * model.nclass + col].permute(1, 2, 0).cpu().numpy()
            axes[row, col].axis("off")
            axes[row, col].imshow(img, vmin=0, vmax=1)

    return figure


def load_model(model, optimizer, path="model_checkpoint.pth"):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])


def save_model(model, optimizer, path="model_checkpoint.pth"):
    torch.save(
        {
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
        },
        path,
    )

# Prepare datasets

In [None]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5]),
    ]
)

In [None]:
path_dataset = PathMNIST(
    split="train",
    download=True,
    transform=transform,
    target_transform=lambda x: (F.one_hot(torch.tensor(x), num_classes=9).squeeze()),
)
path_dataset

Downloading https://zenodo.org/records/10519652/files/pathmnist.npz?download=1 to /root/.medmnist/pathmnist.npz


100%|██████████| 205615438/205615438 [00:34<00:00, 5914617.94it/s] 


Dataset PathMNIST of size 28 (pathmnist)
    Number of datapoints: 89996
    Root location: /root/.medmnist
    Split: train
    Task: multi-class
    Number of channels: 3
    Meaning of labels: {'0': 'adipose', '1': 'background', '2': 'debris', '3': 'lymphocytes', '4': 'mucus', '5': 'smooth muscle', '6': 'normal colon mucosa', '7': 'cancer-associated stroma', '8': 'colorectal adenocarcinoma epithelium'}
    Number of samples: {'train': 89996, 'val': 10004, 'test': 7180}
    Description: The PathMNIST is based on a prior study for predicting survival from colorectal cancer histology slides, providing a dataset (NCT-CRC-HE-100K) of 100,000 non-overlapping image patches from hematoxylin & eosin stained histological images, and a test dataset (CRC-VAL-HE-7K) of 7,180 image patches from a different clinical center. The dataset is comprised of 9 types of tissues, resulting in a multi-class classification task. We resize the source images of 3×224×224 into 3×28×28, and split NCT-CRC-HE-100K

In [None]:
blood_dataset = BloodMNIST(
    split="train",
    download=True,
    transform=transform,
    target_transform=lambda x: (F.one_hot(torch.tensor(x), num_classes=8).squeeze()),
)
blood_dataset

Downloading https://zenodo.org/records/10519652/files/bloodmnist.npz?download=1 to /root/.medmnist/bloodmnist.npz


100%|██████████| 35461855/35461855 [00:18<00:00, 1932672.84it/s]


Dataset BloodMNIST of size 28 (bloodmnist)
    Number of datapoints: 11959
    Root location: /root/.medmnist
    Split: train
    Task: multi-class
    Number of channels: 3
    Meaning of labels: {'0': 'basophil', '1': 'eosinophil', '2': 'erythroblast', '3': 'immature granulocytes(myelocytes, metamyelocytes and promyelocytes)', '4': 'lymphocyte', '5': 'monocyte', '6': 'neutrophil', '7': 'platelet'}
    Number of samples: {'train': 11959, 'val': 1712, 'test': 3421}
    Description: The BloodMNIST is based on a dataset of individual normal cells, captured from individuals without infection, hematologic or oncologic disease and free of any pharmacologic treatment at the moment of blood collection. It contains a total of 17,092 images and is organized into 8 classes. We split the source dataset with a ratio of 7:1:2 into training, validation and test set. The source images with resolution 3×360×363 pixels are center-cropped into 3×200×200, and then resized into 3×28×28.
    License: CC B

In [None]:
derma_dataset = DermaMNIST(
    split="train",
    download=True,
    transform=transform,
    target_transform=lambda x: (F.one_hot(torch.tensor(x), num_classes=7).squeeze()),
)
derma_dataset

Downloading https://zenodo.org/records/10519652/files/dermamnist.npz?download=1 to /root/.medmnist/dermamnist.npz


100%|██████████| 19725078/19725078 [00:01<00:00, 10394792.36it/s]


Dataset DermaMNIST of size 28 (dermamnist)
    Number of datapoints: 7007
    Root location: /root/.medmnist
    Split: train
    Task: multi-class
    Number of channels: 3
    Meaning of labels: {'0': 'actinic keratoses and intraepithelial carcinoma', '1': 'basal cell carcinoma', '2': 'benign keratosis-like lesions', '3': 'dermatofibroma', '4': 'melanoma', '5': 'melanocytic nevi', '6': 'vascular lesions'}
    Number of samples: {'train': 7007, 'val': 1003, 'test': 2005}
    Description: The DermaMNIST is based on the HAM10000, a large collection of multi-source dermatoscopic images of common pigmented skin lesions. The dataset consists of 10,015 dermatoscopic images categorized as 7 different diseases, formulized as a multi-class classification task. We split the images into training, validation and test set with a ratio of 7:1:2. The source images of 3×600×450 are resized into 3×28×28.
    License: CC BY-NC 4.0

# Training

In [None]:
def train(model, optimizer, dataset):
    train_loader = DataLoader(dataset=dataset, batch_size=128, shuffle=True)
    checkpoint_directory = os.path.join(
        GDRIVE, "MyDrive/tu-darmstadt/dgm-project", str(time.time_ns())
    )
    if not os.path.isdir(checkpoint_directory):
        os.mkdir(checkpoint_directory)

    print(f"Saving checkpoints to {checkpoint_directory}")

    epoch_iter = tqdm(range(50))
    for epoch in epoch_iter:
        start = time.time()
        train_loss = 0.0
        batch_iter = tqdm(train_loader)
        for X, y in batch_iter:
            X = X.to(DEVICE)
            y = y.to(DEVICE)
            X_hat, mean, logvar = net(X, y)

            l = loss(X, X_hat, mean, logvar, 0.005)
            l.backward()
            optimizer.step()
            optimizer.zero_grad()

            batch_loss = l.cpu().item()
            train_loss += batch_loss

            batch_iter.set_postfix({"loss": batch_loss})
            batch_iter.refresh()

        epoch_iter.set_postfix({"loss": train_loss / len(train_loader)})
        epoch_iter.refresh()
        # adjust_lr(optimizer)

        figure = generate_test_images(net)
        checkpoint_state_path = os.path.join(
            checkpoint_directory, f"checkpoint-{epoch}.pt"
        )
        checkpoint_image_path = os.path.join(
            checkpoint_directory, f"checkpoint-{epoch}.png"
        )

        save_model(net, optimizer, checkpoint_state_path)
        figure.savefig(checkpoint_image_path)
        plt.close(figure)

In [None]:
net = CVAE(shape=(3, 28, 28), nhid=128, nclass=9).to(DEVICE)
optimizer = torch.optim.Adam(net.parameters(), lr=0.0001)

In [None]:
train(net, optimizer, path_dataset)

Saving checkpoints to /content/drive/MyDrive/tu-darmstadt/dgm-project/1720776158726438607


  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/704 [00:00<?, ?it/s]

  0%|          | 0/704 [00:00<?, ?it/s]