# Preliminary steps and imports

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
cd drive/MyDrive/MLinA_project/

/content/drive/MyDrive/MLinA_project


In [None]:
import torch
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import re
import numpy as np
from tqdm.auto import tqdm
from torch import cuda

# Define function to extract latent embeddings given a model

In [None]:
# extract latent vectors for the entire dataset
def extract_latent_vectors(model, dataloader, model_name):
    steps = len(dataloader)
    progress_bar = tqdm(total = steps)

    latent_vectors = []
    labels = []
    model.eval()
    with torch.no_grad():
        count = 0
        for data,target,_,_ in dataloader:
            count += 1
            data = data.to(device)
            target = target.to(device)
            mu, logvar = model.encode(data)
            latent_vector = model.reparameterize(mu, logvar)
            latent_vectors.append(latent_vector.cpu().numpy())
            labels.append(target.cpu().numpy())

            if count == 100:#
                count = 0#
                np.save(f'features_{model_name}.npy', np.concatenate(latent_vectors))
                np.save(f'labels_{model_name}.npy', np.concatenate(labels))

            progress_bar.update(1)
    return np.concatenate(latent_vectors), np.concatenate(labels)

# Dataset creation

In [None]:
import dataset_patches
from dataset_patches import patchesDataset

dataset = patchesDataset(root_dir='patches')
dataloader = DataLoader(dataset, batch_size=16, shuffle=False, num_workers=2, pin_memory=True)

In [None]:
# for images, labels, patient_ids, coordinates in dataloader:
#     print(images[0].size())  # Tensor shape: (batch_size, channels, height, width)
#     print(labels)  # Labels: 0 (no cancer) or 1 (cancer)
#     print(patient_ids)  # Patient IDs: 1 to 24
#     print(coordinates)  # Coordinates: (x, y) tuples
#     break

# Basic Variational Autoencoder

## Network definition

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define the VAE model
class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten()
        )

        self.fc_mu = nn.Linear(256*32*32, latent_dim)
        self.fc_logvar = nn.Linear(256*32*32, latent_dim)

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256*32*32),
            nn.ReLU(),
            nn.Unflatten(1, (256, 32, 32)),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def encode(self, x):
        x = self.encoder(x)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

## Training

In [None]:
# Model, optimizer, and training parameters

current_epochs = 0

model = VAE(latent_dim=100).to(device)
try:
  model.load_state_dict(torch.load(f'vae_model_{current_epochs}.pth'))
  print('Model loaded correctly!')
except:
    ...
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch_idx, (data, _, _, _) in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+current_epochs+1}/{num_epochs}")):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+current_epochs+1}/{num_epochs}, Loss: {total_loss / len(dataloader.dataset):.4f}")
    if (epoch+1)%2==0 and epoch>0:
        torch.save(model.state_dict(), f'vae_{epoch+current_epochs+1}.pth')

# Save the trained model
torch.save(model.state_dict(), f'vae_100.pth')

Epoch 1/100: 100%|██████████| 510/510 [21:08<00:00,  2.49s/it]


Epoch 1/100, Loss: 472607.2335


Epoch 2/100: 100%|██████████| 510/510 [02:42<00:00,  3.14it/s]


Epoch 2/100, Loss: 460213.8909


Epoch 3/100: 100%|██████████| 510/510 [02:41<00:00,  3.16it/s]


Epoch 3/100, Loss: 456497.6767


Epoch 4/100: 100%|██████████| 510/510 [02:41<00:00,  3.17it/s]


Epoch 4/100, Loss: 455264.1543


Epoch 5/100: 100%|██████████| 510/510 [02:45<00:00,  3.07it/s]


Epoch 5/100, Loss: 453340.5198


Epoch 6/100: 100%|██████████| 510/510 [02:40<00:00,  3.17it/s]


Epoch 6/100, Loss: 452499.7678


Epoch 7/100: 100%|██████████| 510/510 [02:44<00:00,  3.10it/s]


Epoch 7/100, Loss: 451029.2463


Epoch 8/100: 100%|██████████| 510/510 [02:45<00:00,  3.08it/s]


Epoch 8/100, Loss: 450278.6699


