In [1]:
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import skimage.io
import albumentations as A

import random
import sys
import os

sys.path = [
    '../input/efficientnet-pytorch/EfficientNet-PyTorch/EfficientNet-PyTorch-master',
] + sys.path

from efficientnet_pytorch import EfficientNet
from tqdm.notebook import tqdm
import numpy as np

In [2]:
class Config:
    
    width=224
    height=224
    dim = (width, height)
    
    batch_size = 1
    n_classes = 6
    seed = 42
    
    folds = 5
    model_name = 'efficientnet-b0'
    pretrained_model = ['../input/checkpoint-fold1/checkpoint_Fold1.pt', '../input/checkpoint-fold2/checkpoint_Fold2.pt',
                        '../input/checkpoint-fold3/checkpoint_Fold3.pt', '../input/checkpoint-fold4/checkpoint_Fold4.pt',
                        '../input/checkpoint-fold5/checkpoint_Fold5.pt']

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
seed_everything(Config.seed)

In [3]:
path = '../input/prostate-cancer-grade-assessment'
train_file = 'train_images'
test_file = 'test_images'

In [4]:
N = 10
if os.path.exists('../input/prostate-cancer-grade-assessment/test_images'):
    subm = pd.read_csv(os.path.join(path, 'sample_submission.csv'))[['image_id']]
    image_folder = os.path.join(path, test_file)
else:
    subm = pd.read_csv(os.path.join(path, 'train.csv')).loc[:N-1 ,['image_id']]
    image_folder = os.path.join(path, train_file)

In [5]:
def PlantModel(model_name=Config.model_name):
    model = EfficientNet.from_name(Config.model_name)
    fc = model._fc.in_features
    model._fc = nn.Sequential(nn.Linear(fc, 1000, bias=True),
                              nn.ReLU(),
                              nn.Dropout(p=.5),
                              nn.Linear(1000, Config.n_classes, bias=True))
    
    return model

In [6]:
tile_size = image_size = 256
n_tiles = 16

def get_tiles(img, mode=0):
        result = []
        h, w, c = img.shape
        pad_h = (tile_size - h % tile_size) % tile_size + ((tile_size * mode) // 2)
        pad_w = (tile_size - w % tile_size) % tile_size + ((tile_size * mode) // 2)

        img2 = np.pad(img,[[pad_h // 2, pad_h - pad_h // 2], [pad_w // 2,pad_w - pad_w//2], [0,0]], constant_values=255)
        img3 = img2.reshape(
            img2.shape[0] // tile_size,
            tile_size,
            img2.shape[1] // tile_size,
            tile_size,
            3
        )

        img3 = img3.transpose(0,2,1,3,4).reshape(-1, tile_size, tile_size,3)
        n_tiles_with_info = (img3.reshape(img3.shape[0],-1).sum(1) < tile_size ** 2 * 3 * 255).sum()
        if len(img) < n_tiles:
            img3 = np.pad(img3,[[0,N-len(img3)],[0,0],[0,0],[0,0]], constant_values=255)
        idxs = np.argsort(img3.reshape(img3.shape[0],-1).sum(-1))[:n_tiles]
        img3 = img3[idxs]
        for i in range(len(img3)):
            result.append({'img':img3[i], 'idx':i})
        return result, n_tiles_with_info >= n_tiles


class PlantData(Dataset):
    def __init__(self,
                 df,
                 image_size,
                 n_tiles=n_tiles,
                 tile_mode=0,
                 rand=False,
                 sub_imgs=False,
                 transform=None
                ):

        self.df = df.reset_index(drop=True)
        self.image_size = image_size
        self.n_tiles = n_tiles
        self.tile_mode = tile_mode
        self.rand = rand
        self.sub_imgs = sub_imgs
        self.transform = transform

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index):
        row = self.df.iloc[index]
        img_id = row.image_id
        
        tiff_file = os.path.join(image_folder, f'{img_id}.tiff')
        image = skimage.io.MultiImage(tiff_file)[1]
        tiles, OK = get_tiles(image, self.tile_mode)

        if self.rand:
            idxes = np.random.choice(list(range(self.n_tiles)), self.n_tiles, replace=False)
        else:
            idxes = list(range(self.n_tiles))
        idxes = np.asarray(idxes) + self.n_tiles if self.sub_imgs else idxes

        n_row_tiles = int(np.sqrt(self.n_tiles))
        images = np.zeros((image_size * n_row_tiles, image_size * n_row_tiles, 3))
        for h in range(n_row_tiles):
            for w in range(n_row_tiles):
                i = h * n_row_tiles + w
    
                if len(tiles) > idxes[i]:
                    this_img = tiles[idxes[i]]['img']
                else:
                    this_img = np.ones((self.image_size, self.image_size, 3)).astype(np.uint8) * 255
                this_img = 255 - this_img
                if self.transform is not None:
                    this_img = self.transform(image=this_img)['image']
                h1 = h * image_size
                w1 = w * image_size
                images[h1:h1+image_size, w1:w1+image_size] = this_img
        
        if self.transform is not None:
            images = self.transform(image=images)['image']
        images = images.astype(np.float32)
        images /= 255
        images = images.transpose(2, 0, 1)

        return torch.tensor(images)

In [7]:
transforms_test = A.Compose([
    A.Transpose(p=0.7),
    A.VerticalFlip(p=0.7),
    A.HorizontalFlip(p=0.7),
])

dataset_test = PlantData(subm, image_size , n_tiles, 0, transform=transforms_test)
test_dataloader = DataLoader(dataset_test, batch_size=Config.batch_size, shuffle=False, num_workers=4)

In [8]:
submission_values = {'Fold_1':[], 'Fold_2':[], 'Fold_3':[], 'Fold_4':[], 'Fold_5':[]}
tta_count = 10

for fold in range(Config.folds):
    print('#'*60)
    print('#'*60)
    print('\t\t\tFOLD {}/{}'.format(fold+1, Config.folds))
    print('#'*60)
    print('#'*60)

    test_end = 0

    model = PlantModel()
    model.load_state_dict(torch.load(Config.pretrained_model[fold]))
    model.to(Config.device)

    model.eval()
    for tta in range(tta_count):
        test_preds = None
        for images in tqdm(test_dataloader):
            images = images.to(Config.device)
            outputs = model(images)

            with torch.no_grad():
                if test_preds is None:
                    test_preds = outputs.detach().cpu()
                else:
                    test_preds = torch.cat((test_preds, outputs.detach().cpu()), dim=0)
        test_end += test_preds
    submission_values[f'Fold_{fold+1}'] = test_end / tta_count
        
subm['isup_grade'] = (submission_values['Fold_1'].argmax(1) + submission_values['Fold_2'].argmax(1) + \
                     submission_values['Fold_3'].argmax(1) + submission_values['Fold_4'].argmax(1) + \
                     submission_values['Fold_5'].argmax(1)) / Config.folds

subm[['image_id','isup_grade']].to_csv('submission.csv', index=False)

############################################################
############################################################
			FOLD 1/5
############################################################
############################################################


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))


############################################################
############################################################
			FOLD 2/5
############################################################
############################################################


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))


