In [11]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [14]:
cd drive/MyDrive/MLinA_project/

[Errno 2] No such file or directory: 'drive/MyDrive/MLinA_project/'
/content/drive/MyDrive/MLinA_project


In [13]:
import torch
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import re
import numpy as np

# Dataset creation

In [15]:
import dataset_patches
from dataset_patches import patchesDataset

dataset = patchesDataset(root_dir='patches')
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

In [16]:
for images, labels, patient_ids, coordinates in dataloader:
    print(images[0].size())  # Tensor shape: (batch_size, channels, height, width)
    print(labels)  # Labels: 0 (no cancer) or 1 (cancer)
    print(patient_ids)  # Patient IDs: 1 to 24
    print(coordinates)  # Coordinates: (x, y) tuples
    break

torch.Size([3, 512, 512])
tensor([0, 0, 0, 0, 1, 0, 1, 1])
tensor([14, 18, 21,  8,  2, 20, 20, 23])
[tensor([33792,  4608, 10752,  1536,  9216,  7168, 11776,  1536]), tensor([14848,  2048,  5120,  5120,  6144, 27648, 27136,  9216])]


# Variational Autoencoder

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define the VAE model
class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten()
        )

        self.fc_mu = nn.Linear(256*32*32, latent_dim)
        self.fc_logvar = nn.Linear(256*32*32, latent_dim)

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256*32*32),
            nn.ReLU(),
            nn.Unflatten(1, (256, 32, 32)),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def encode(self, x):
        x = self.encoder(x)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# Training

In [None]:
# Model, optimizer, and training parameters

current_epochs = 0

model = VAE(latent_dim=100).to(device)
try:
  model.load_state_dict(torch.load(f'vae_model_{current_epochs}.pth'))
  print('Model loaded correctly!')
except:
    ...
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch_idx, (data, _, _, _) in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+current_epochs+1}/{num_epochs}")):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+current_epochs+1}/{num_epochs}, Loss: {total_loss / len(dataloader.dataset):.4f}")
    if (epoch+1)%2==0 and epoch>0:
        torch.save(model.state_dict(), f'vae_{epoch+current_epochs+1}.pth')

# Save the trained model
torch.save(model.state_dict(), f'vae_100.pth')

Epoch 1/100: 100%|██████████| 510/510 [21:08<00:00,  2.49s/it]


Epoch 1/100, Loss: 472607.2335


Epoch 2/100: 100%|██████████| 510/510 [02:42<00:00,  3.14it/s]


Epoch 2/100, Loss: 460213.8909


Epoch 3/100: 100%|██████████| 510/510 [02:41<00:00,  3.16it/s]


Epoch 3/100, Loss: 456497.6767


Epoch 4/100: 100%|██████████| 510/510 [02:41<00:00,  3.17it/s]


Epoch 4/100, Loss: 455264.1543


Epoch 5/100: 100%|██████████| 510/510 [02:45<00:00,  3.07it/s]


Epoch 5/100, Loss: 453340.5198


Epoch 6/100: 100%|██████████| 510/510 [02:40<00:00,  3.17it/s]


Epoch 6/100, Loss: 452499.7678


Epoch 7/100: 100%|██████████| 510/510 [02:44<00:00,  3.10it/s]


Epoch 7/100, Loss: 451029.2463


Epoch 8/100: 100%|██████████| 510/510 [02:45<00:00,  3.08it/s]


Epoch 8/100, Loss: 450278.6699


Epoch 9/100: 100%|██████████| 510/510 [02:43<00:00,  3.12it/s]


Epoch 9/100, Loss: 449788.7218


Epoch 10/100: 100%|██████████| 510/510 [02:40<00:00,  3.17it/s]


Epoch 10/100, Loss: 448215.9512


Epoch 11/100: 100%|██████████| 510/510 [02:40<00:00,  3.18it/s]


Epoch 11/100, Loss: 448432.1664


Epoch 12/100: 100%|██████████| 510/510 [02:39<00:00,  3.19it/s]