Epoch 9/100: 100%|██████████| 510/510 [02:43<00:00,  3.12it/s]


Epoch 9/100, Loss: 449788.7218


Epoch 10/100: 100%|██████████| 510/510 [02:40<00:00,  3.17it/s]


Epoch 10/100, Loss: 448215.9512


Epoch 11/100: 100%|██████████| 510/510 [02:40<00:00,  3.18it/s]


Epoch 11/100, Loss: 448432.1664


Epoch 12/100: 100%|██████████| 510/510 [02:39<00:00,  3.19it/s]


Epoch 12/100, Loss: 446868.9502


Epoch 13/100: 100%|██████████| 510/510 [02:39<00:00,  3.20it/s]


Epoch 13/100, Loss: 446073.9155


Epoch 14/100: 100%|██████████| 510/510 [02:39<00:00,  3.19it/s]


Epoch 14/100, Loss: 443283.4474


Epoch 15/100: 100%|██████████| 510/510 [02:40<00:00,  3.18it/s]


Epoch 15/100, Loss: 441782.3061


Epoch 16/100: 100%|██████████| 510/510 [02:40<00:00,  3.18it/s]


Epoch 16/100, Loss: 441548.8330


Epoch 17/100: 100%|██████████| 510/510 [02:40<00:00,  3.19it/s]


Epoch 17/100, Loss: 440520.4613


Epoch 18/100: 100%|██████████| 510/510 [02:40<00:00,  3.18it/s]


Epoch 18/100, Loss: 439309.6442


Epoch 19/100: 100%|██████████| 510/510 [02:40<00:00,  3.19it/s]


Epoch 19/100, Loss: 439470.5268


Epoch 20/100: 100%|██████████| 510/510 [02:40<00:00,  3.18it/s]


Epoch 20/100, Loss: 438753.5188


Epoch 21/100: 100%|██████████| 510/510 [02:40<00:00,  3.18it/s]


Epoch 21/100, Loss: 437344.1273


Epoch 22/100: 100%|██████████| 510/510 [02:40<00:00,  3.18it/s]


Epoch 22/100, Loss: 436683.6216


Epoch 23/100: 100%|██████████| 510/510 [02:43<00:00,  3.11it/s]


Epoch 23/100, Loss: 436192.5903


Epoch 24/100: 100%|██████████| 510/510 [02:42<00:00,  3.15it/s]


Epoch 24/100, Loss: 436841.1710


Epoch 25/100: 100%|██████████| 510/510 [02:43<00:00,  3.13it/s]


Epoch 25/100, Loss: 436464.5814


Epoch 26/100: 100%|██████████| 510/510 [02:40<00:00,  3.17it/s]


Epoch 26/100, Loss: 435183.0895


Epoch 27/100: 100%|██████████| 510/510 [02:41<00:00,  3.17it/s]


Epoch 27/100, Loss: 434648.6175


Epoch 28/100: 100%|██████████| 510/510 [02:40<00:00,  3.17it/s]


Epoch 28/100, Loss: 434532.0893


Epoch 29/100: 100%|██████████| 510/510 [02:40<00:00,  3.17it/s]


Epoch 29/100, Loss: 434102.9588


Epoch 30/100: 100%|██████████| 510/510 [02:41<00:00,  3.17it/s]


Epoch 30/100, Loss: 434041.5383


Epoch 31/100: 100%|██████████| 510/510 [02:40<00:00,  3.18it/s]


Epoch 31/100, Loss: 433791.0779


Epoch 32/100: 100%|██████████| 510/510 [02:40<00:00,  3.18it/s]


Epoch 32/100, Loss: 433329.2456


Epoch 33/100: 100%|██████████| 510/510 [02:41<00:00,  3.16it/s]


Epoch 33/100, Loss: 432793.3968


Epoch 34/100: 100%|██████████| 510/510 [02:41<00:00,  3.17it/s]


Epoch 34/100, Loss: 432855.1450


Epoch 35/100: 100%|██████████| 510/510 [02:39<00:00,  3.20it/s]


Epoch 35/100, Loss: 432510.2789


