In [1]:
import torch.nn as nn
import pandas as pd
import torch
import openslide
import sys
import imgaug.augmenters as iaa
#sys.path.append('../brown-datathon/src')
from config import Config
from models.scse_pyramid_unet import UNet
import numpy as np

In [None]:
def iou(pred, target, n_classes=2):
    ious = []
    #pred = torch.round(pred.view(-1)).long()
    pred = pred.ravel()
    

    # Ignore IoU for background class ("0")
    for cls in range(0, n_classes):  # This goes from 1:n_classes-1 -> class "0" is ignored
        this_target = np.squeeze(target[:,:,cls])
        this_target = this_target.ravel()
        pred_inds = pred == (cls + 1)
        target_inds = this_target == 1
        # Cast to long to prevent overflows       
        intersection = len(this_target[pred_inds & target_inds])  
        union = len(pred[pred_inds]) + len(this_target[target_inds]) - \
                intersection

        if union == 0:
            ious.append(0.)  # If there is no ground truth, do not include in evaluation
        else:
            ious.append(float(intersection) / float(max(union, 1)))
    return ious
    

In [2]:
reader = openslide.OpenSlide('285383.svs')
whole_region = reader.read_region((0,0),1,reader.level_dimensions[1])
whole_region = np.array(whole_region.getdata()).reshape(*reader.level_dimensions[1][::-1],4)[:,:,:3]
bbox = pd.read_csv('285383.csv')
tangle_threads = bbox[bbox['label'].isin(['tangle','threads'])]
top_left = np.floor(tangle_threads[['x','y']]).astype(np.uint32).values
bottom_right = np.ceil(tangle_threads[['x','y']].values + tangle_threads[['w','h']].values).astype(np.uint32)
label = np.where(tangle_threads['label'] == 'tangle',1,2)
mask = np.zeros_like(whole_region[:,:,:2])
for i in range(2):
     this_top_left = top_left[label == (i + 1)]
     this_bottom_right = bottom_right[label == (i + 1)]
     for j,tl in enumerate(this_top_left):
            br = this_bottom_right[j]
            mask[tl[1]:br[1],tl[0]:br[0],i] = 1
padder = iaa.PadToFixedSize(512,512,position='right-bottom',pad_cval=255)
patches = []
for i in range(whole_region.shape[0] // 512 + 1):
    for j in range(whole_region.shape[1] // 512 + 1):
        patch = whole_region[i * 512: (i + 1) * 512,j * 512: (j + 1) * 512] 
        if i == whole_region.shape[0] // 512  or j == whole_region.shape[1] // 512 :
            patch = padder.augment_image(patch)
        patch = patch / 255
        patches.append(patch) 
input_patches = np.transpose(np.stack(patches),(0,3,1,2))

In [51]:
models_list = ['no_shake_drop_no_deepcut','with_shake_drop_no_deepcut','no_shake_drop_deepcut','shake_drop_deepcut']
device = torch.device('cuda')
ious = []
for i,model in enumerate(models_list):
    shake_drop = False if i % 2 == 0 else True
    model = UNet(84, 3, 3,
            shake_drop, True, 4, 2)
    model = model.to(device)
    checkpoint = torch.load('../brown-datathon/src/training_logs/' + model + '/best.pth.tar')
    model.load_state_dict(checkpoint.state_dict())
    preds = []
    with torch.no_grad():
        model.eval()
        for i in range(whole_region.shape[1] // 512 + 1):
            this_patches = torch.from_numpy(input_patches[i * (whole_region.shape[1] // 512 + 1):(i + 1) * (whole_region.shape[1] // 512 + 1)]).float().to(device)
            output = model(this_patches)
            pred = torch.argmax(output,1).long().cpu().data.numpy()
            preds.append(pred)
            this_patches.to('cpu')
            del this_patches
    catted = np.concatenate(preds)
    recon_mask = np.vstack([np.hstack(catted[i * (whole_region.shape[1] // 512 + 1): (i + 1) * (whole_region.shape[1] // 512 + 1)]) for i in range(whole_region.shape[0] // 512 + 1)])
    recon_mask = recon_mask[:reader.level_dimensions[1][1],:reader.level_dimensions[1][0]]
    ious.append(iou(recon_mask,mask))

'\nwith torch.no_grad():\n    model.eval()\n    output = model(torch.from_numpy(t_region).float().to(device))\n    pred = torch.squeeze(torch.argmax(output,1).long()).cpu().data.numpy()\n    conts = []\n    for cl in range(2):\n        pred[pred != (cl + 1)] = 0\n        contours, hierarchy = cv2.findContours(pred.astype(np.uint8),cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)\n        conts.append(contours)\n        #y,x = np.mgrid[0:512,0:512]\n        #coord = np.vstack([x.ravel(),y.ravel()]).T\n        #print(x[pred == (cl + 1)])\n        #print(y[pred == (cl + 1)])\n        #print([np.min(pred[:, 1]) - 1, np.min(pred[:, 2]) - 1, np.max(pred[:, 1])+1,np.max(pred[:, 2]) + 1])\n'