Epoch 12/100, Loss: 446868.9502


Epoch 13/100: 100%|██████████| 510/510 [02:39<00:00,  3.20it/s]


Epoch 13/100, Loss: 446073.9155


Epoch 14/100: 100%|██████████| 510/510 [02:39<00:00,  3.19it/s]


Epoch 14/100, Loss: 443283.4474


Epoch 15/100: 100%|██████████| 510/510 [02:40<00:00,  3.18it/s]


Epoch 15/100, Loss: 441782.3061


Epoch 16/100: 100%|██████████| 510/510 [02:40<00:00,  3.18it/s]


Epoch 16/100, Loss: 441548.8330


Epoch 17/100: 100%|██████████| 510/510 [02:40<00:00,  3.19it/s]


Epoch 17/100, Loss: 440520.4613


Epoch 18/100: 100%|██████████| 510/510 [02:40<00:00,  3.18it/s]


Epoch 18/100, Loss: 439309.6442


Epoch 19/100: 100%|██████████| 510/510 [02:40<00:00,  3.19it/s]


Epoch 19/100, Loss: 439470.5268


Epoch 20/100: 100%|██████████| 510/510 [02:40<00:00,  3.18it/s]


Epoch 20/100, Loss: 438753.5188


Epoch 21/100: 100%|██████████| 510/510 [02:40<00:00,  3.18it/s]


Epoch 21/100, Loss: 437344.1273


Epoch 22/100: 100%|██████████| 510/510 [02:40<00:00,  3.18it/s]


Epoch 22/100, Loss: 436683.6216


Epoch 23/100: 100%|██████████| 510/510 [02:43<00:00,  3.11it/s]


Epoch 23/100, Loss: 436192.5903


Epoch 24/100: 100%|██████████| 510/510 [02:42<00:00,  3.15it/s]


Epoch 24/100, Loss: 436841.1710


Epoch 25/100: 100%|██████████| 510/510 [02:43<00:00,  3.13it/s]


Epoch 25/100, Loss: 436464.5814


Epoch 26/100: 100%|██████████| 510/510 [02:40<00:00,  3.17it/s]


Epoch 26/100, Loss: 435183.0895


Epoch 27/100: 100%|██████████| 510/510 [02:41<00:00,  3.17it/s]


Epoch 27/100, Loss: 434648.6175


Epoch 28/100: 100%|██████████| 510/510 [02:40<00:00,  3.17it/s]


Epoch 28/100, Loss: 434532.0893


Epoch 29/100: 100%|██████████| 510/510 [02:40<00:00,  3.17it/s]


Epoch 29/100, Loss: 434102.9588


Epoch 30/100: 100%|██████████| 510/510 [02:41<00:00,  3.17it/s]


Epoch 30/100, Loss: 434041.5383


Epoch 31/100: 100%|██████████| 510/510 [02:40<00:00,  3.18it/s]


Epoch 31/100, Loss: 433791.0779


Epoch 32/100: 100%|██████████| 510/510 [02:40<00:00,  3.18it/s]


Epoch 32/100, Loss: 433329.2456


Epoch 33/100: 100%|██████████| 510/510 [02:41<00:00,  3.16it/s]


Epoch 33/100, Loss: 432793.3968


Epoch 34/100: 100%|██████████| 510/510 [02:41<00:00,  3.17it/s]


Epoch 34/100, Loss: 432855.1450


Epoch 35/100: 100%|██████████| 510/510 [02:39<00:00,  3.20it/s]


Epoch 35/100, Loss: 432510.2789


Epoch 36/100: 100%|██████████| 510/510 [02:39<00:00,  3.19it/s]


Epoch 36/100, Loss: 432455.7602


Epoch 37/100: 100%|██████████| 510/510 [02:40<00:00,  3.18it/s]


Epoch 37/100, Loss: 432562.0608


Epoch 38/100: 100%|██████████| 510/510 [02:40<00:00,  3.17it/s]


