In [None]:
from ImageProcessor import *
from Evaluation import *
svs_path = '../prelimary_data/BRACS_1494.svs'
json_path = '../prelimary_data/BRACS_1494.geojson'
tile_size = 512

imageProcessor_1494 = ImageProcessor(json_path, svs_path)
tiles_1494 = imageProcessor_1494.generate_tile(tile_size = tile_size)

In [None]:
svs_path = '../prelimary_data/BRACS_1496.svs'
json_path = '../prelimary_data/BRACS_1496.geojson'
imageProcessor_1496 = ImageProcessor(json_path, svs_path)
tiles_1496 = imageProcessor_1496.generate_tile(tile_size = tile_size)

In [None]:
svs_path = '../prelimary_data/BRACS_1286.svs'
json_path = '../prelimary_data/BRACS_1286.geojson'
imageProcessor_1286 = ImageProcessor(json_path, svs_path)
tiles_1286 = imageProcessor_1286.generate_tile(tile_size = tile_size)

In [None]:
len(tiles_1494),len(tiles_1496), len(tiles_1286)

In [None]:
tiles = tiles_1494 + tiles_1496 + tiles_1286

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import relu
import json
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import numpy as np

class UNet(nn.Module):
    """
    Base Unet Implementation
    """

    def __init__(self, n_class:int = 2) -> None:
        """
        Initialize Unet, n_class is number of classes we want to segment images for
        Defaults at 2 for 2 ROI types
        """
        super().__init__()

        # Encoder
        # In the encoder, convolutional layers with the Conv2d function are used to extract features from the input image.
        # Each block in the encoder consists of two convolutional layers followed by a max-pooling layer, with the exception of the last block which does not include a max-pooling layer.
        # -------
        # input: 572x572x3
        self.e11 = nn.Conv2d(3, 64, kernel_size=3, padding=1) # output: 570x570x64
        self.e12 = nn.Conv2d(64, 64, kernel_size=3, padding=1) # output: 568x568x64
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # output: 284x284x64

        # input: 284x284x64
        self.e21 = nn.Conv2d(64, 128, kernel_size=3, padding=1) # output: 282x282x128
        self.e22 = nn.Conv2d(128, 128, kernel_size=3, padding=1) # output: 280x280x128
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # output: 140x140x128

        # input: 140x140x128
        self.e31 = nn.Conv2d(128, 256, kernel_size=3, padding=1) # output: 138x138x256
        self.e32 = nn.Conv2d(256, 256, kernel_size=3, padding=1) # output: 136x136x256
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) # output: 68x68x256

        # input: 68x68x256
        self.e41 = nn.Conv2d(256, 512, kernel_size=3, padding=1) # output: 66x66x512
        self.e42 = nn.Conv2d(512, 512, kernel_size=3, padding=1) # output: 64x64x512
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) # output: 32x32x512

        # input: 32x32x512
        self.e51 = nn.Conv2d(512, 1024, kernel_size=3, padding=1) # output: 30x30x1024
        self.e52 = nn.Conv2d(1024, 1024, kernel_size=3, padding=1) # output: 28x28x1024


        # Decoder
        # In the decoder, transpose convolutional layers with the ConvTranspose2d function are used to upsample the feature maps to the original size of the input image.
        # Each block in the decoder consists of an upsampling layer, a concatenation with the corresponding encoder feature map, and two convolutional layers.
        # -------
        self.upconv1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.d11 = nn.Conv2d(1024, 512, kernel_size=3, padding=1)
        self.d12 = nn.Conv2d(512, 512, kernel_size=3, padding=1)

        self.upconv2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.d21 = nn.Conv2d(512, 256, kernel_size=3, padding=1)
        self.d22 = nn.Conv2d(256, 256, kernel_size=3, padding=1)

        self.upconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.d31 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.d32 = nn.Conv2d(128, 128, kernel_size=3, padding=1)

        self.upconv4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.d41 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.d42 = nn.Conv2d(64, 64, kernel_size=3, padding=1)

        # Output layer
        self.outconv = nn.Conv2d(64, n_class, kernel_size=1)

    def forward(self, x):
        # Encoder
        xe11 = relu(self.e11(x))
        xe12 = relu(self.e12(xe11))
        xp1 = self.pool1(xe12)

        xe21 = relu(self.e21(xp1))
        xe22 = relu(self.e22(xe21))
        xp2 = self.pool2(xe22)

        xe31 = relu(self.e31(xp2))
        xe32 = relu(self.e32(xe31))
        xp3 = self.pool3(xe32)

        xe41 = relu(self.e41(xp3))
        xe42 = relu(self.e42(xe41))
        xp4 = self.pool4(xe42)

        xe51 = relu(self.e51(xp4))
        xe52 = relu(self.e52(xe51))

        # Decoder
        xu1 = self.upconv1(xe52)
        xu11 = torch.cat([xu1, xe42], dim=1)
        xd11 = relu(self.d11(xu11))
        xd12 = relu(self.d12(xd11))

        xu2 = self.upconv2(xd12)
        xu22 = torch.cat([xu2, xe32], dim=1)
        xd21 = relu(self.d21(xu22))
        xd22 = relu(self.d22(xd21))

        xu3 = self.upconv3(xd22)
        xu33 = torch.cat([xu3, xe22], dim=1)
        xd31 = relu(self.d31(xu33))
        xd32 = relu(self.d32(xd31))

        xu4 = self.upconv4(xd32)
        xu44 = torch.cat([xu4, xe12], dim=1)
        xd41 = relu(self.d41(xu44))
        xd42 = relu(self.d42(xd41))

        # Output layer
        out = self.outconv(xd42)

        return out

