In [14]:
import os
from pathlib import Path

import pandas as pd
import numpy as np
from tqdm import tqdm
import random
# from openslide import OpenSlide

import torch
from torch import nn
from torch.utils.data import (
    ConcatDataset,
    DataLoader,
    Dataset,
    Subset,
    SubsetRandomSampler,
    TensorDataset,
    random_split,
)

import torchvision
from torchvision import transforms
from PIL import Image

# import einops

# from eval_metrics import print_metrics_regression
from sklearn import metrics as sklearn_metrics
from models.slice_att import SliceAttention

In [15]:
train_x = pd.read_pickle(f'./datasets_mini/train_x.pkl')
train_y = pd.read_pickle(f'./datasets_mini/train_y.pkl')
train_id = pd.read_pickle(f'./datasets_mini/train_id.pkl')

test_x = pd.read_pickle(f'./datasets_mini/test_x.pkl')
test_y = pd.read_pickle(f'./datasets_mini/test_y.pkl')
test_id = pd.read_pickle(f'./datasets_mini/test_id.pkl')

In [16]:
min_label = train_y.min().item()
max_label = train_y.max().item()
train_y = (train_y-min_label)/(max_label-min_label)
test_y = (test_y-min_label)/(max_label-min_label)

min_label, max_label

(0.0, 4.0)

In [17]:
def min_max_norm(x, min_label=min_label, max_label=max_label):
    return (x-min_label)/(max_label-min_label)

def reverse_min_max_norm(x, min_label=min_label, max_label=max_label):
    return x*(max_label-min_label)+min_label

In [18]:
# reverse normalization from Imagenet -> breast cancer dataset

transform_dataset = transforms.Compose([
        transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.255], std=[1/0.229, 1/0.224, 1/0.255]),
        transforms.Normalize([0.92740107, 0.90446373, 0.94529596], [0.02340832, 0.06800389, 0.04525188]),
    ])

In [19]:
class ImageDataset(Dataset):
    def __init__(self, x, y, biopsy_id):
        self.x = x # img_tensor_list
        self.y = y # label
        self.biopsy_id = np.array(biopsy_id)
        self.id_set = np.unique(self.biopsy_id)
    
    def __len__(self):
        return len(self.id_set)    
    
    def __getitem__(self, index, transform=False):
        cur_id = self.id_set[index]
        cur_x = self.x[self.biopsy_id == cur_id]
        cur_y = self.y[self.biopsy_id == cur_id][0]
        cur_len = len(cur_x)
        
        if transform:
            x_tensor = transform_dataset(cur_x)
            return cur_x, cur_y, cur_len
        return cur_x, cur_y, cur_len
    
def collate_fn(batch):
    imgs, labels, lens = zip(*batch)
    imgs = torch.cat(imgs, dim=0)
    labels = torch.stack(labels)
    lens = torch.tensor(list(lens), dtype=torch.int32)
    lens = torch.cumsum(lens, dim=0)
    return imgs, labels, lens

In [20]:
batch_size = 256

epochs = 50
learning_rate = 2e-4
momentum = 0.9
weight_decay = 0 # 1e-8

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [21]:
train_dataset = ImageDataset(train_x, train_y, train_id)
train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn)
test_dataset = ImageDataset(test_x, test_y, test_id)
test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_fn)

In [22]:
def mse_loss(y_pred, y_true):
    loss_fn = nn.MSELoss()
    return loss_fn(y_pred, y_true)

def focal_mse_loss(inputs, targets, activate='sigmoid', beta=.2, gamma=1):
    loss = (inputs - targets) ** 2
    loss *= (torch.tanh(beta * torch.abs(inputs - targets))) ** gamma if activate == 'tanh' else \
        (2 * torch.sigmoid(beta * torch.abs(inputs - targets)) - 1) ** gamma
    loss = torch.mean(loss)
    return loss

def huber_loss(inputs, targets, beta=1.):
    l1_loss = torch.abs(inputs - targets)
    cond = l1_loss < beta
    loss = torch.where(cond, 0.5 * l1_loss ** 2 / beta, l1_loss - 0.5 * beta)
    loss = torch.mean(loss)
    return loss

criterion = mse_loss