Epoch 36/100: 100%|██████████| 510/510 [02:39<00:00,  3.19it/s]


Epoch 36/100, Loss: 432455.7602


Epoch 37/100: 100%|██████████| 510/510 [02:40<00:00,  3.18it/s]


Epoch 37/100, Loss: 432562.0608


Epoch 38/100: 100%|██████████| 510/510 [02:40<00:00,  3.17it/s]


Epoch 38/100, Loss: 432166.7637


Epoch 39/100: 100%|██████████| 510/510 [02:39<00:00,  3.19it/s]


Epoch 39/100, Loss: 431593.6047


Epoch 40/100: 100%|██████████| 510/510 [02:41<00:00,  3.16it/s]


Epoch 40/100, Loss: 431390.8113


Epoch 41/100: 100%|██████████| 510/510 [02:41<00:00,  3.16it/s]


Epoch 41/100, Loss: 431285.5754


Epoch 42/100: 100%|██████████| 510/510 [02:44<00:00,  3.10it/s]


Epoch 42/100, Loss: 431218.5502


Epoch 43/100: 100%|██████████| 510/510 [02:45<00:00,  3.08it/s]


Epoch 43/100, Loss: 431401.4483


Epoch 44/100: 100%|██████████| 510/510 [02:44<00:00,  3.11it/s]


Epoch 44/100, Loss: 431173.2488


Epoch 45/100: 100%|██████████| 510/510 [02:46<00:00,  3.05it/s]


Epoch 45/100, Loss: 430708.3409


Epoch 46/100: 100%|██████████| 510/510 [02:44<00:00,  3.09it/s]


Epoch 46/100, Loss: 430704.2701


Epoch 47/100: 100%|██████████| 510/510 [02:43<00:00,  3.13it/s]


Epoch 47/100, Loss: 430562.0105


Epoch 48/100: 100%|██████████| 510/510 [02:41<00:00,  3.16it/s]


Epoch 48/100, Loss: 430901.6449


Epoch 49/100: 100%|██████████| 510/510 [02:40<00:00,  3.17it/s]


Epoch 49/100, Loss: 431869.3346


Epoch 50/100: 100%|██████████| 510/510 [02:41<00:00,  3.16it/s]


Epoch 50/100, Loss: 430319.5532


Epoch 51/100: 100%|██████████| 510/510 [02:40<00:00,  3.18it/s]


Epoch 51/100, Loss: 429857.6732


Epoch 52/100: 100%|██████████| 510/510 [02:41<00:00,  3.16it/s]


Epoch 52/100, Loss: 429731.8997


Epoch 53/100: 100%|██████████| 510/510 [02:40<00:00,  3.18it/s]


Epoch 53/100, Loss: 429808.7932


Epoch 54/100: 100%|██████████| 510/510 [02:43<00:00,  3.12it/s]


Epoch 54/100, Loss: 429753.8902


Epoch 55/100: 100%|██████████| 510/510 [02:41<00:00,  3.16it/s]


Epoch 55/100, Loss: 429620.9039


Epoch 56/100: 100%|██████████| 510/510 [02:40<00:00,  3.18it/s]


Epoch 56/100, Loss: 429682.7313


Epoch 57/100: 100%|██████████| 510/510 [02:42<00:00,  3.15it/s]


Epoch 57/100, Loss: 430335.2606


Epoch 58/100: 100%|██████████| 510/510 [02:41<00:00,  3.16it/s]


Epoch 58/100, Loss: 429420.3647


Epoch 59/100: 100%|██████████| 510/510 [02:43<00:00,  3.11it/s]


Epoch 59/100, Loss: 429029.4118


Epoch 60/100: 100%|██████████| 510/510 [02:45<00:00,  3.08it/s]


Epoch 60/100, Loss: 429236.6064


Epoch 61/100:  25%|██▌       | 128/510 [00:42<02:11,  2.90it/s]

## Extract and save the features using the basic VAE

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VAE(100).to(device)
model.load_state_dict(torch.load('vae_model_60.pth'))

<All keys matched successfully>

In [None]:
latent_vectors, labels = extract_latent_vectors(model, dataloader, 'vae')

  0%|          | 0/892 [00:00<?, ?it/s]