In [None]:
# @title
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import relu
import json
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import numpy as np

class SmallUNet(nn.Module):
    """
    Base Unet Implementation
    """

    def __init__(self, n_class:int = 2) -> None:
        """
        Initialize Unet, n_class is number of classes we want to segment images for
        Defaults at 2 for 2 ROI types
        """
        super().__init__()

        # Encoder
        # In the encoder, convolutional layers with the Conv2d function are used to extract features from the input image.
        # Each block in the encoder consists of two convolutional layers followed by a max-pooling layer, with the exception of the last block which does not include a max-pooling layer.
        # -------
        # input: 572x572x3

        self.e21 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.e22 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        # input: 284x284x64
        self.e31 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.e32 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) #

        # input: 140x140x128
        self.e41 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.e42 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        # input: 68x68x256
        self.e51 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.e52 = nn.Conv2d(512, 512, kernel_size=3, padding=1)


        # Decoder
        # In the decoder, transpose convolutional layers with the ConvTranspose2d function are used to upsample the feature maps to the original size of the input image.
        # Each block in the decoder consists of an upsampling layer, a concatenation with the corresponding encoder feature map, and two convolutional layers.
        # -------
        self.upconv1 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.d11 = nn.Conv2d(512, 256, kernel_size=3, padding=1)
        self.d12 = nn.Conv2d(256, 256, kernel_size=3, padding=1)

        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.d21 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.d22 = nn.Conv2d(128, 128, kernel_size=3, padding=1)

        self.upconv3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.d31 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.d32 = nn.Conv2d(64, 64, kernel_size=3, padding=1)

        # Output layer
        self.outconv = nn.Conv2d(64, n_class, kernel_size=1)
        self.sigmoid = nn.LogSoftmax(dim=1)

    def forward(self, x):

        xe21 = relu(self.e21(x))
        xe22 = relu(self.e22(xe21))
        xp2 = self.pool2(xe22)

        xe31 = relu(self.e31(xp2))
        xe32 = relu(self.e32(xe31))
        xp3 = self.pool3(xe32)

        xe41 = relu(self.e41(xp3))
        xe42 = relu(self.e42(xe41))
        xp4 = self.pool4(xe42)

        xe51 = relu(self.e51(xp4))
        xe52 = relu(self.e52(xe51))

        # Decoder
        xu1 = self.upconv1(xe52)
        xu11 = torch.cat([xu1, xe42], dim=1)
        xd11 = relu(self.d11(xu11))
        xd12 = relu(self.d12(xd11))

        xu2 = self.upconv2(xd12)
        xu22 = torch.cat([xu2, xe32], dim=1)
        xd21 = relu(self.d21(xu22))
        xd22 = relu(self.d22(xd21))

        xu3 = self.upconv3(xd22)
        xu33 = torch.cat([xu3, xe22], dim=1)
        xd31 = relu(self.d31(xu33))
        xd32 = relu(self.d32(xd31))

        # Output layer
        out = self.outconv(xd32)
        #out = self.sigmoid(out)

        return out

