In [1]:
 # "excellent-shard-422915-n5"
from google.colab import auth

# PROJECT_ID = "excellent-shard-422915-n5"  # @param {type:"string"}

auth.authenticate_user()

!echo "deb http://packages.cloud.google.com/apt gcsfuse-bionic main" > /etc/apt/sources.list.d/gcsfuse.list

!curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key add -

!apt -qq update

!apt -qq install gcsfuse

!mkdir colab_directory

!gcsfuse --implicit-dirs testopolito colab_directory

!ls colab_directory

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  2659  100  2659    0     0  10466      0 --:--:-- --:--:-- --:--:-- 10509
OK
57 packages can be upgraded. Run 'apt list --upgradable' to see them.
[1;33mW: [0mhttp://packages.cloud.google.com/apt/dists/gcsfuse-bionic/InRelease: Key is stored in legacy trusted.gpg keyring (/etc/apt/trusted.gpg), see the DEPRECATION section in apt-key(8) for details.[0m
The following NEW packages will be installed:
  gcsfuse
0 upgraded, 1 newly installed, 0 to remove and 57 not upgraded.
Need to get 10.4 MB of archives.
After this operation, 0 B of additional disk space will be used.
Selecting previously unselected package gcsfuse.
(Reading database ... 121920 files and directories currently installed.)
Preparing to unpack .../gcsfuse_2.0.1_amd64.deb ...
Unpacking gcsfuse (2.0.1) ...
Setting up gcsfuse (2.0.1) ...
{"timestamp":{"seconds":171

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.utils import save_image
from PIL import Image
import numpy as np
import random
from sklearn import svm
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
from sklearn.svm import SVC
import matplotlib.pyplot as plt

In [3]:
# Define the generator with embedding extraction
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(6, 64, 4, 2, 1),  # Input: 6 channels (3 RGB + 3 mask)
            nn.ReLU(True),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(True)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, x):
        embedding = self.encoder(x)
        output = self.decoder(embedding)
        return embedding, output

In [4]:
# Define the discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.main(x)

In [5]:
# Utility function to create a masked image with random position and size
def create_random_masked_image(image):
    _, _, height, width = image.size()
    mask = torch.zeros_like(image)

    # Define random position and size
    top = random.randint(0, height // 2)
    left = random.randint(0, width // 2)
    patch_height = random.randint(height // 4, height // 2)
    patch_width = random.randint(width // 4, width // 2)

    # Apply mask
    mask[:, :, top:top + patch_height, left:left + patch_width] = 1
    masked_image = image.clone()
    masked_image[:, :, top:top + patch_height, left:left + patch_width] = 0

    return masked_image, mask

In [6]:
# Load and preprocess the image
def load_image(image_path):
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    image = Image.open(image_path).convert('RGB')
    return transform(image).unsqueeze(0)

In [7]:
# Function to convert a tensor to a PIL image and display it
def tensor_to_pil_image(tensor):
    tensor = tensor.clone().detach().cpu()
    tensor = tensor.squeeze(0)  # remove batch dimension
    tensor = transforms.Normalize((-1, -1, -1), (2, 2, 2))(tensor)  # unnormalize
    tensor = tensor.permute(1, 2, 0)  # convert to HWC format
    image = tensor.numpy()
    image = np.clip(image, 0, 1)
    return Image.fromarray((image * 255).astype(np.uint8))

In [8]:
# Initialize models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator().to(device)
discriminator = Discriminator().to(device)

In [91]:
# Load and mask the image
image_path = "/content/colab_directory/CRC_WSIs_original/in_roi_patches/1.svs/10240_10240_mag1.png"
image = load_image(image_path).to(device)
masked_image, mask = create_random_masked_image(image)

# Concatenate masked image and mask along the channel dimension
z = torch.cat((masked_image, mask), dim=1)

In [83]:
# Display masked image
masked_pil_image = tensor_to_pil_image(masked_image)
masked_pil_image.show()  # This will open the image in the default image viewer
masked_pil_image.save("masked_image.png")  # Optionally save the image to disk

In [9]:
# GAN training setup
criterion = nn.BCELoss()
optimizerG = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerD = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Temporary dataset with patient 1 only

In [10]:
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import re
import numpy as np

class patchesDataset(Dataset):
    def __init__(self, root_dir, patient_id, transform=None):
        self.root_dir = root_dir
        self.patient_id = patient_id
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.patients = []
        self.coordinates = []

        # Process not_roi_patches (label 0)
        not_roi_dir = os.path.join(root_dir, 'not_roi_patches')
        patient_dir = os.path.join(not_roi_dir, f'{patient_id}.svs')
        for img_name in os.listdir(patient_dir):
            self.image_paths.append(os.path.join(patient_dir, img_name))
            self.labels.append(0)
            self.patients.append(patient_id)
            self.coordinates.append(self._extract_coordinates(img_name))

        # Process in_roi_patches (label 1)
        in_roi_dir = os.path.join(root_dir, 'in_roi_patches')
        patient_dir = os.path.join(in_roi_dir, f'{patient_id}.svs')
        for img_name in os.listdir(patient_dir):
            self.image_paths.append(os.path.join(patient_dir, img_name))
            self.labels.append(1)
            self.patients.append(patient_id)
            self.coordinates.append(self._extract_coordinates(img_name))

    def _extract_coordinates(self, img_name):
        # Extract x and y from the filename
        match = re.match(r'(\d+)_(\d+)_.*\.png', img_name)
        if match:
            x, y = int(match.group(1)), int(match.group(2))
            return (x, y)
        else:
            raise ValueError(f"Filename {img_name} does not match the expected pattern.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGBA")  # Ensure image is RGBA
        image = np.array(image)[:, :, :3]  # Drop the alpha channel
        image = Image.fromarray(image)  # Convert back to PIL image
        label = self.labels[idx]
        patient_id = self.patients[idx]
        coordinates = self.coordinates[idx]

        if self.transform:
            image = self.transform(image)
        return image, label, patient_id, coordinates

In [12]:
# Specify the patient ID you want to process
patient_id = 1

transform = transforms.Compose([
    transforms.ToTensor()
])

dataset_temp = patchesDataset(root_dir='/content/colab_directory/CRC_WSIs_original', patient_id=patient_id, transform=transform)
dataloader_temp = DataLoader(dataset_temp, batch_size=8, shuffle=True)

# Example of iterating through the dataloader
for images, labels, patient_ids, coords in dataloader_temp:
    print(f'Batch of images shape: {images.shape}')
    print(f'Batch of labels: {labels}')
    print(f'Batch of patient IDs: {patient_ids}')
    print(f'Batch of coordinates: {coords}')
    break  # Remove this break to iterate through all batches

Batch of images shape: torch.Size([8, 3, 512, 512])
Batch of labels: tensor([0, 0, 1, 1, 0, 1, 1, 1])
Batch of patient IDs: tensor([1, 1, 1, 1, 1, 1, 1, 1])
Batch of coordinates: [tensor([17408, 10240,  9216, 22016,  6144, 12800, 16896, 23040]), tensor([ 8192,  7680, 13312, 17920,  8704, 10752, 12288, 11776])]


In [13]:
# Inizializza il dataloader per il paziente 1
patient_id = 1
dataset = patchesDataset(root_dir='/content/colab_directory/CRC_WSIs_original', patient_id=patient_id, transform=transform)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

In [None]:
# Training loop
num_epochs = 10
embeddings_list = []
labels_list = []

# for epoch in range(num_epochs):
for images, labels, patient_ids, coords in dataloader:
    images = images.to(device)
    labels = labels.to(device)

    # Create random masked image
    masked_images = []
    masks = []
    for image in images:
        masked_image, mask = create_random_masked_image(image.unsqueeze(0))
        masked_images.append(masked_image)
        masks.append(mask)

    masked_images = torch.cat(masked_images)
    masks = torch.cat(masks)

    # Train Discriminator
    optimizerD.zero_grad()
    real_labels = torch.ones(images.size(0), device=device)
    fake_labels = torch.zeros(images.size(0), device=device)

    outputs = discriminator(images).view(images.size(0), -1).mean(1)
    d_loss_real = criterion(outputs, real_labels)
    d_loss_real.backward()

    z = torch.cat((masked_images, masks), dim=1)
    embedding, fake_image = generator(z)
    outputs = discriminator(fake_image.detach()).view(images.size(0), -1).mean(1)
    d_loss_fake = criterion(outputs, fake_labels)
    d_loss_fake.backward()

    optimizerD.step()

    # Train Generator
    optimizerG.zero_grad()
    outputs = discriminator(fake_image).view(images.size(0), -1).mean(1)
    g_loss = criterion(outputs, real_labels)
    g_loss.backward()

    optimizerG.step()

    # Collect embeddings and labels for SVM training
    embeddings_list.append(embedding.view(embedding.size(0), -1).cpu().detach().numpy())
    labels_list.append(real_labels.cpu().detach().numpy())
    embeddings_list.append(embedding.view(embedding.size(0), -1).cpu().detach().numpy())
    labels_list.append(fake_labels.cpu().detach().numpy())
    """
    if epoch % 1 == 0:
        print(f'Epoch [{epoch}/{num_epochs}], d_loss: {d_loss_real + d_loss_fake:.4f}, g_loss: {g_loss:.4f}')
        save_image(fake_image, f'output_{epoch}.png')
    """

In [None]:
# Save the GAN models
torch.save({
    'generator_state_dict': generator.state_dict(),
    'discriminator_state_dict': discriminator.state_dict(),
    'optimizerG_state_dict': optimizerG.state_dict(),
    'optimizerD_state_dict': optimizerD.state_dict(),
}, 'gan_random_masks.pth')

In [None]:
# Save the final generated image
save_image(fake_image, 'final_output.png')

In [None]:
# Prepare data for SVM
X = np.vstack(embeddings_list)
y = np.hstack(labels_list)

# Train/test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Initialize and train the SVM classifier
svm_classifier = SVC(kernel='linear', C=1.0)
svm_classifier.fit(X_train, y_train)

# Predict on the test set
y_pred = svm_classifier.predict(X_test)

# Evaluate the classifier
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)
print("Classification Report:\n", classification_report(y_test, y_pred))