Epoch 38/100, Loss: 432166.7637


Epoch 39/100: 100%|██████████| 510/510 [02:39<00:00,  3.19it/s]


Epoch 39/100, Loss: 431593.6047


Epoch 40/100: 100%|██████████| 510/510 [02:41<00:00,  3.16it/s]


Epoch 40/100, Loss: 431390.8113


Epoch 41/100: 100%|██████████| 510/510 [02:41<00:00,  3.16it/s]


Epoch 41/100, Loss: 431285.5754


Epoch 42/100: 100%|██████████| 510/510 [02:44<00:00,  3.10it/s]


Epoch 42/100, Loss: 431218.5502


Epoch 43/100: 100%|██████████| 510/510 [02:45<00:00,  3.08it/s]


Epoch 43/100, Loss: 431401.4483


Epoch 44/100: 100%|██████████| 510/510 [02:44<00:00,  3.11it/s]


Epoch 44/100, Loss: 431173.2488


Epoch 45/100: 100%|██████████| 510/510 [02:46<00:00,  3.05it/s]


Epoch 45/100, Loss: 430708.3409


Epoch 46/100: 100%|██████████| 510/510 [02:44<00:00,  3.09it/s]


Epoch 46/100, Loss: 430704.2701


Epoch 47/100: 100%|██████████| 510/510 [02:43<00:00,  3.13it/s]


Epoch 47/100, Loss: 430562.0105


Epoch 48/100: 100%|██████████| 510/510 [02:41<00:00,  3.16it/s]


Epoch 48/100, Loss: 430901.6449


Epoch 49/100: 100%|██████████| 510/510 [02:40<00:00,  3.17it/s]


Epoch 49/100, Loss: 431869.3346


Epoch 50/100: 100%|██████████| 510/510 [02:41<00:00,  3.16it/s]


Epoch 50/100, Loss: 430319.5532


Epoch 51/100: 100%|██████████| 510/510 [02:40<00:00,  3.18it/s]


Epoch 51/100, Loss: 429857.6732


Epoch 52/100: 100%|██████████| 510/510 [02:41<00:00,  3.16it/s]


Epoch 52/100, Loss: 429731.8997


Epoch 53/100: 100%|██████████| 510/510 [02:40<00:00,  3.18it/s]


Epoch 53/100, Loss: 429808.7932


Epoch 54/100: 100%|██████████| 510/510 [02:43<00:00,  3.12it/s]


Epoch 54/100, Loss: 429753.8902


Epoch 55/100: 100%|██████████| 510/510 [02:41<00:00,  3.16it/s]


Epoch 55/100, Loss: 429620.9039


Epoch 56/100: 100%|██████████| 510/510 [02:40<00:00,  3.18it/s]


Epoch 56/100, Loss: 429682.7313


Epoch 57/100: 100%|██████████| 510/510 [02:42<00:00,  3.15it/s]


Epoch 57/100, Loss: 430335.2606


Epoch 58/100: 100%|██████████| 510/510 [02:41<00:00,  3.16it/s]


Epoch 58/100, Loss: 429420.3647


Epoch 59/100: 100%|██████████| 510/510 [02:43<00:00,  3.11it/s]


Epoch 59/100, Loss: 429029.4118


Epoch 60/100: 100%|██████████| 510/510 [02:45<00:00,  3.08it/s]


Epoch 60/100, Loss: 429236.6064


Epoch 61/100:  25%|██▌       | 128/510 [00:42<02:11,  2.90it/s]

# Load the encoder and extract latent embeddings

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VAE(100).to(device)
model.load_state_dict(torch.load('vae_model_60.pth'))

<All keys matched successfully>

In [None]:
# extract latent vectors for the entire dataset
def extract_latent_vectors(model, dataloader):
    latent_vectors = []
    labels = []
    model.eval()
    with torch.no_grad():
        for data,target,_,_ in dataloader:
            data = data.to(device)
            target = target.to(device)
            mu, logvar = model.encode(data)
            latent_vector = model.reparameterize(mu, logvar)
            latent_vectors.append(latent_vector.cpu().numpy())
            labels.append(target.cpu().numpy())
    return np.concatenate(latent_vectors), np.concatenate(labels)