In [None]:
import random
import numpy as np
import torch

def set_seeds(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

set_seeds(32)

In [None]:
tiles = [t for t in tiles if t.image_data.shape == (3, tile_size, tile_size)]
tiles_copy = tiles

In [None]:
tiles_with_masks = [t for t in tiles if 1 in t.mask]
print(len(tiles_with_masks))
tiles = tiles_with_masks + list(np.random.choice(tiles, 3*len(tiles_with_masks), replace=False))
print(len(tiles))

In [None]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from numpy import random

def generate_random_rng_state():
    random_seed = random.randint(0, 2**32 - 1)
    torch.manual_seed(random_seed)
    rng_state = torch.get_rng_state()
    return rng_state

class TiledWSIDataset(Dataset):
    def __init__(self, tiles, length_modifier=1.0, training=True):
        self.tiles = [t.image_data for t in tiles]
        self.masks = [t.mask for t in tiles]
        self.general_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            #transforms.RandomRotation(degrees=(-90, 90)),
        ])
        self.image_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
            #transforms.GaussianBlur(kernel_size=(5, 5), sigma=(1.0, 2.0))
        ])
        self.length_modifier = length_modifier
        self.random_state = [generate_random_rng_state()] * int(len(self.tiles) * self.length_modifier)
        self.training = training

    def __len__(self):
        return int(self.length_modifier * len(self.tiles))

    def __getitem__(self, idx):
        image = self.tiles[idx % len(self.tiles)]
        image = image.transpose(1, 2, 0)
        
        mask = self.masks[idx % len(self.tiles)]

        # this is the most lazy possible way to do this, fix later
        mask_rgb = np.repeat(mask, 3, axis=0)
        mask_rgb = torch.tensor(mask_rgb)

        image = self.image_transform(image)

        if self.training:
          state = self.random_state[idx]
          torch.set_rng_state(state)
          image = self.general_transform(image)

          torch.set_rng_state(state)
          mask_rgb = self.general_transform(mask_rgb)

        mask = mask_rgb[0]
        mask = mask.unsqueeze(0)

        return image, mask

dataset = TiledWSIDataset(tiles, length_modifier=2.0)
full_dataset = DataLoader(dataset, batch_size=16, shuffle=True)

In [None]:
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
test_dataset.training = False

In [None]:
from torch.optim import Adam

def dice_loss(pred, target):
    smooth = 1.
    iflat = pred.view(-1)
    tflat = target.view(-1)
    intersection = (iflat * tflat).sum()

    return 1 - ((2. * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth))

def criterion(pred, target):
    return dice_loss(pred, target) + F.cross_entropy(pred, target)


In [None]:
from torch.utils.data import random_split, DataLoader


# Split the smaller dataset into training and testing
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

batch_size = 16

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)


In [None]:
import torch
torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(n_class=1).to(device)

lr = 1e-4
optimizer = Adam(model.parameters(), lr=lr)

criterion = nn.BCEWithLogitsLoss()