############################################################
############################################################
			FOLD 3/5
############################################################
############################################################


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))


############################################################
############################################################
			FOLD 4/5
############################################################
############################################################


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))


############################################################
############################################################
			FOLD 5/5
############################################################
############################################################


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))






In [9]:
submission_values

{'Fold_1': tensor([[  4.1607,  -4.0861,  -6.4145,  -8.3567,  -7.3244,  -8.2743],
         [  4.1973,  -3.8506,  -7.0716,  -8.7232,  -8.6588,  -8.7205],
         [ -3.7808,  -7.7207,  -8.7204,  -5.1375,   0.6321,  -2.1111],
         [ -4.7518,  -6.2061,  -7.6633,  -4.5102,   0.9932,  -1.5190],
         [  5.4936,  -5.1059,  -8.4454, -10.8253, -10.2800, -10.8516],
         [ -0.6583,   0.2306,  -3.5075,  -5.7548,  -4.7117,  -5.8966],
         [ -1.2063,   0.0173,  -2.9847,  -5.5254,  -6.0779,  -7.1173],
         [ -9.7274,  -0.2168,  -0.0382,  -3.3017,  -4.4470,  -6.9851],
         [-10.7833,   4.9732,  -5.2913,  -8.6701, -10.3810, -13.7862],
         [  5.0880,  -4.7157,  -8.3293,  -9.9800, -10.1440, -10.3200]]),
 'Fold_2': tensor([[  3.3364,  -3.1619,  -5.3122,  -6.6518,  -6.9792,  -7.3209],
         [  4.3792,  -4.1816,  -7.4323,  -8.7920,  -9.9312,  -9.9225],
         [ -9.4010, -10.4474,  -8.0928,  -4.8355,  -0.1657,   0.4337],
         [ -7.1042,  -8.2404,  -7.1517,  -5.0202,  -0.4

In [10]:
subm

Unnamed: 0,image_id,isup_grade
0,0005f7aaab2800f6170c399693a96917,0
1,000920ad0b612851f8e01bcc880d9b3d,0
2,0018ae58b01bdadc8e347995b69f99aa,4
3,001c62abd11fa4b57bf7a6c603a11bb9,4
4,001d865e65ef5d2579c190a0e0350d8f,0
5,002a4db09dad406c85505a00fb6f6144,1
6,003046e27c8ead3e3db155780dc5498e,1
7,0032bfa835ce0f43a92ae0bbab6871cb,2
8,003a91841da04a5a31f808fb5c21538a,1
9,003d4dd6bd61221ebc0bfb9350db333f,0