In [None]:
print(latent_vectors.shape)
print(labels.shape)

(14269, 100)
(14269,)


In [None]:
np.save('latent_vectors_vae_final.npy', latent_vectors)
np.save('labels_vae_final.npy', labels)
import pickle

with open('extracted_features.pkl', 'wb') as f:
    pickle.dump((latent_vectors, labels), f)


## Train an SVM model using the basic VAE features


In [None]:
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(latent_vectors, labels, test_size=0.2, random_state=42)

# Initialize and train the SVM classifier
svm_classifier = SVC(kernel='linear', C=1.0)
svm_classifier.fit(X_train, y_train)

# Predict on the test set
y_pred = svm_classifier.predict(X_test)

# Evaluate the classifier
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)
print("Classification Report:\n", classification_report(y_test, y_pred))


Accuracy: 0.7186405045550105
Classification Report:
               precision    recall  f1-score   support

           0       0.73      0.73      0.73      1475
           1       0.71      0.71      0.71      1379

    accuracy                           0.72      2854
   macro avg       0.72      0.72      0.72      2854
weighted avg       0.72      0.72      0.72      2854



# ResNet based VAE
Something similar was used in [CLINICALLY RELEVANT LATENT SPACE EMBEDDING OF CANCER HISTOPATHOLOGY SLIDES THROUGH VARIATIONAL AUTOENCODER BASED IMAGE COMPRESSION](https://arxiv.org/pdf/2303.13332).Its code can be found in [this](https://github.com/jacobluber/uta_cancer_search) github folder but doesn't work anymore because of torch lightning...
The specific vae we use was found in [this folder](https://github.com/julianstastny/VAE-ResNet18-PyTorch/blob/master/model.py).
With respect to the original vae some changes were applied:
1. first of all al the dimensions had to be adjusted to our data
2. changed num_blocks from [2,2,2,2] to [1,1,1,1] to reduce the complexity of the network and go from 14.7M parameters to 13.5M
3. added max pooling layers before linear layers allows us to reduce the number of parameters from 13.5M to 5.7M

In [None]:
import torch
from torch import nn
import torch.nn.functional as F

class ResizeConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, scale_factor, mode='nearest'):
        super().__init__()
        self.scale_factor = scale_factor
        self.mode = mode
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=1)

    def forward(self, x):
        x = nn.functional.interpolate(x, scale_factor=self.scale_factor, mode=self.mode)
        x = self.conv(x)
        return x

class BasicBlockEnc(nn.Module):
    def __init__(self, in_planes, out_planes, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_planes)
        self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != out_planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_planes)
            )

    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = torch.relu(out)
        return out

