In [None]:
!pip install pytorch-lightning
!pip install wandb
!pip install -U git+https://github.com/qubvel/segmentation_models.pytorch
#!pip install lightning-bolts
#!pip install torchmetrics

Collecting pytorch-lightning
[?25l  Downloading https://files.pythonhosted.org/packages/c2/a1/a991780873b5fd760fb99dfda01916fe9e5b186f0ba70a120e6b4f79cfaa/pytorch_lightning-1.3.1-py3-none-any.whl (805kB)
[K     |████████████████████████████████| 808kB 24.8MB/s 
Collecting torchmetrics>=0.2.0
[?25l  Downloading https://files.pythonhosted.org/packages/3b/e8/513cd9d0b1c83dc14cd8f788d05cd6a34758d4fd7e4f9e5ecd5d7d599c95/torchmetrics-0.3.2-py3-none-any.whl (274kB)
[K     |████████████████████████████████| 276kB 33.1MB/s 
Collecting future>=0.17.1
[?25l  Downloading https://files.pythonhosted.org/packages/45/0b/38b06fd9b92dc2b68d58b75f900e97884c45bedd2ff83203d933cf5851c9/future-0.18.2.tar.gz (829kB)
[K     |████████████████████████████████| 829kB 29.5MB/s 
[?25hCollecting pyDeprecate==0.3.0
  Downloading https://files.pythonhosted.org/packages/14/52/aa227a0884df71ed1957649085adf2b8bc2a1816d037c2f18b3078854516/pyDeprecate-0.3.0-py3-none-any.whl
Collecting fsspec[http]>=2021.4.0
[?25l  

In [None]:
import math
import os
import logging
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import imageio
import random
from datetime import datetime

import torch
from torch import nn, tensor
from torch.distributions.categorical import Categorical
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset, random_split, TensorDataset

from torchvision import transforms
from torchvision.datasets import MNIST

import wandb

import pytorch_lightning as pl
from pytorch_lightning.metrics import functional as FM
from pytorch_lightning.callbacks import Callback, EarlyStopping
from pytorch_lightning.loggers import WandbLogger

import segmentation_models_pytorch as smp

In [None]:
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


In [None]:
drive_path = "/content/drive/MyDrive"
examples_dir = f"{drive_path}/imc-prediction/examples"
checkpoints_dir = f"{drive_path}/imc-prediction/checkpoints"

In [None]:
#n_protein_channels = 38
#collagen_index = 21
#n_protein_channels = 27
#collagen_index = 14
n_protein_channels = 3
collagen_index = 1

In [None]:
console_logger = logging.getLogger("pytorch_lightning")
console_logger.setLevel(logging.DEBUG)

In [None]:
# adapted from github.com/PyTorchLightning/lightning-bolts/blob/master/pl_bolts/models/vision/unet.py
class UNet(nn.Module):
    def __init__(
        self,
        num_classes,
        input_channels=3,
        num_layers=5,
        features_start=64,
        bilinear=False,
        dropout=0,
    ):
        if num_layers < 1:
            raise ValueError(f"num_layers={num_layers}, expected: num_layers > 0")

        super().__init__()
        self.num_layers = num_layers

        layers = [DoubleConv(input_channels, features_start, dropout)]

        feats = features_start
        for _ in range(num_layers - 1):
            layers.append(Down(feats, feats * 2, dropout))
            feats *= 2

        for _ in range(num_layers - 1):
            layers.append(Up(feats, feats // 2, dropout, bilinear))
            feats //= 2

        layers.append(nn.Conv2d(feats, num_classes, kernel_size=1))

        self.layers = nn.ModuleList(layers)

    def forward(self, x):
        xi = [self.layers[0](x)]
        # Down path
        for layer in self.layers[1:self.num_layers]:
            xi.append(layer(xi[-1]))
        # Up path
        for i, layer in enumerate(self.layers[self.num_layers:-1]):
            xi[-1] = layer(xi[-1], xi[-2 - i])
        return self.layers[-1](xi[-1])

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch, dropout=0):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            #nn.BatchNorm2d(out_ch),
            nn.Dropout(dropout),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            #nn.BatchNorm2d(out_ch),
            nn.Dropout(dropout),
            nn.LeakyReLU(inplace=True),
        )

    def forward(self, x):
        return self.net(x)

class Down(nn.Module):
    def __init__(self, in_ch, out_ch, dropout=0):
        super().__init__()
        self.net = nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2),
            DoubleConv(in_ch, out_ch, dropout),
        )

    def forward(self, x):
        return self.net(x)

class Up(nn.Module):
    def __init__(self, in_ch, out_ch, dropout=0, bilinear=False):
        super().__init__()
        self.upsample = None
        if bilinear:
            self.upsample = nn.Sequential(
                nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
                nn.Conv2d(in_ch, in_ch // 2, kernel_size=1),
            )
        else:
            self.upsample = nn.ConvTranspose2d(in_ch, in_ch // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_ch, out_ch, dropout)

    def forward(self, x1, x2):
        x1 = self.upsample(x1)

        # pad x1 to the size of x2
        diff_h = x2.shape[2] - x1.shape[2]
        diff_w = x2.shape[3] - x1.shape[3]

        x1 = F.pad(
            x1,
            [diff_w // 2, diff_w - diff_w // 2, diff_h // 2, diff_h - diff_h // 2],
        )

        # concatenate along the channels axis
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)
        # output one channel for each protein channel and then take loss for each?

In [None]:
class MyUNet(pl.LightningModule):
    def __init__(
        self,
        #num_features,
        #kernel_size,
        #stride,
        #padding,
        batch_size,  # here just for inclusion in wandb hparams
        input_channels=1,
        num_layers=4,
        dropout=0,
        learning_rate=3e-5,
        weight_decay=0,
        #protein_multiplier=1,
        #protein_multiplier_index=0,
        architecture="smp.UnetPlusPlus",
        encoder_name="resnet34",
        encoder_depth=5,
        decoder_channels=(256, 128, 64, 32, 16),
        decoder_use_batchnorm=False,
    ):
        super().__init__()
        self.save_hyperparameters()
        if architecture == "UNet":
            self.net = UNet(
                features_start=64,
                input_channels=input_channels,
                num_layers=num_layers,
                #num_classes=2,
                num_classes=n_protein_channels,
                dropout=dropout,
            )
        elif architecture == "smp.Unet":
            self.net = smp.Unet(
                encoder_name=encoder_name,
                encoder_depth=encoder_depth,
                encoder_weights="imagenet",
                decoder_use_batchnorm=decoder_use_batchnorm,
                decoder_channels=decoder_channels,
                #decoder_attention_type=None,
                in_channels=input_channels,
                classes=n_protein_channels,
                aux_params=dict(
                    dropout=dropout,
                    classes=n_protein_channels,
                ),
            )
        elif architecture == "smp.UnetPlusPlus":
            self.net = smp.UnetPlusPlus(
                encoder_name=encoder_name,
                encoder_depth=encoder_depth,
                encoder_weights="imagenet",
                decoder_use_batchnorm=decoder_use_batchnorm,
                decoder_channels=decoder_channels,
                #decoder_attention_type=None,
                in_channels=input_channels,
                classes=n_protein_channels,
                aux_params=dict(
                    dropout=dropout,
                    classes=n_protein_channels,
                ),
            )
        elif architecture == "smp.MAnet":
            self.net = smp.MAnet(
                encoder_name=encoder_name,
                encoder_depth=encoder_depth,
                encoder_weights="imagenet",
                decoder_use_batchnorm=decoder_use_batchnorm,
                decoder_channels=decoder_channels,
                #decoder_attention_type=None,
                in_channels=input_channels,
                classes=n_protein_channels,
                aux_params=dict(
                    dropout=dropout,
                    classes=n_protein_channels,
                ),
            )
        #self.loss_multiplier = torch.ones(n_protein_channels, 1, 1).cuda() / protein_multiplier
        #self.loss_multiplier[protein_multiplier_index, 0, 0] = 1

    def forward(self, x):
        if "smp." in self.hparams.architecture:
            return self.net(x[:, :, :-15, :-16])[0]
        else:
            return self.net(x)
        #return self.net(x)[:, 1:]

    def training_step(self, batch, batch_idx):
        x, y = batch
        if "smp." in self.hparams.architecture:
            y = y[:, :, :-15, :-16]
        y_pred = self(x)
        #loss = F.binary_cross_entropy(torch.sigmoid(y_pred.flatten(0, 1)), y)
        #loss = (F.binary_cross_entropy(torch.sigmoid(y_pred), y, reduction="none") * self.loss_multiplier).mean()
        loss = F.binary_cross_entropy(torch.sigmoid(y_pred), y)
        self.log("train_loss", loss, on_epoch=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        if "smp." in self.hparams.architecture:
            y = y[:, :, :-15, :-16]
        y_pred = self(x)
        #loss = F.binary_cross_entropy(torch.sigmoid(y_pred.flatten(0, 1)), y)
        #loss = (F.binary_cross_entropy(torch.sigmoid(y_pred), y, reduction="none") * self.loss_multiplier).mean()
        loss = F.binary_cross_entropy(torch.sigmoid(y_pred), y)
        loss_collagen = F.binary_cross_entropy(torch.sigmoid(y_pred[:, collagen_index]), y[:, collagen_index])
        self.log("val_loss", loss, on_epoch=True, logger=True)
        self.log("val_loss_collagen", loss_collagen, on_epoch=True, logger=True)
        return {
            "val_loss": loss,
            "val_loss_collagen": loss_collagen,
        }

    def configure_optimizers(self):
        return torch.optim.Adam(
            self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=self.hparams.weight_decay,
        )

In [None]:
# assumes batch size of 1 currently
def display_predictions(loader):#, protein_multiplier_index):
    n_rows = 14
    n_cols = 8
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * 1.5, n_rows * 1.5))

    for i, (xs, ys) in enumerate(loader):
        if i == n_rows:
            break

        model.eval()
        preds = model(xs.cuda()).detach().cpu().numpy()

        axs[i][0].imshow(xs[0][1])
        axs[i][1].imshow(xs[0][0])
        axs[i][2].imshow(ys[0][0])
        axs[i][3].imshow(preds[0][0])
        axs[i][4].imshow(ys[0][1])
        axs[i][5].imshow(preds[0][1])
        axs[i][6].imshow(ys[0][2])
        axs[i][7].imshow(preds[0][2])

        for j in range(n_cols):
            axs[i][j].set_xticks([])
            axs[i][j].set_yticks([])

        fig.tight_layout(pad=0)

    axs[0][0].set_title("DNA input")
    axs[0][1].set_title("Pano input")
    axs[0][2].set_title("True alpha")
    axs[0][3].set_title("Pred alpha")
    axs[0][4].set_title("True collagen")
    axs[0][5].set_title("Pred collagen")
    axs[0][6].set_title("True keratin")
    axs[0][7].set_title("Pred keratin")
    axs[-1][0].set_title("DNA input", y=-0.13)
    axs[-1][1].set_title("Pano input", y=-0.13)
    axs[-1][2].set_title("True alpha", y=-0.13)
    axs[-1][3].set_title("Pred alpha", y=-0.13)
    axs[-1][4].set_title("True collagen", y=-0.13)
    axs[-1][5].set_title("Pred collagen", y=-0.13)
    axs[-1][6].set_title("True keratin", y=-0.13)
    axs[-1][7].set_title("Pred keratin", y=-0.13)

In [None]:
class LoggingCallback(Callback):
    def on_train_epoch_end(self, trainer, pl_module, outputs):
        if trainer.current_epoch < 3 or (trainer.current_epoch + 1) % 3 == 0:
            display_predictions(train_loader)#, pl_module.hparams.protein_multiplier_index)
            trainer.logger.experiment.log(
                {
                    "train_preds": wandb.Image(plt, caption=f"epoch_{trainer.current_epoch}"),
                    "global_step": trainer.global_step,
                },
                commit=False,  # docs: "When logging manually ... make sure to use commit=False ..."
            )
        if (trainer.current_epoch + 1) % 3 == 0:
            trainer.save_checkpoint(f"epoch_{trainer.current_epoch}.ckpt")
            trainer.logger.experiment.save(f"epoch_{trainer.current_epoch}.ckpt")
            console_logger.info(f"Saved epoch_{trainer.current_epoch}.ckpt to wandb server")

    def on_validation_epoch_end(self, trainer, pl_module):
        if trainer.current_epoch < 3 or (trainer.current_epoch + 1) % 3 == 0:
            display_predictions(val_loader)#, pl_module.hparams.protein_multiplier_index)
            trainer.logger.experiment.log(
                {
                    "val_preds": wandb.Image(plt, caption=f"epoch_{trainer.current_epoch}"),
                    "global_step": trainer.global_step,
                },
                commit=False,  # docs: "When logging manually ... make sure to use commit=False ..."
            )

In [None]:
#inputs = torch.load(f"{drive_path}/imc-prediction/tensors/inputs.pt")
inputs_with_dna = torch.load(f"{drive_path}/imc-prediction/tensors/inputs_with_dna.pt")
alpha_collagen_keratin_targets = torch.load(f"{drive_path}/imc-prediction/tensors/alpha_collagen_keratin_targets.pt")
#collagen_targets = torch.load(f"{drive_path}/imc-prediction/tensors/collagen_targets.pt")
#dna_targets = torch.load(f"{drive_path}/imc-prediction/tensors/dna_targets.pt")
#inputs_with_dna = torch.cat([inputs, dna_targets], dim=1)

In [None]:
all_indices = range(inputs_with_dna.shape[0])
random.seed(123)
val_indices = random.sample(all_indices, math.floor(inputs_with_dna.shape[0] / 4))
train_indices = list(set(all_indices) - set(val_indices))
train_inputs = inputs_with_dna.index_select(dim=0, index=torch.tensor(train_indices))
train_targets = alpha_collagen_keratin_targets.index_select(dim=0, index=torch.tensor(train_indices))
val_inputs = inputs_with_dna.index_select(dim=0, index=torch.tensor(val_indices))
val_targets = alpha_collagen_keratin_targets.index_select(dim=0, index=torch.tensor(val_indices))

In [None]:
t1 = transforms.Compose([
    transforms.RandomRotation([90, 90], expand=True),
    transforms.RandomResizedCrop([271, 304], scale=[1, 1]),  # reflect-pad instead of resizing?
])
#t2 = transforms.RandomRotation([180, 180])
#t3 = transforms.Compose([
#    transforms.RandomRotation([270, 270], expand=True),
#    transforms.RandomResizedCrop([271, 304], scale=[1, 1]),  # reflect-pad instead of resizing?
#])
t4 = transforms.RandomHorizontalFlip(p=1)
#t5 = transforms.RandomVerticalFlip(p=1)

In [None]:
#normalize_transform = transforms.Normalize(
#    mean=[0.45, 0.45],
#    std=[0.225, 0.225],
#)
#train_inputs_augmented = torch.cat([
#    normalize_transform(train_inputs),
#    transforms.Compose([t1, normalize_transform])(train_inputs),
#    transforms.Compose([t2, normalize_transform])(train_inputs),
#    transforms.Compose([t3, normalize_transform])(train_inputs),
#    transforms.Compose([t4, normalize_transform])(train_inputs),
#    transforms.Compose([t5, normalize_transform])(train_inputs),
#])
train_inputs_augmented = torch.cat([
    train_inputs,
    t1(train_inputs),
    #t2(train_inputs),
    #t3(train_inputs),
    t4(train_inputs),
    #t5(train_inputs),
])
train_targets_augmented = torch.cat([
    train_targets,
    t1(train_targets),
    #t2(train_targets),
    #t3(train_targets),
    t4(train_targets),
    #t5(train_targets),
])

In [None]:
inputs = None
collagen_targets = None
dna_targets = None
inputs_with_dna = None
train_inputs = None
train_targets = None

In [None]:
#train_dataset = TensorDataset(
#    torch.load(f"{drive_path}/imc-prediction/tensors/train_inputs.pt"),
#    torch.load(f"{drive_path}/imc-prediction/tensors/train_targets.pt"),
#)
#val_dataset = TensorDataset(
#    torch.load(f"{drive_path}/imc-prediction/tensors/val_inputs.pt"),
#    torch.load(f"{drive_path}/imc-prediction/tensors/val_targets.pt"),
#)

In [None]:
#batch_size = 1
batch_size = 32
train_loader = DataLoader(
    #dataset=train_dataset,
    #dataset=TensorDataset(train_inputs, train_targets),
    dataset=TensorDataset(train_inputs_augmented, train_targets_augmented),
    batch_size=batch_size,
    num_workers=4,
    shuffle=True,
    #pin_memory=True,
)
val_loader = DataLoader(
    #dataset=val_dataset,
    dataset=TensorDataset(val_inputs, val_targets),
    batch_size=batch_size,
    num_workers=4,
    shuffle=False,
    #pin_memory=True,
)

In [None]:
train_inputs_augmented = None
train_targets_augmented = None
val_inputs = None
val_targets = None

In [None]:
pl.seed_everything(55)

model = MyUNet(
    batch_size=batch_size,
    #num_layers=5,
    input_channels=2,
    dropout=0.42,
    #weight_decay=1e-4,
    #learning_rate=1e-3,
    #protein_multiplier=16,
    #protein_multiplier_index=0,
    #protein_multiplier=8,
    #protein_multiplier_index=collagen_index,
    architecture="smp.UnetPlusPlus",
    #architecture="smp.MAnet",
    encoder_depth=4,
    decoder_channels=(256, 128, 64, 32),
    #encoder_depth=3,
    #decoder_channels=(128, 64, 32),
    decoder_use_batchnorm=True,
)

Global seed set to 55


In [None]:
logger = WandbLogger(
  name=f"{datetime.now()}"[:19],
  project="unet1",
)
trainer = pl.Trainer(
    max_epochs=500,
    logger=logger,
    callbacks=[
        LoggingCallback(),
        EarlyStopping(
            #monitor="val_loss_multiplied",
            monitor="val_loss",
            patience=25,
        ),
    ],
    #accumulate_grad_batches=10,
    stochastic_weight_avg=True,
    check_val_every_n_epoch=1,
    gpus=1,
    #resume_from_checkpoint="./epoch_50.ckpt",
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Both every_n_train_steps and every_n_val_epochs are not set. Setting every_n_val_epochs=1


In [None]:
trainer.fit(model, train_loader, val_loader)

In [None]:
# higher learning rate
# MSE and other losses (hybrid MSE / cross entropy?)
# L1 loss? L1 "reconstruction" loss? pix2pix paper says it's better than GAN for sem seg
# GAN conditional on input channels!
# ^ look for GANs conditional on images online - is this how style transfer works? StyleGAN
# use tanh and make everything (-1, 1)? see how you did this in the homework
#   other sebastian suggestions
# ask jin sun / sebastian / etc what evaluation metrics would be good
#  per pixel classification isn't great because of intensities
#  have pathologists rate? how exactly?

In [None]:
# print out weight matrices, see if weights are mostly 0. but is that dead relu, not vanishing gradient?
# wouldn't vanishing gradient be if change to weights is mostly 0?

In [None]:
# make dataset with 4 channels + data augmentation (or augment at runtime)
#     should fit in memory
#     train with holdout in same patients
# get new models really overfitting:
#     log better training examples - pull consistent indices directly from dataset and run .cuda()/.cpu() etc?
#     more encoder/decoder channels?
#     switch back to one protein (also try 2-4, think about which would be good)
#     look through wandb logs for best overfitting
# pathology checkpoints
#    or pathology datasets to fine tune smp models
# look at a variety of proteins. maybe certain ones it does much better than others?

In [None]:
# train within patients - then you don't have issue of different disease, etc
# train on large greyscale pathology dataset with self supervision fill in the blank etc
# try coordconv
# test late 2cnm83td models against training data to assess overfitting - wandb didn't get informative examples

In [None]:
wandb.finish()

In [None]:
def get_val_preds(model, protein_channel):
    model.eval()
    with torch.no_grad():
        #fig, axs = plt.subplots(len(val_dataset), 3, figsize=(3 * 2, len(val_dataset) * 2))
        fig, axs = plt.subplots(5, 3, figsize=(3 * 2, 5 * 2))
        i = 0
        for xs, ys in val_loader:
            if i == 5:
                break
            n = xs.shape[0]
            preds = model(xs)
            for _i in range(n):
                _axs = axs[i + _i]
                _axs[0].imshow(xs[_i][0].cpu().detach().numpy())
                _axs[1].imshow(preds[_i][protein_channel].cpu().detach().numpy())
                _axs[2].imshow(ys[_i][protein_channel].cpu().detach().numpy())
                for j in range(3):
                    _axs[j].set_xticks([])
                    _axs[j].set_yticks([])
            i += n
        fig.tight_layout()

In [None]:
!wandb login

[34m[1mwandb[0m: Currently logged in as: [33mbzrry[0m (use `wandb login --relogin` to force relogin)


In [None]:
wandb.restore("epoch_39.ckpt", run_path="bzrry/unet1/2cnm83td")
wandb.restore("epoch_59.ckpt", run_path="bzrry/unet1/2cnm83td")
wandb.restore("epoch_79.ckpt", run_path="bzrry/unet1/3b1u16fg")

<_io.TextIOWrapper name='/content/epoch_79.ckpt' mode='r' encoding='UTF-8'>

In [None]:
model_2cnm83td_39 = MyUNet.load_from_checkpoint(f"./epoch_39.ckpt").cuda()
model_2cnm83td_59 = MyUNet.load_from_checkpoint(f"./epoch_59.ckpt").cuda()
#model_3b1u16fg_79 = MyUNet.load_from_checkpoint(f"./epoch_79.ckpt").cuda()

In [None]:
get_val_preds(model_2cnm83td_39, collagen_index)

In [None]:
get_val_preds(model_2cnm83td_39, 0)

In [None]:
get_val_preds(model_2cnm83td_59, collagen_index)

In [None]:
get_val_preds(model_2cnm83td_59, 0)

In [None]:
!nvidia-smi --query-gpu=utilization.gpu,utilization.memory,memory.total,memory.free,memory.used --format=csv -l 5
#!nvidia-smi
#!nvidia-smi --gpu-reset -i 0

utilization.gpu [%], utilization.memory [%], memory.total [MiB], memory.free [MiB], memory.used [MiB]
0 %, 0 %, 16280 MiB, 65 MiB, 16215 MiB


In [None]:
# might need a bigger model / more layers/params/etc to handle more training examples?
# get to overfitting
# make sure model is actually training against all the examples...seems like it's going fast?
# layer/instance norm instead of batch norm?

In [None]:
wandb.init()

In [None]:
class FullyConv(nn.Module):
    def __init__(self, num_features, kernel_size, stride, padding):
        super(FullyConv, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, num_features, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.LeakyReLU(),
            nn.Conv2d(num_features, num_features, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.LeakyReLU(),
            nn.Conv2d(num_features, num_features, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.LeakyReLU(),
            nn.Conv2d(num_features, 1, kernel_size=kernel_size, stride=stride, padding=padding),
        )

    def forward(self, x):
        return self.net(x)