def train(model, trainloader, testloader, criterion, optimizer, epochs, device):
  best_test = np.inf
  train_losses = []
  test_losses = []


  for epoch in range(epochs):
    epoch_loss = 0
    test_loss = 0

    model.train()
    for images, masks in tqdm(trainloader):

      images, masks = images.to(device), masks.to(device)

      optimizer.zero_grad()

      outputs = model(images)

      loss = criterion(outputs, masks)
      loss.backward()
      optimizer.step()
      loss = loss.cpu()
      epoch_loss += loss.item()

    model.eval()
    for image, masks in tqdm(testloader):
      images, masks = images.to(device), masks.to(device)
      outputs = model(images)
      loss = criterion(outputs, masks)
      loss = loss.cpu()
      test_loss += loss.item()

      
    epoch_train_loss = epoch_loss / len(trainloader)
    epoch_test_loss = test_loss / len(testloader)

    train_losses.append(epoch_train_loss)
    test_losses.append(epoch_test_loss)

    if test_loss/len(testloader) < best_test:
      best_test = test_loss/len(testloader)
      lr_str = str(lr)
      tile_size_str = str(tile_size)
      torch.save(model.state_dict(), 'unet_' + lr_str + "_" + tile_size_str + 'pth')

    print(f'Epoch {epoch+1}, Train Loss: {epoch_loss/len(trainloader)}, Test Loss: {test_loss/len(testloader)}')

  return train_losses, test_losses

    # plt.figure(figsize=(10, 5))
    # plt.plot(range(1, epochs + 1), train_losses, label='Train Loss', color='red')
    # plt.plot(range(1, epochs + 1), test_losses, label='Test Loss', color='blue')
    # plt.xlabel('Epochs')
    # plt.ylabel('Loss')
    # plt.title('Train and Test Loss Over Epochs')
    # plt.legend()
    # plt.grid(True)
    # plt.show()


In [None]:
num_iterates = 50
UNet_trained = train(model, train_loader, test_loader, criterion, optimizer, num_iterates, device)

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(range(num_iterates), UNet_trained[0], label='Train Loss', color='red')
plt.plot(range(num_iterates), UNet_trained[1], label='Test Loss', color='blue')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Train and Test Loss Over Epochs with combination of 3 slides') #10^-4 
plt.legend()
plt.grid(True)
plt.show()

In [None]:
confusion_matrix_1 = compute_confusion_matrix(model, dataset, device)
eval = evaluate_model(model, dataset, device)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_1 = UNet(n_class=1).to(device)

lr = 1e-6
optimizer = Adam(model_1.parameters(), lr=lr)

criterion = nn.BCEWithLogitsLoss()

def train(model, trainloader, testloader, criterion, optimizer, epochs, device):
  best_test = np.inf
  train_losses = []
  test_losses = []


  for epoch in range(epochs):
    epoch_loss = 0
    test_loss = 0

    model.train()
    for images, masks in tqdm(trainloader):

      images, masks = images.to(device), masks.to(device)

      optimizer.zero_grad()

      outputs = model(images)

      loss = criterion(outputs, masks)
      loss.backward()
      optimizer.step()
      loss = loss.cpu()
      epoch_loss += loss.item()

    model.eval()
    for image, masks in tqdm(testloader):
      images, masks = images.to(device), masks.to(device)
      outputs = model(images)
      loss = criterion(outputs, masks)
      loss = loss.cpu()
      test_loss += loss.item()

      
    epoch_train_loss = epoch_loss / len(trainloader)
    epoch_test_loss = test_loss / len(testloader)

    train_losses.append(epoch_train_loss)
    test_losses.append(epoch_test_loss)

    if test_loss/len(testloader) < best_test:
      best_test = test_loss/len(testloader)
      lr_str = str(lr)
      tile_size_str = str(tile_size)
      torch.save(model.state_dict(), 'unet_' + lr_str + "_" + tile_size_str + '_' + str(epochs) + '.pth')

    print(f'Epoch {epoch+1}, Train Loss: {epoch_loss/len(trainloader)}, Test Loss: {test_loss/len(testloader)}')

  return train_losses, test_losses

    # plt.figure(figsize=(10, 5))
    # plt.plot(range(1, epochs + 1), train_losses, label='Train Loss', color='red')
    # plt.plot(range(1, epochs + 1), test_losses, label='Test Loss', color='blue')
    # plt.xlabel('Epochs')
    # plt.ylabel('Loss')
    # plt.title('Train and Test Loss Over Epochs')
    # plt.legend()
    # plt.grid(True)
    # plt.show()


In [None]:
num_iterates = 50
UNet_trained_1 = train(model_1, train_loader, test_loader, criterion, optimizer, num_iterates, device)

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(range(num_iterates), UNet_trained_1[0], label='Train Loss', color='red')
plt.plot(range(num_iterates), UNet_trained_1[1], label='Test Loss', color='blue')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Train and Test Loss Over Epochs of 50 and lr = 1e-6')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
import time
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_2 = UNet(n_class=1).to(device)