In [None]:
latent_vectors, labels = extract_latent_vectors(model, dataloader)

In [None]:
print(latent_vectors.shape)
print(labels.shape)

In [None]:
np.save('latent_vectors_vae.npy', latent_vectors)
np.save('labels_vae.npy', labels)

# Create SVM model


In [None]:
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(latent_vectors, labels, test_size=0.2, random_state=42)

# Initialize and train the SVM classifier
svm_classifier = SVC(kernel='linear', C=1.0)
svm_classifier.fit(X_train, y_train)

# Predict on the test set
y_pred = svm_classifier.predict(X_test)

# Evaluate the classifier
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)
print("Classification Report:\n", classification_report(y_test, y_pred))


# Temporary dataset with patient 1 only

In [None]:
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import re
import numpy as np

class patchesDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.patients = []
        self.coordinates = []

        # Process not_roi_patches (label 0)
        not_roi_dir = os.path.join(root_dir, 'not_roi_patches')
        for patient_id in range(1, 2):
            patient_dir = str(patient_id) + '.svs'
            patient_dir = os.path.join(not_roi_dir, str(patient_dir))
            for img_name in os.listdir(patient_dir):
                self.image_paths.append(os.path.join(patient_dir, img_name))
                self.labels.append(0)
                self.patients.append(patient_id)
                self.coordinates.append(self._extract_coordinates(img_name))

        # Process in_roi_patches (label 1)
        in_roi_dir = os.path.join(root_dir, 'in_roi_patches')
        for patient_id in range(1, 2):
            if patient_id != 21:
                patient_dir = str(patient_id) + '.svs'
                patient_dir = os.path.join(in_roi_dir, str(patient_dir))
                for img_name in os.listdir(patient_dir):
                    self.image_paths.append(os.path.join(patient_dir, img_name))
                    self.labels.append(1)
                    self.patients.append(patient_id)
                    self.coordinates.append(self._extract_coordinates(img_name))

    def _extract_coordinates(self, img_name):
        # Extract x and y from the filename
        match = re.match(r'(\d+)_(\d+)_.*\.png', img_name)
        if match:
            x, y = int(match.group(1)), int(match.group(2))
            return (x, y)
        else:
            raise ValueError(f"Filename {img_name} does not match the expected pattern.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGBA")  # Ensure image is RGBA
        image = np.array(image)[:, :, :3]  # Drop the alpha channel
        image = Image.fromarray(image)  # Convert back to PIL image
        label = self.labels[idx]
        patient_id = self.patients[idx]
        coordinates = self.coordinates[idx]

        if self.transform:
            image = self.transform(image)
        return image, label, patient_id, coordinates

transform = transforms.Compose([
    transforms.ToTensor()
])

In [None]:
dataset_temp = patchesDataset(root_dir='patches', transform=transform)
dataloader_temp = DataLoader(dataset_temp, batch_size=8, shuffle=True)

# Results with only patient 1

In [None]:
# RESULTS ONLY WITH PATIENT 1

from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(latent_vectors, labels, test_size=0.2, random_state=42)

# Initialize and train the SVM classifier
svm_classifier = SVC(kernel='linear', C=1.0)
svm_classifier.fit(X_train, y_train)

# Predict on the test set
y_pred = svm_classifier.predict(X_test)

# Evaluate the classifier
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)
print("Classification Report:\n", classification_report(y_test, y_pred))


Accuracy: 0.8571428571428571
Classification Report:
               precision    recall  f1-score   support

           0       0.85      0.87      0.86       102
           1       0.87      0.84      0.85       101

    accuracy                           0.86       203
   macro avg       0.86      0.86      0.86       203
weighted avg       0.86      0.86      0.86       203

