In [1]:
import os
from pathlib import Path

import pandas as pd
import numpy as np
from tqdm import tqdm
# from openslide import OpenSlide

import torch
from torch import nn
from torch.utils.data import (
    ConcatDataset,
    DataLoader,
    Dataset,
    Subset,
    SubsetRandomSampler,
    TensorDataset,
    random_split,
)

import torchvision
from torchvision import transforms
from PIL import Image

# import einops

# from eval_metrics import print_metrics_regression
from sklearn import metrics as sklearn_metrics
from models.swin_transformer import SwinTransformer



In [2]:
holdout = pd.read_pickle("./datasets/holdout.pkl")
holdout_x = holdout["x"]
holdout_y = holdout["y"]
holdout_id = holdout["id"]
holdout_x = torch.tensor(torch.stack(holdout_x).detach().cpu().numpy())
holdout_y = torch.tensor(holdout_y)

In [3]:
min_label = 0.
max_label = 4.

def reverse_min_max_norm(x, min_label=min_label, max_label=max_label):
    return x*(max_label-min_label)+min_label

In [4]:
class ImageDataset(Dataset):
    def __init__(self, x, y, biopsy_id):
        self.x = x # img_tensor_list
        self.y = y # label
        self.biopsy_id = biopsy_id

    def __getitem__(self, index):
        return self.x[index], self.y[index], self.biopsy_id[index]

    def __len__(self):
        return len(self.x)

In [5]:
batch_size = 32

epochs = 30
learning_rate = 1e-5
momentum = 0.9
weight_decay = 0.0 # 1e-8

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [6]:
holdout_dataset = ImageDataset(holdout_x, holdout_y, holdout_id)
holdout_loader = DataLoader(holdout_dataset, batch_size=batch_size)

In [7]:
model = SwinTransformer()

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
hidden_dim = model.head.in_features
out_dim = 1

model.head = nn.Sequential(
    nn.Linear(hidden_dim, hidden_dim//16),
    nn.GELU(),
    nn.Linear(hidden_dim//16, out_dim),
    # nn.Linear(hidden_dim, out_dim),
    nn.Sigmoid()
)

model.load_state_dict(torch.load('checkpoints/model_swin.ckpt'), strict=False)

model.to(device)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


SwinTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 192, kernel_size=(4, 4), stride=(4, 4))
    (norm): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (layers): ModuleList(
    (0): BasicLayer(
      dim=192, input_resolution=(56, 56), depth=2
      (blocks): ModuleList(
        (0): SwinTransformerBlock(
          dim=192, input_resolution=(56, 56), num_heads=6, window_size=7, shift_size=0, mlp_ratio=4.0
          (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            dim=192, window_size=(7, 7), num_heads=6
            (qkv): Linear(in_features=192, out_features=576, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=192, out_features=192, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): Identity()
          (norm2)

In [8]:
def test_epoch(model, dataloader):
    y_pred = {} # key: biopsy_id, value: List[slice_stage_pred]
    model.eval()
    with torch.no_grad():
        for step, data in enumerate(dataloader):
            # print(step)
            batch_x, batch_y, batch_biopsy_id = data
            batch_x, batch_y = (
                batch_x.float().to(device),
                batch_y.float().to(device),
            )
            output = model(batch_x)
            output = torch.squeeze(output, dim=1)
            output = output.detach().cpu().numpy().tolist()

            for i in range(len(batch_biopsy_id)):
                biopsy_id = batch_biopsy_id[i]
                if biopsy_id not in y_pred:
                    y_pred[biopsy_id] = []
                y_pred[biopsy_id].append(output[i])
    
    submit_result_dict = {}
    for biopsy_id in y_pred:
        preds = np.array(y_pred[biopsy_id])
        submit_result_dict[biopsy_id] = reverse_min_max_norm(preds.mean())
    return submit_result_dict

In [9]:
submit_result_dict = test_epoch(model, holdout_loader)

In [10]:
biopsy_id_list = []
biopsy_stage_list = []

for biopsy_id in submit_result_dict:
    biopsy_id_list.append(biopsy_id)
    biopsy_stage_list.append(submit_result_dict[biopsy_id])

In [11]:
import csv    

with open("submit_1027.csv", "w") as infile:
    writer = csv.writer(infile)
    # writer.writerow(["header01", "header02"])
    for i in zip(biopsy_id_list, biopsy_stage_list):
        writer.writerow(i)