lr = 1e-6
optimizer = Adam(model_2.parameters(), lr=lr)

criterion = nn.BCEWithLogitsLoss()

def train(model, trainloader, testloader, criterion, optimizer, epochs, device):
  best_test = np.inf
  train_losses = []
  test_losses = []

  start_time = time.time()

  for epoch in range(epochs):
    epoch_loss = 0
    test_loss = 0
    epoch_start_time = time.time()


    model.train()
    for images, masks in tqdm(trainloader):

      images, masks = images.to(device), masks.to(device)

      optimizer.zero_grad()

      outputs = model(images)

      loss = criterion(outputs, masks)
      loss.backward()
      optimizer.step()
      loss = loss.cpu()
      epoch_loss += loss.item()

    model.eval()
    for image, masks in tqdm(testloader):
      images, masks = images.to(device), masks.to(device)
      outputs = model(images)
      loss = criterion(outputs, masks)
      loss = loss.cpu()
      test_loss += loss.item()

      
    epoch_train_loss = epoch_loss / len(trainloader)
    epoch_test_loss = test_loss / len(testloader)

    train_losses.append(epoch_train_loss)
    test_losses.append(epoch_test_loss).

    if test_loss/len(testloader) < best_test:
      best_test = test_loss/len(testloader)
      lr_str = str(lr)
      tile_size_str = str(tile_size)
      torch.save(model.state_dict(), 'unet_' + lr_str + "_" + tile_size_str + '_' + str(epochs) + '.pth')

    epoch_duration = time.time() - epoch_start_time
    print(f'Epoch {epoch+1}, Train Loss: {epoch_loss/len(trainloader)}, Test Loss: {test_loss/len(testloader)}')

  total_duration = time.time() - start_time
  total_duration_minutes = total_duration / 60
  print(f'Total Training Time: {total_duration_minutes:.2f} minutes')

  return train_losses, test_losses

    # plt.figure(figsize=(10, 5))
    # plt.plot(range(1, epochs + 1), train_losses, label='Train Loss', color='red')
    # plt.plot(range(1, epochs + 1), test_losses, label='Test Loss', color='blue')
    # plt.xlabel('Epochs')
    # plt.ylabel('Loss')
    # plt.title('Train and Test Loss Over Epochs')
    # plt.legend()
    # plt.grid(True)
    # plt.show()


In [None]:
num_iterates = 100
UNet_trained_2 = train(model_2, train_loader, test_loader, criterion, optimizer, num_iterates, device)

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(range(num_iterates), UNet_trained_2[0], label='Train Loss', color='red')
plt.plot(range(num_iterates), UNet_trained_2[1], label='Test Loss', color='blue')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Train and Test Loss Over Epochs of 100 and lr = 1e-6')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
# model.eval()

# for images, masks in train_loader:
#   images, masks = images.to(device), masks.to(device)
#   outputs = model(images)

#   outputs = outputs.cpu()
#   masks = masks.cpu()

#   for i in range(outputs.shape[0]):

#     threshold = 0.5
#     test_output = outputs[i].detach().numpy().transpose(1, 2, 0)
#     binary_output = np.zeros(test_output.shape, dtype=np.uint8)
#     binary_output[test_output > threshold] = 1

#     fig, ax = plt.subplots(1, 4, figsize=(8, 3))

#     image = images.cpu()

#     image = image[i].detach().numpy().transpose(1, 2, 0)
#     print(np.max(image))
#     print(np.min(image))
#     ax[0].imshow(image, cmap='viridis')
#     ax[0].set_title('Input')

#     ax[1].imshow(test_output, cmap='viridis')
#     ax[1].set_title('True Output')

#     ax[2].imshow(binary_output, cmap='viridis')
#     ax[2].set_title('Binary Output')

#     # just to stop strange colour things going on
#     current_mask = masks[i].detach().numpy().transpose(1, 2, 0)
#     if 0 not in current_mask:
#       current_mask[0][0] = 0

#     if 1 not in current_mask:
#       current_mask[0][0] = 1

#     ax[3].imshow(current_mask, cmap='viridis')
#     ax[3].set_title('Mask')

#     plt.show()

#   #break



In [None]:
torch.save(model.state_dict(), 'unet.pth')