In [11]:
import torch
import torchvision
import torch.nn as nn

import pandas as pd
from typing import (Dict, IO, List, Tuple)
from pathlib import Path

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
USE_GPU = True
dtype = torch.float32 # we will be using float throughout this tutorial

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print('using device:', device)

using device: cuda


In [3]:
MODEL_PATH = 'saved_models/resnet18_e10_va0.99230.pt'
RETURN_PREACTIVATION = True  # return features from the model, if false return classification logits
NUM_CLASSES = 3  # only used if RETURN_PREACTIVATION = False


In [4]:
def load_model_weights(model, weights):

    model_dict = model.state_dict()
    weights = {k: v for k, v in weights.items() if k in model_dict}
    if weights == {}:
        print('No weight could be loaded..')
    model_dict.update(weights)
    model.load_state_dict(model_dict)

    return model

In [5]:
model = torchvision.models.__dict__['resnet18'](pretrained=False)
state_dict = torch.load(MODEL_PATH, map_location='cuda:0')['model_state_dict']


In [6]:
# for key in list(state_dict.keys()):
#     state_dict[key.replace('model.', '').replace('resnet.', '')] = state_dict.pop(key)

model = load_model_weights(model, state_dict)

if torch.cuda.is_available():
    model.cuda()

for param in model.parameters():
    param.requires_grad = False


In [35]:
from dataloader import WsiDataset
from torch.utils.data import DataLoader

PATH_IMG = "/kuacc/users/skoc21/dataset/pannet/wsi/wsi-tiles/HE__20211011_160350_(3)/"
fname = PATH_IMG.split('/')[-2]

output_folder = Path('output_folder')
image_folder = Path(fname)

dataset = WsiDataset(img_dir=PATH_IMG, extension='png')
dataloader_ts = DataLoader(dataset, batch_size=64)
# lst_test_idx = pd.read_csv('test_idx.csv')['test_idx'].to_list()
# test_sampler = torch.utils.data.SubsetRandomSampler(lst_test_idx)
# dataloader_ts = DataLoader(dataset, batch_size=64, sampler=test_sampler)



In [21]:
correct_ts = 0
with torch.no_grad():
    model.eval()
    for data_dict_ts in dataloader_ts:
        data_ts, target_ts = data_dict_ts['image'].to(device), data_dict_ts['ann'].to(device)
        outputs_ts = model(data_ts)
        confidence,pred_ts = torch.max(nn.Softmax(dim=1)(outputs_ts), dim=1)
        print(confidence,pred_ts)
        correct_ts += torch.sum(pred_ts==target_ts).item()
    print(100 * correct_ts/len(lst_test_idx))

KeyError: 'ann'

In [36]:
with output_folder.joinpath(f"{image_folder.name}.csv").open(
        mode="w") as writer:

    writer.write("tile,x,y,prediction,confidence\n")

    # Loop through all of the patches.
    for batch_num, data_dict_ts in enumerate(dataloader_ts):
        batch_window_names = data_dict_ts['name_img']

        confidences, test_preds = torch.max(nn.Softmax(dim=1)(model(data_dict_ts['image'].to(device=device))), dim=1)
        for i in range(test_preds.shape[0]):
            # Find coordinates and predicted class.
            xy = batch_window_names[i].split("_")[-1].split(".")[0].split('-')[:2]
            writer.write(
                f"{','.join([batch_window_names[i].split('/')[-1], xy[0], xy[1], f'{test_preds[i].data.item()}', f'{confidences[i].data.item():.5f}'])}\n"
            )