In [None]:
import os, sys
sys.path = ['../input/efficientnet-pytorch/EfficientNet-PyTorch/EfficientNet-PyTorch-master', ] + sys.path

In [None]:
#Basic Python and Machine learning libraries
import random, cv2
import pandas as pd
import numpy as np
import skimage.io
from tqdm.notebook import tqdm

#Pytorch and Albumentations(Data Augmentation Library)
import torch
import albumentations
from torch import nn
from torch.utils.data import Dataset, DataLoader

from efficientnet_pytorch import EfficientNet

In [None]:
CONFIG = {}
    
#change to choose coresponding dataset. VALID: 16, 25, 36
CONFIG['data_dir'] = '../input/prostate-cancer-grade-assessment/'
CONFIG['test_img_dir'] = os.path.join(CONFIG['data_dir'], 'test_images')

CONFIG['backbone'] = 'efficientnet-b1'
CONFIG['model_dir'] = ['../input/newpandamodels/efficientnet-b1_fold_0_epoch_7.pt',
                       '../input/newpandamodels/efficientnet-b1_fold_1_epoch_8.pt',
                       '../input/newpandamodels/efficientnet-b1_fold_2_epoch_9.pt',
                       '../input/newpandamodels/efficientnet-b1_fold_3_epoch_8.pt',
                       '../input/newpandamodels/efficientnet-b1_fold_4_epoch_8.pt']


#Add to config understandable definition of target size for BCE or CCE
CONFIG['sum_prediction'] = False
CONFIG['out_dim'] = 5 if CONFIG['sum_prediction'] else 6

CONFIG['image_size'] = 256
CONFIG['tile_size'] = 256
CONFIG['tile_mode'] = 0
CONFIG['n_tiles'] = 36
CONFIG['batch_size'] = 2
CONFIG['num_workers'] = 4

CONFIG['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#Image-net standard mean and std
CONFIG['mean'] = [0.485, 0.456, 0.406]
CONFIG['std'] = [0.229, 0.224, 0.225]

In [None]:
print(CONFIG['device'])

# DataSet

In [None]:
def get_tiles(img, tile_size, n_tiles, mode=0):
    result = []
    h, w, c = img.shape
    pad_h = (tile_size - h % tile_size) % tile_size + ((tile_size * mode) // 2)
    pad_w = (tile_size - w % tile_size) % tile_size + ((tile_size * mode) // 2)

    img = np.pad(img,[[pad_h // 2, pad_h - pad_h // 2], [pad_w // 2,pad_w - pad_w//2], [0,0]], constant_values=255)
    img = img.reshape(
            img.shape[0] // tile_size,
            tile_size,
            img.shape[1] // tile_size,
            tile_size,
            3
        )
    img = img.transpose(0,2,1,3,4).reshape(-1, tile_size, tile_size,3)
    
    if len(img) < n_tiles:
        img = np.pad(img,[[0,n_tiles-len(img)],[0,0],[0,0],[0,0]], constant_values=255)
    idxs = np.argsort(img.reshape(img.shape[0],-1).sum(-1))[:n_tiles]
    img = img[idxs]
    
    for i in range(len(img)):
        result.append({'img':img[i], 'idx':i})
    
    return result


In [None]:
class PANDA_Dataset(Dataset):
    def __init__(self,
                 df,
                 config,
                 img_transform=None
                ):

        self.df = df.reset_index(drop=True)
        self.config = config
        self.img_transform = img_transform

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index):
        img_path = os.path.join(self.config['test_img_dir'], self.df['image_id'].values[index]) + '.tiff'
        image_id = self.df['image_id'].values[index]
        image = skimage.io.MultiImage(img_path)[1]
        
        tiles = get_tiles(image, self.config['tile_size'], self.config['n_tiles'], self.config['tile_mode'])
        idxes = list(range(self.config['n_tiles']))
        
        n_row_tiles = int(np.sqrt(self.config['n_tiles']))
        images = np.zeros((self.config['image_size'] * n_row_tiles, self.config['image_size'] * n_row_tiles, 3))
        for h in range(n_row_tiles):
            for w in range(n_row_tiles):
                i = h * n_row_tiles + w
                
                tile_i = tiles[idxes[i]]['img']
                
                h1 = h * self.config['image_size']
                w1 = w * self.config['image_size']
                images[h1:h1+self.config['image_size'], w1:w1+self.config['image_size']] = tile_i
        
        if self.img_transform is not None:
            images = images.astype(np.float32)
            images = self.img_transform(image=images)['image']
        else:
            images = images.astype(np.float32)
            images /= 255
            
        images = images.transpose(2, 0, 1)
        
        return torch.tensor(images), image_id

# Model

In [None]:
class Model(nn.Module):
    def __init__(self, backbone, out_dim=6):
        super(Model, self).__init__()
        self.enet = EfficientNet.from_name('efficientnet-b1')
        
        self.fc = nn.Linear(self.enet._fc.in_features, out_dim)
        self.enet._fc = nn.Identity()
    
    def forward(self, x):
        x = self.enet(x)
        x = self.fc(x)
        return x

In [None]:
models = []
for path in CONFIG['model_dir']:
    model = Model(CONFIG['backbone'], CONFIG['out_dim'])
    model.to(CONFIG['device'])
    state_dict = torch.load(path, map_location=CONFIG['device'])
    model.load_state_dict(state_dict)
    model.eval()
    models.append(model)

del state_dict

# Prediction

In [None]:
test_img_transforms = albumentations.Compose([
    albumentations.Normalize(mean=CONFIG['mean'], std=CONFIG['std'], always_apply=True)
])

In [None]:
sample_path = os.path.join(CONFIG['data_dir'], 'sample_submission.csv')
sub_df = pd.read_csv(sample_path)
if os.path.exists(CONFIG['test_img_dir']):
    
    test_path = os.path.join(CONFIG['data_dir'], 'test.csv')
    test_df = pd.read_csv(test_path)
    
    test_ds = PANDA_Dataset(test_df, CONFIG, test_img_transforms)
    test_dl = DataLoader(test_ds, batch_size=CONFIG['batch_size'], num_workers=CONFIG['num_workers'], shuffle=False)
    image_id_list, preds_list = [],[]

    with torch.no_grad():
        for data, image_id in tqdm(test_dl):
            data = data.to(CONFIG['device'])
            # NEW
            data = torch.stack([data, data.flip(-1), data.flip(-2), data.flip(-1,-2),
                                data.transpose(-1,-2), data.transpose(-1,-2).flip(-1),
                                data.transpose(-1,-2).flip(-2),data.transpose(-1,-2).flip(-1,-2)],
                               1)
            data = data.view(-1, 3, 1536, 1536)
            # OLD
#             logits = model(data)
#             preds = torch.argmax(logits, dim=1)

#             image_id_list.append(image_id)
#             preds_list.append(preds)
            # NEW
            preds = [model(data) for model in models]
            preds = torch.stack(preds, 1)
            preds = preds.view(CONFIG['batch_size'], 8*len(models), -1).mean(1).argmax(-1)
            
            image_id_list.append(image_id)
            preds_list.append(preds)
        
    preds_list = torch.cat(preds_list).cpu().numpy()
    image_id_list = np.concatenate(image_id_list)
    
    sub_df = pd.DataFrame({'image_id': image_id_list, 'isup_grade': preds_list})
    sub_df.to_csv('submission.csv', index=False)

In [None]:
sub_df.to_csv("submission.csv", index=False)
sub_df.head()