class ResNet18Enc(nn.Module):
    def __init__(self, num_blocks=[1,1,1,1], z_dim=128, nc=3):
        super().__init__()
        self.in_planes = 32
        self.conv1 = nn.Conv2d(nc, 32, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.z_dim = z_dim
        self.layer1 = self._make_layer(BasicBlockEnc, 16, num_blocks[0], stride=2)
        self.layer2 = self._make_layer(BasicBlockEnc, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(BasicBlockEnc, 64, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(BasicBlockEnc, 128, num_blocks[3], stride=2)
        self.maxpool = nn.MaxPool2d(4,4)
        self.linear = nn.Linear(128 * 4 * 4, 2 * z_dim)

    def _make_layer(self, block, planes, num_blocks, stride):
        layers = []
        for i in range(num_blocks):
            if i == 0:
                layers.append(block(self.in_planes, planes, stride))
            else:
                layers.append(block(planes, planes, 1))
            self.in_planes = planes
        return nn.Sequential(*layers)

    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.maxpool(x)
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        mu = x[:, :self.z_dim]
        logvar = x[:, self.z_dim:]
        return mu, logvar

class BasicBlockDec(nn.Module):
    def __init__(self, in_planes, out_planes, stride=1):
        super().__init__()
        self.conv2 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(in_planes)
        self.conv1 = ResizeConv2d(in_planes, out_planes, kernel_size=3, scale_factor=stride)
        self.bn1 = nn.BatchNorm2d(out_planes)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != out_planes:
            self.shortcut = nn.Sequential(
                ResizeConv2d(in_planes, out_planes, kernel_size=1, scale_factor=stride),
                nn.BatchNorm2d(out_planes)
            )

    def forward(self, x):
        out = torch.relu(self.bn2(self.conv2(x)))
        out = self.bn1(self.conv1(out))
        # Adjust shortcut connection to match spatial dimensions
        shortcut_x = self.shortcut(x)
        if shortcut_x.size(2) != out.size(2) or shortcut_x.size(3) != out.size(3):
            shortcut_x = F.interpolate(shortcut_x, size=(out.size(2), out.size(3)), mode='nearest')
        out += shortcut_x
        out = torch.relu(out)
        return out

class ResNet18Dec(nn.Module):
    def __init__(self, num_blocks=[1,1,1,1], z_dim=128, nc=3):
        super().__init__()
        self.in_planes = z_dim
        self.linear = nn.Linear(z_dim, 128 * 16 * 16)
        self.layer4 = self._make_layer(BasicBlockDec, 128, num_blocks[3], stride=2)
        self.layer3 = self._make_layer(BasicBlockDec, 64, num_blocks[2], stride=2)
        self.layer2 = self._make_layer(BasicBlockDec, 32, num_blocks[1], stride=2)
        self.layer1 = self._make_layer(BasicBlockDec, 16, num_blocks[0], stride=2)
        self.conv1 = ResizeConv2d(16, nc, kernel_size=3, scale_factor=2)

    def _make_layer(self, block, planes, num_blocks, stride):
        layers = []
        for i in range(num_blocks):
            if i == 0:
                layers.append(block(self.in_planes, planes, stride))
            else:
                layers.append(block(planes, planes, 1))
            self.in_planes = planes
        return nn.Sequential(*layers)

    def forward(self, z):
        x = self.linear(z)
        x = x.view(z.size(0), 128, 16, 16)
        x = self.layer4(x)
        x = self.layer3(x)
        x = self.layer2(x)
        x = self.layer1(x)
        x = torch.sigmoid(self.conv1(x))
        return x

class ResVAE(nn.Module):
    def __init__(self, z_dim):
        super().__init__()
        self.encoder = ResNet18Enc(z_dim=z_dim)
        self.decoder = ResNet18Dec(z_dim=z_dim)

    def encode(self, x):
        return self.encoder(x)

    def forward(self, x):
        mean, logvar = self.encoder(x)
        z = self.reparameterize(mean, logvar)
        x = self.decoder(z)
        return x, mean, logvar

    @staticmethod
    def reparameterize(mean, logvar):
        std = torch.exp(logvar / 2)
        epsilon = torch.randn_like(std)
        return epsilon * std + mean

    def loss_function(self, recon_x, x, mu, logvar):
        BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return BCE + KLD

In [None]:
from torchsummary import summary
device = 'cuda' if cuda.is_available() else 'cpu'
print(device)
model = ResVAE(128).to(device)
summary(model, (3,512,512))

cpu
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 256, 256]             864
       BatchNorm2d-2         [-1, 32, 256, 256]              64
            Conv2d-3         [-1, 16, 128, 128]           4,608
       BatchNorm2d-4         [-1, 16, 128, 128]              32
            Conv2d-5         [-1, 16, 128, 128]           2,304
       BatchNorm2d-6         [-1, 16, 128, 128]              32
            Conv2d-7         [-1, 16, 128, 128]             512
       BatchNorm2d-8         [-1, 16, 128, 128]              32
     BasicBlockEnc-9         [-1, 16, 128, 128]               0
           Conv2d-10           [-1, 32, 64, 64]           4,608
      BatchNorm2d-11           [-1, 32, 64, 64]              64
           Conv2d-12           [-1, 32, 64, 64]           9,216
      BatchNorm2d-13           [-1, 32, 64, 64]              64
           Conv2d-14           [-1,

## Training

In [None]:
import torch
from torch import optim
from torch.utils.data import DataLoader
from tqdm import tqdm

# Hyperparameters
z_dim = 128
num_epochs = 10
learning_rate = 1e-3
checkpoint_path = 'ResVAE_checkpoint.pth'

# Initialize model and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ResVAE(z_dim=z_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Load checkpoint if available
start_epoch = 0
total_loss = 0

try:
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    total_loss = checkpoint.get('loss', 0)
    print('Model loaded correctly from checkpoint!')
except FileNotFoundError:
    print('No checkpoint found, starting from scratch.')

progress_bar = tqdm(range(len(dataloader)*(num_epochs-start_epoch)))

# Training loop
for epoch in range(start_epoch, num_epochs):
    model.train()
    epoch_loss = 0
    for data, _, _, _ in dataloader:
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = model.loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        progress_bar.update(1)

    avg_loss = epoch_loss / len(dataloader.dataset)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

    # Save checkpoint every epoch
    if (epoch + 1) % 1 == 0:
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }, f'ResVAE_checkpoint.pth')

# Close the tqdm progress bar
progress_bar.close()

# Save the final model
torch.save({
    'epoch': num_epochs,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': avg_loss,
}, f'ResVAE_{num_epochs}.pth')

No checkpoint found, starting from scratch.


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
 10%|█         | 892/8920 [50:57<7:53:01,  3.54s/it]

Epoch 1/10, Loss: 453197.6523


 20%|██        | 1784/8920 [57:49<41:18,  2.88it/s]

Epoch 2/10, Loss: 449137.7002


 30%|███       | 2676/8920 [1:04:42<39:40,  2.62it/s]

Epoch 3/10, Loss: 447773.0669


 40%|████      | 3568/8920 [1:11:38<36:57,  2.41it/s]

Epoch 4/10, Loss: 447791.7133


 50%|█████     | 4460/8920 [1:18:43<24:54,  2.98it/s]

Epoch 5/10, Loss: 446503.5644


 60%|██████    | 5352/8920 [1:25:50<20:52,  2.85it/s]

Epoch 6/10, Loss: 443056.0729


 70%|███████   | 6244/8920 [1:33:01<17:53,  2.49it/s]

Epoch 7/10, Loss: 439472.1356


 80%|████████  | 7136/8920 [1:39:55<13:21,  2.23it/s]

Epoch 8/10, Loss: 473489.0946


 90%|█████████ | 8028/8920 [1:46:57<04:45,  3.12it/s]

Epoch 9/10, Loss: 442792.9457


100%|██████████| 8920/8920 [1:53:51<00:00,  2.98it/s]

Epoch 10/10, Loss: 439307.9617


100%|██████████| 8920/8920 [1:53:51<00:00,  1.31it/s]


## Extract and save the features using ResVAE

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResVAE(128).to(device)
model.load_state_dict(torch.load('ResVAE_10.pth')['model_state_dict'])

<All keys matched successfully>

In [None]:
latent_vectors, labels = extract_latent_vectors(model, dataloader, 'ResVAE')

100%|██████████| 892/892 [06:30<00:00,  2.29it/s]


In [None]:
print(latent_vectors.shape)
print(labels.shape)
np.save('features_ResVAE_10.npy', latent_vectors)
np.save('labels_ResVAE_10.npy', labels)

(14269, 128)
(14269,)


## Train an SVM model using ResVAE features


In [None]:
latent_vectors = np.load('features_ResVAE_10.npy')
labels = np.load('labels_ResVAE_10.npy')

In [None]:
latent_vectors.shape

(14269, 128)

In [None]:
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(latent_vectors, labels, test_size=0.2, random_state=42)

# Initialize and train the SVM classifier
svm_classifier = SVC(kernel='linear', C=1.0)
svm_classifier.fit(X_train, y_train)

# Predict on the test set
y_pred = svm_classifier.predict(X_test)

# Evaluate the classifier
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)
print("Classification Report:\n", classification_report(y_test, y_pred))


Accuracy: 0.7214435879467415
Classification Report:
               precision    recall  f1-score   support

           0       0.73      0.73      0.73      1475
           1       0.71      0.72      0.71      1379

    accuracy                           0.72      2854
   macro avg       0.72      0.72      0.72      2854
weighted avg       0.72      0.72      0.72      2854

