In [None]:
import cv2
from tqdm import tqdm_notebook as tqdm
import fastai
from fastai.vision import *
import os
from mish_activation import *
import warnings
warnings.filterwarnings("ignore")
import skimage.io
import numpy as np
import pandas as pd
from bisect import bisect_right

In [None]:
sz = 128
bs = 1
N = 128
nworkers = 2

#DATA = '../input/prostate-cancer-grade-assessment/train_images/'
#TEST = '../input/prostate-cancer-grade-assessment/train.csv'
DATA = '../input/prostate-cancer-grade-assessment/test_images'
TEST = '../input/prostate-cancer-grade-assessment/test.csv'
SAMPLE = '../input/prostate-cancer-grade-assessment/sample_submission.csv'
MODELS = [f'../input/panda-init-class-model1/RNXT50_128k_0_{i}.pth' for i in range(4)] + \
         [f'../input/panda-init-class-model1/RNXT50_128k_3_{i}.pth' for i in range(4)] + \
         [f'../input/panda-init-class-model1/RNXT50_128kr_1_{i}.pth' for i in range(4)] + \
         [f'../input/panda-init-class-model1/RNXT50_128kr_2_{i}.pth' for i in range(4)] + \
         [f'../input/panda-init-class-model1/RNXT50_128kr_2feature_{i}.pth' for i in range(4)] + \
         [f'../input/panda-init-class-model1/RNXT50_128kr_9_{i}.pth' for i in range(4)] + \
         [f'../input/panda-init-class-model1/RNXT50_128kr_7c_{i}.pth' for i in range(4)] + \
         [f'../input/panda-init-class-model1/RNXT50_128kr_3_{i}.pth' for i in range(4)]
ws = [1,1,1,6,1,1,1,1]
ws = [w for w in ws for k in range(4)]

In [None]:
from torchvision.models.resnet import ResNet, Bottleneck

class Model(nn.Module):
    def __init__(self, arch='resnext50_32x4d', n=11, pre=True):
        super().__init__()
        m = ResNet(Bottleneck, [3, 4, 6, 3], groups=32, width_per_group=4)
        self.enc = nn.Sequential(*list(m.children())[:-2])       
        nc = list(m.children())[-1].in_features
        self.head = nn.Sequential(AdaptiveConcatPool2d(),Flatten(),
                                  nn.Linear(2*nc,512),Mish(),nn.GroupNorm(32,512),
                                  nn.Dropout(0.5),nn.Linear(512,n))
        
    def forward(self, x):
        shape = x.shape
        n = shape[1]
        x = x.view(-1,shape[2],shape[3],shape[4])
        x = self.enc(x)
        shape = x.shape
        x = x.view(-1,n,shape[1],shape[2],shape[3]).permute(0,2,1,3,4).contiguous()\
          .view(-1,shape[1],shape[2]*n,shape[3])
        x = self.head(x)
        return x[:,:1]

In [None]:
models = []
for path in MODELS:
    state_dict = torch.load(path,map_location=torch.device('cpu'))
    model = Model()
    model.load_state_dict(state_dict)
    model.float()
    model.eval()
    model.cuda()
    models.append(model)

del state_dict

In [None]:
def tile(img):
    shape = img.shape
    pad0,pad1 = (sz - shape[0]%sz)%sz, (sz - shape[1]%sz)%sz
    img = np.pad(img,[[pad0//2,pad0-pad0//2],[pad1//2,pad1-pad1//2],[0,0]],constant_values=255)
    img = img.reshape(img.shape[0]//sz,sz,img.shape[1]//sz,sz,3)
    img = img.transpose(0,2,1,3,4).reshape(-1,sz,sz,3)
    if len(img) < N:
        img = np.pad(img,[[0,N-len(img)],[0,0],[0,0],[0,0]],constant_values=255)
    idxs = np.argsort(img.reshape(img.shape[0],-1).sum(-1))[:N]
    img = img[idxs]
    return img

mean = torch.Tensor([1.0-0.85506157, 1.0-0.7035249, 1.0-0.80203127])
std = torch.Tensor([0.40011922, 0.52504386, 0.42675745])

class PandaDataset(Dataset):
    def __init__(self, path, test):
        self.path = path
        self.names = list(pd.read_csv(test).image_id)

    def __len__(self):
        return len(self.names)

    def __getitem__(self, idx):
        name = self.names[idx]
        img = skimage.io.MultiImage(os.path.join(DATA,name+'.tiff'))[1]
        tiles = torch.Tensor((255 - tile(img))/255.0)
        tiles = (tiles - mean)/std
        return tiles.permute(0,3,1,2), name

In [None]:
ths = np.array([1.03125,1.03125,0.8125,0.90625,1.0859375]).cumsum()
#ths = np.array([1.0,1.0,1.0,1.0,1.0]).cumsum()
sub_df = pd.read_csv(SAMPLE)
if os.path.exists(DATA):
    bs=2
    ds = PandaDataset(DATA,TEST)
    dl = DataLoader(ds, batch_size=bs, num_workers=nworkers, shuffle=False)
    names,preds = [],[]

    with torch.no_grad():
        for x,y in tqdm(dl):
            x = x.cuda()
            b = x.shape[0]
            #dihedral TTA
            #x = torch.stack([x,x.flip(-1),x.flip(-2),x.flip(-1,-2),x.transpose(-1,-2),\
            #  x.transpose(-1,-2).flip(-1), x.transpose(-1,-2).flip(-2),\
            #  x.transpose(-1,-2).flip(-1,-2)],1)
            x = torch.stack([x,x.flip(-1),x.flip(-2),x.flip(-1,-2),x.transpose(-1,-2),\
              x.transpose(-1,-2).flip(-1)],1)
            n_tta = 6
            x = x.view(-1,N,3,sz,sz)
            p = [model(x) for model in models]
            p = torch.stack(p,1)
            p = p.view(b,n_tta,len(models)).mean(1).cpu()
            p = 6.0*torch.sigmoid(p)
            
            for i in range(b):
                pred = []
                for pi in p[i]: pred.append(bisect_right(ths, pi.numpy()))
                preds.append(np.argmax(np.bincount(pred,ws)))
           
            names.append(y)
    
    names = np.concatenate(names)
    sub_df = pd.DataFrame({'image_id': names, 'isup_grade': preds})

In [None]:
sub_df.to_csv("submission.csv", index=False)
sub_df.head()