In [1]:
 # "excellent-shard-422915-n5"
from google.colab import auth

# PROJECT_ID = "excellent-shard-422915-n5"  # @param {type:"string"}

auth.authenticate_user()

!echo "deb http://packages.cloud.google.com/apt gcsfuse-bionic main" > /etc/apt/sources.list.d/gcsfuse.list

!curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key add -

!apt -qq update

!apt -qq install gcsfuse

!mkdir colab_directory

!gcsfuse --implicit-dirs testopolito colab_directory

!ls colab_directory

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  2659  100  2659    0     0  13795      0 --:--:-- --:--:-- --:--:-- 13848
OK
46 packages can be upgraded. Run 'apt list --upgradable' to see them.
[1;33mW: [0mhttp://packages.cloud.google.com/apt/dists/gcsfuse-bionic/InRelease: Key is stored in legacy trusted.gpg keyring (/etc/apt/trusted.gpg), see the DEPRECATION section in apt-key(8) for details.[0m
The following NEW packages will be installed:
  gcsfuse
0 upgraded, 1 newly installed, 0 to remove and 46 not upgraded.
Need to get 10.4 MB of archives.
After this operation, 0 B of additional disk space will be used.
Selecting previously unselected package gcsfuse.
(Reading database ... 121918 files and directories currently installed.)
Preparing to unpack .../gcsfuse_2.0.1_amd64.deb ...
Unpacking gcsfuse (2.0.1) ...
Setting up gcsfuse (2.0.1) ...
{"timestamp":{"seconds":171

In [58]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.utils import save_image
from PIL import Image
import numpy as np
import random
import matplotlib.pyplot as plt

In [43]:
# Define the generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(6, 64, 4, 2, 1),  # Input: 6 channels (3 RGB + 3 mask)
            nn.ReLU(True),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)

In [44]:
# Define the discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.main(x)

In [45]:
# Utility function to create a masked image with random position and size
def create_random_masked_image(image):
    _, _, height, width = image.size()
    mask = torch.zeros_like(image)

    # Define random position and size
    top = random.randint(0, height // 2)
    left = random.randint(0, width // 2)
    patch_height = random.randint(height // 4, height // 2)
    patch_width = random.randint(width // 4, width // 2)

    # Apply mask
    mask[:, :, top:top + patch_height, left:left + patch_width] = 1
    masked_image = image.clone()
    masked_image[:, :, top:top + patch_height, left:left + patch_width] = 0

    return masked_image, mask

In [46]:
# Load and preprocess the image
def load_image(image_path):
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    image = Image.open(image_path).convert('RGB')
    return transform(image).unsqueeze(0)

In [59]:
# Function to convert a tensor to a PIL image and display it
def tensor_to_pil_image(tensor):
    tensor = tensor.clone().detach().cpu()
    tensor = tensor.squeeze(0)  # remove batch dimension
    tensor = transforms.Normalize((-1, -1, -1), (2, 2, 2))(tensor)  # unnormalize
    tensor = tensor.permute(1, 2, 0)  # convert to HWC format
    image = tensor.numpy()
    image = np.clip(image, 0, 1)
    return Image.fromarray((image * 255).astype(np.uint8))

In [47]:
# Initialize models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator().to(device)
discriminator = Discriminator().to(device)

In [52]:
# Load and mask the image
image_path = "/content/colab_directory/CRC_WSIs_original/in_roi_patches/1.svs/10240_10240_mag1.png"
image = load_image(image_path).to(device)
masked_image, mask = create_random_masked_image(image)

# Concatenate masked image and mask along the channel dimension
z = torch.cat((masked_image, mask), dim=1)

In [60]:
# Display masked image
masked_pil_image = tensor_to_pil_image(masked_image)
masked_pil_image.show()  # This will open the image in the default image viewer
masked_pil_image.save("masked_image.png")  # Optionally save the image to disk

In [53]:
# GAN training setup
criterion = nn.BCELoss()
optimizerG = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerD = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [57]:
# Training loop
num_epochs = 1000
for epoch in range(num_epochs):
    # Train Discriminator
    optimizerD.zero_grad()
    real_labels = torch.ones(image.size(0), device=device)
    fake_labels = torch.zeros(image.size(0), device=device)

    outputs = discriminator(image).view(-1)
    d_loss_real = criterion(outputs, real_labels)
    d_loss_real.backward()

    fake_image = generator(z)
    outputs = discriminator(fake_image.detach()).view(-1)
    d_loss_fake = criterion(outputs, fake_labels)
    d_loss_fake.backward()

    optimizerD.step()

    # Train Generator
    optimizerG.zero_grad()
    outputs = discriminator(fake_image).view(-1)
    g_loss = criterion(outputs, real_labels)
    g_loss.backward()

    optimizerG.step()

    if epoch % 100 == 0:
        print(f'Epoch [{epoch}/{num_epochs}], d_loss: {d_loss_real + d_loss_fake:.4f}, g_loss: {g_loss:.4f}')
        save_image(fake_image, f'output_{epoch}.png')

Epoch [0/1000], d_loss: 0.0088, g_loss: 6.1539
Epoch [100/1000], d_loss: 0.0028, g_loss: 6.7476
Epoch [200/1000], d_loss: 0.0017, g_loss: 7.2293
Epoch [300/1000], d_loss: 0.0010, g_loss: 7.5938
Epoch [400/1000], d_loss: 0.0007, g_loss: 7.9658
Epoch [500/1000], d_loss: 0.0005, g_loss: 8.2244
Epoch [600/1000], d_loss: 0.0004, g_loss: 8.3837
Epoch [700/1000], d_loss: 0.0003, g_loss: 8.7450
Epoch [800/1000], d_loss: 0.0003, g_loss: 8.6618
Epoch [900/1000], d_loss: 0.0003, g_loss: 8.8313


In [56]:
# Save the final generated image
save_image(fake_image, 'final_output.png')