In [None]:
import os
from pathlib import Path

import pandas as pd
import numpy as np
from tqdm import tqdm
from openslide import OpenSlide

import torch
from torch import nn
from torch.utils.data import (
    ConcatDataset,
    DataLoader,
    Dataset,
    Subset,
    SubsetRandomSampler,
    TensorDataset,
    random_split,
)

import torchvision
from torchvision import transforms
from PIL import Image

import einops

In [None]:
train_df = pd.read_csv('./train_outcomes.csv') # biopsy_id, label
test_df = pd.read_csv('./test_outcomes.csv')
train_mapping = pd.read_csv('./train_mapping.csv') # slide_id, biopsy_id, img path
test_mapping = pd.read_csv('./test_mapping.csv')

In [None]:
train_outcome_map = {}
"""
key: biopsy_id
value: stage_number 0,1,2,3,4 (exclude NaN)
"""
for idx, row in train_df.iterrows():
    train_outcome_map[row['biopsy_id']] = row['label']

train_slide_map = {}
"""
key: slide_id
value: Tuple(biopsy_id, slide_path)
"""
for idx, row in train_mapping.iterrows():
    train_slide_map[row['slide_id']] = (row['biopsy_id'], row['downsampled_path'])


In [None]:
test_outcome_map = {}
"""
key: biopsy_id
value: stage_number 0,1,2,3,4 (exclude NaN)
"""
for idx, row in test_df.iterrows():
    test_outcome_map[row['biopsy_id']] = row['label']

test_slide_map = {}
"""
key: slide_id
value: Tuple(biopsy_id, slide_path)
"""
for idx, row in test_mapping.iterrows():
    test_slide_map[row['slide_id']] = (row['biopsy_id'], row['downsampled_path'])


In [None]:
train_x = [] # img_path
train_y = [] # stage label
for slide_id in train_slide_map:
    # print(slide_id)
    biopsy_id, img_path = train_slide_map[slide_id]
    label = train_outcome_map[biopsy_id]
    train_x.append(img_path)
    train_y.append(label)

In [None]:
test_x = [] # img_path
test_y = [] # stage label
for slide_id in test_slide_map:
    # print(slide_id)
    biopsy_id, img_path = test_slide_map[slide_id]
    label = test_outcome_map[biopsy_id]
    test_x.append(img_path)
    test_y.append(label)

In [None]:
len(train_x), len(train_y), len(test_x), len(test_y)

In [None]:
transform_aug_train = transforms.Compose([
        transforms.Resize(224),
        transforms.RandomResizedCrop(size=224,scale=(0.8,1.0)),
        transforms.RandomRotation(degrees=15),
        transforms.RandomHorizontalFlip(),
        transforms.CenterCrop(size=224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])

transform_aug_test = transforms.Compose([
        transforms.Resize(224),
        # transforms.RandomResizedCrop(size=224,scale=(0.8,1.0)),
        # transforms.RandomRotation(degrees=15),
        # transforms.RandomHorizontalFlip(),
        transforms.CenterCrop(size=224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])

In [None]:
class ImageDataset(Dataset):
    def __init__(self, x, y, mode='train'): 
        self.x = x # img_path
        self.y = y # label
        self.mode = mode # train/test

    def __getitem__(self, index):
        path = self.x[index]
        x_pil = Image.open(path)
        if self.mode=='train': x_tensor = transform_aug_train(x_pil)
        elif self.mode == 'test': x_tensor = transform_aug_test(x_pil)
        return x_tensor, self.y[index]

    def __len__(self):
        return len(self.x)

In [None]:
batch_size = 2

epochs = 20
learning_rate = 1e-3
momentum = 0.9
weight_decay=0 # 1e-8

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
train_dataset = ImageDataset(train_x, train_y, mode='train')
train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_dataset = ImageDataset(test_x, test_y, mode='test')
test_loader = DataLoader(test_dataset, batch_size=batch_size)

In [None]:
# for data in test_dataset:
#     x, y = data
#     print(x.shape, y)

In [None]:
# train_x_list = []
# train_y_list = train_y
# for i in tqdm(range(len(train_x))):
#     path = train_x[i]
#     x_pil = Image.open(path)
#     x_tensor = transform_aug_train(x_pil)
#     train_x_list.append(x_tensor)

# test_x_list = []
# test_y_list = test_y
# for i in tqdm(range(len(test_x))):
#     path = test_x[i]
#     x_pil = Image.open(path)
#     x_tensor = transform_aug_test(x_pil)
#     test_x_list.append(x_tensor)

# pd.to_pickle({'x': train_x_list, 'y': train_y_list}, f'./train.pkl')
# pd.to_pickle({'x': test_x_list, 'y': test_y_list}, f'./test.pkl')

In [None]:
def mse_loss(y_pred, y_true):
    loss_fn = nn.MSELoss()
    return loss_fn(y_pred, y_true)

def focal_mse_loss(inputs, targets, activate='sigmoid', beta=.2, gamma=1):
    loss = (inputs - targets) ** 2
    loss *= (torch.tanh(beta * torch.abs(inputs - targets))) ** gamma if activate == 'tanh' else \
        (2 * torch.sigmoid(beta * torch.abs(inputs - targets)) - 1) ** gamma
    loss = torch.mean(loss)
    return loss

def huber_loss(inputs, targets, beta=1.):
    l1_loss = torch.abs(inputs - targets)
    cond = l1_loss < beta
    loss = torch.where(cond, 0.5 * l1_loss ** 2 / beta, l1_loss - 0.5 * beta)
    loss = torch.mean(loss)
    return loss

criterion = mse_loss

In [None]:
def train_epoch(model, dataloader, loss_fn, optimizer, scheduler):
    train_loss = []
    model.train()
    for step, data in enumerate(dataloader):
        batch_x, batch_y = data
        batch_x, batch_y = (
            batch_x.float().to(device),
            batch_y.float().to(device),
        )
        optimizer.zero_grad()
        output = model(batch_x)
        output = torch.squeeze(output, dim=0)
        # print(output.shape, batch_y.shape)
        loss = loss_fn(output, batch_y)
        train_loss.append(loss.item())
        loss.backward()
        optimizer.step()
    metric_train_loss = np.array(train_loss).mean()
    scheduler.step(metric_train_loss)
    return metric_train_loss

def val_epoch(model, dataloader):
    model.eval()
    with torch.no_grad():
        valid_pred_y = model(valid_x.to(device))
        metric_valid = mse_loss(valid_pred_y, valid_y.to(device)).item()
    return metric_valid

In [None]:
model = torchvision.models.resnet18(num_classes=1)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
hidden_dim = model.fc.in_features
out_dim = 1

model.fc = nn.Sequential(
    nn.Linear(hidden_dim, hidden_dim//4),
    nn.GELU(),
    nn.Linear(hidden_dim//4, out_dim),
    nn.Sigmoid()
)

model.load_state_dict(torch.load('./checkpoints/resnet18-f37072fd.pth'), strict=False)

In [None]:
best_score = 1e8
for epoch in range(epochs):
    # print(f'Running epoch {epoch} ...')
    train_loss = train_epoch(
        model,
        train_loader,
        criterion,
        optimizer,
        scheduler
    )
    metric_valid = val_epoch(model, test_loader)
    if metric_valid < best_score:
        best_score = metric_valid
        print("Saving best model ...")
        print("Val Score:", metric_valid)
        torch.save(
            model.state_dict(),
            f"./checkpoints/model.ckpt",
        )
    print(f"Epoch {epoch}: Loss = {train_loss}")