In [23]:
def train_epoch(model, model_att, dataloader, loss_fn, optimizer, scheduler):
    train_loss = []
    model.train()
    model_att.train()
    for step, data in enumerate(dataloader):
        batch_x, batch_y, batch_len = data
        batch_x, batch_y = (
            batch_x.float().to(device),
            batch_y.float().to(device),
        )
        optimizer.zero_grad()
        # print(batch_x.device, batch_x.shape)
        # print(next(model.parameters()).is_cuda)
        feature_vec = model(batch_x)
        output = model_att(feature_vec, batch_len)
        output = torch.squeeze(output, dim=1)
        
        loss = loss_fn(output, batch_y)
        train_loss.append(loss.item())
        loss.backward()
        optimizer.step()
    metric_train_loss = np.array(train_loss).mean()
    scheduler.step(metric_train_loss)
    return metric_train_loss

def val_epoch(model, model_att, dataloader):
    y_pred = []
    y_true = []
    model.eval()
    model_att.eval()
    with torch.no_grad():
        for step, data in enumerate(dataloader):
            # print(step)
            batch_x, batch_y, batch_len = data
            batch_x, batch_y = (
                batch_x.float().to(device),
                batch_y.float().to(device),
            )
            feature_vec = model(batch_x)
            output = model_att(feature_vec, batch_len)
            output = torch.squeeze(output, dim=1)
            output = output.detach().cpu().numpy().tolist()
            batch_y = batch_y.detach().cpu().numpy().tolist()
            y_pred.extend(output)
            y_true.extend(batch_y)
    y_pred = np.array(y_pred)
    y_true = np.array(y_true)
    y_pred = reverse_min_max_norm(y_pred)
    y_true = reverse_min_max_norm(y_true)

    mse = sklearn_metrics.mean_squared_error(y_true, y_pred)
    return mse

In [27]:
# model = torchvision.models.resnet18(num_classes=1)
model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
hidden_dim = model.fc.in_features
model.fc = nn.Sequential()
model_att = SliceAttention(hidden_dim, hidden_dim//8)
optimizer = torch.optim.Adam(list(model.parameters()) + list(model_att.parameters()), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)

def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False
# model.load_state_dict(torch.load('./checkpoints/resnet18-f37072fd.pth'), strict=False)
# model.load_state_dict(torch.load('./checkpoints/resnet50-11ad3fa6.pth'), strict=False)

model.to(device)
model_att.to(device)
# set_parameter_requires_grad(model, True)

In [28]:
best_score = 1e8
for epoch in range(epochs):
    # print(f'Running epoch {epoch} ...')
    train_loss = train_epoch(
        model,
        model_att,
        train_loader,
        criterion,
        optimizer,
        scheduler
    )
    print(f"Epoch {epoch}: Loss = {train_loss}")
    if epoch % 1 == 0:
        metric_valid = val_epoch(model, model_att, test_loader)
        print("Val Score:", metric_valid)
        if metric_valid < best_score:
            best_score = metric_valid
            print("Saving best model ...")
            torch.save(
                model.state_dict(),
                f"./checkpoints/model_resnet50_1028.ckpt",
            )
    

Epoch 0: Loss = 0.08266448974609375
Val Score: 0.49302292362571143
Saving best model ...
Epoch 1: Loss = 0.04256656765937805
Val Score: 0.37895327921013583
Saving best model ...
Epoch 2: Loss = 0.0419778898358345
Val Score: 0.5346572665292936
Epoch 3: Loss = 0.042613279074430466
Val Score: 0.5108245264618505
Epoch 4: Loss = 0.03430700674653053
Val Score: 0.3903572415563362
Epoch 5: Loss = 0.02525983192026615
Val Score: 0.31059334295862334
Saving best model ...
Epoch 6: Loss = 0.020161766558885574
Val Score: 0.30864520851371646
Saving best model ...
Epoch 7: Loss = 0.018717989325523376
Val Score: 0.3358035080638183
Epoch 8: Loss = 0.018421906977891922
Val Score: 0.3361505342605279
Epoch 9: Loss = 0.017290957272052765
Val Score: 0.31166993208170124
Epoch 10: Loss = 0.015002255327999592
Val Score: 0.2859653541918664
Saving best model ...
Epoch 11: Loss = 0.012182102538645267
Val Score: 0.2664742755606964
Saving best model ...
Epoch 12: Loss = 0.009751765988767147
Val Score: 0.275892621823

KeyboardInterrupt: 

In [13]:
best_score  #sigmoid

0.21758774127243202

In [26]:
best_score  #relu

0.19854353574712819