In [None]:
import cv2
from tqdm import tqdm_notebook as tqdm
import fastai
from fastai.vision import *
import os
from mish_activation import *
import warnings
warnings.filterwarnings("ignore")
import skimage.io
import numpy as np
import pandas as pd
sys.path.insert(0, '../input/semisupervised-imagenet-models/semi-supervised-ImageNet1K-models-master/')
#hubconf = '../input/semisupervised-imagenet-models/semi-supervised-ImageNet1K-models-master/hubconf.py'
from hubconf import *
from tqdm.notebook import tqdm
from sklearn.utils import shuffle
import matplotlib.pyplot as plt
from skimage.transform import resize

In [None]:
!pip install tifffile

In [None]:
!pip install ../input/installed-packages/imagecodecs-2020.5.30-cp37-cp37m-manylinux2014_x86_64.whl

In [None]:
!pip install ../input/installed-packages/spams-2.6.1-cp37-cp37m-linux_x86_64.whl

In [None]:
import tifffile
import imagecodecs

In [None]:
TRAIN_DATA = '../input/prostate-cancer-grade-assessment/train_images/'
TEST_DATA = '../input/prostate-cancer-grade-assessment/test_images/'
TRAIN = '../input/prostate-cancer-grade-assessment/train.csv'
TEST = '../input/prostate-cancer-grade-assessment/test.csv'
SAMPLE = '../input/prostate-cancer-grade-assessment/sample_submission.csv'
MODELS = [f'../input/panda-stain-norm-downsampling-12x64x64-models/RNXT50_{i}.pth' for i in range(4)]

sz = 128
bs = 1
N = 12
nworkers = 2

downsize = (N,64,64,3)

# Model

In [None]:
def _resnext(url, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
    #state_dict = load_state_dict_from_url(url, progress=progress)
    #model.load_state_dict(state_dict)
    return model

class Model(nn.Module):
    def __init__(self, arch='resnext50_32x4d', n=6, pre=True):
        super().__init__()
        #m = torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', arch)
        m = _resnext(semi_supervised_model_urls[arch], Bottleneck, [3, 4, 6, 3], False, 
                progress=False,groups=32,width_per_group=4)
        self.enc = nn.Sequential(*list(m.children())[:-2])       
        nc = list(m.children())[-1].in_features
        self.head = nn.Sequential(AdaptiveConcatPool2d(),Flatten(),nn.Linear(2*nc,512),
                Mish(),nn.BatchNorm1d(512),nn.Dropout(0.5),nn.Linear(512,n))
        
    def forward(self, x):
        shape = x.shape
        n = shape[1]
        x = x.view(-1,shape[2],shape[3],shape[4])
        x = self.enc(x)
        shape = x.shape
        x = x.view(-1,n,shape[1],shape[2],shape[3]).permute(0,2,1,3,4).contiguous()\
          .view(-1,shape[1],shape[2]*n,shape[3])
        x = self.head(x)
        x = F.softmax(x,dim=1)
        return x

In [None]:
models = []
for path in MODELS:
    state_dict = torch.load(path,map_location=torch.device('cpu'))
    model = Model()
    model.load_state_dict(state_dict)
    model.float()
    model.eval()
    model.cuda()
    models.append(model)

del state_dict

# Stain Normalization

In [None]:
# STAIN NORMALIZATION FUNCTIONS
import spams
class TissueMaskException(Exception):
    pass

######################################################################################################

def is_uint8_image(I):

    if not is_image(I):
        return False
    if I.dtype != np.uint8:
        return False
    return True
######################################################################################################

def is_image(I):

    if not isinstance(I, np.ndarray):
        return False
    if not I.ndim == 3:
        return False
    return True
######################################################################################################

def get_tissue_mask(I, luminosity_threshold=0.8):

    I_LAB = cv2.cvtColor(I, cv2.COLOR_RGB2LAB)
    L = I_LAB[:, :, 0] / 255.0  # Convert to range [0,1].
    mask = L < luminosity_threshold

    # Check it's not empty
    if mask.sum() == 0:
        raise TissueMaskException("Empty tissue mask computed")

    return mask

######################################################################################################

def convert_RGB_to_OD(I):

    mask = (I == 0)
    I[mask] = 1
    

    #return np.maximum(-1 * np.log(I / 255), 1e-6)
    return np.maximum(-1 * np.log(I / 255), np.zeros(I.shape) + 0.1)

######################################################################################################

def convert_OD_to_RGB(OD):

    assert OD.min() >= 0, "Negative optical density."
    
    OD = np.maximum(OD, 1e-6)
    
    return (255 * np.exp(-1 * OD)).astype(np.uint8)

######################################################################################################

def normalize_matrix_rows(A):

    return A / np.linalg.norm(A, axis=1)[:, None]

######################################################################################################


def get_concentrations(I, stain_matrix, regularizer=0.01):

    OD = convert_RGB_to_OD(I).reshape((-1, 3))
    return spams.lasso(X=OD.T, D=stain_matrix.T, mode=2, lambda1=regularizer, pos=True).toarray().T

######################################################################################################

def get_stain_matrix(I, luminosity_threshold=0.8, angular_percentile=99):
    
    #assert is_uint8_image(I), "Image should be RGB uint8."
    # Convert to OD and ignore background
    tissue_mask = get_tissue_mask(I, luminosity_threshold=luminosity_threshold).reshape((-1,))
    OD = convert_RGB_to_OD(I).reshape((-1, 3))
    
    OD = OD[tissue_mask]

    # Eigenvectors of cov in OD space (orthogonal as cov symmetric)
    _, V = np.linalg.eigh(np.cov(OD, rowvar=False))

    # The two principle eigenvectors
    V = V[:, [2, 1]]

    # Make sure vectors are pointing the right way
    if V[0, 0] < 0: V[:, 0] *= -1
    if V[0, 1] < 0: V[:, 1] *= -1

    # Project on this basis.
    That = np.dot(OD, V)

    # Angular coordinates with repect to the prinicple, orthogonal eigenvectors
    phi = np.arctan2(That[:, 1], That[:, 0])

    # Min and max angles
    minPhi = np.percentile(phi, 100 - angular_percentile)
    maxPhi = np.percentile(phi, angular_percentile)

    # the two principle colors
    v1 = np.dot(V, np.array([np.cos(minPhi), np.sin(minPhi)]))
    v2 = np.dot(V, np.array([np.cos(maxPhi), np.sin(maxPhi)]))

    # Order of H and E.
    # H first row.
    if v1[0] > v2[0]:
        HE = np.array([v1, v2])
    else:
        HE = np.array([v2, v1])

    return normalize_matrix_rows(HE)

######################################################################################################

def mapping(target,source):
    
    stain_matrix_target = get_stain_matrix(target)
    target_concentrations = get_concentrations(target,stain_matrix_target)
    maxC_target = np.percentile(target_concentrations, 99, axis=0).reshape((1, 2))
    stain_matrix_target_RGB = convert_OD_to_RGB(stain_matrix_target) 
    
    stain_matrix_source = get_stain_matrix(source)
    source_concentrations = get_concentrations(source, stain_matrix_source)
    maxC_source = np.percentile(source_concentrations, 99, axis=0).reshape((1, 2))
    source_concentrations *= (maxC_target / maxC_source)
    tmp = 255 * np.exp(-1 * np.dot(source_concentrations, stain_matrix_target))
    return tmp.reshape(source.shape).astype(np.uint8)

In [None]:
def trans(img):
    ### Stain Transformation
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    I_LAB = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
    L = I_LAB[:, :, 0] / 255.0  # Convert to range [0,1].
    mask = L < 0.8
    if mask.sum() == 0:
        trans = img
        #empty_img.append(name)
        #print('empty img')
    elif mask.sum() == 1:
        trans = img
        #almost_empty_img.append(name)
        print('almost empty img')
    else:
        trans = mapping(target,img)    
    
    return trans

# Data

In [None]:
target = cv2.imread('../input/panda-tiles-16x128x128/train/002a4db09dad406c85505a00fb6f6144_0.png')
target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)
plt.imshow(target)

In [None]:
def tile(img):
    shape = img.shape
    pad0,pad1 = (sz - shape[0]%sz)%sz, (sz - shape[1]%sz)%sz
    img = np.pad(img,[[pad0//2,pad0-pad0//2],[pad1//2,pad1-pad1//2],[0,0]],
                 constant_values=255)
    img = img.reshape(img.shape[0]//sz,sz,img.shape[1]//sz,sz,3)
    img = img.transpose(0,2,1,3,4).reshape(-1,sz,sz,3)
    if len(img) < N:
        img = np.pad(img,[[0,N-len(img)],[0,0],[0,0],[0,0]],constant_values=255)
    idxs = np.argsort(img.reshape(img.shape[0],-1).sum(-1))[:N]
    img = img[idxs]
   
    return img



#mean = torch.tensor([1.0-0.90949707, 1.0-0.8188697, 1.0-0.87795304])
#std = torch.tensor([0.36357649, 0.49984502, 0.40477625])

### Image Stats for stain normalized tiles 12x64x64
mean = torch.tensor([1.0-0.90008685, 1.0-0.80557228, 1.0-0.88988811])
std = torch.tensor([0.3247047, 0.41417728, 0.33721329])



class PandaDataset(Dataset):
    def __init__(self, path, test):
        self.path = path
        self.names = list(pd.read_csv(test).image_id)

    def __len__(self):
        return len(self.names)

    def __getitem__(self, idx):
        name = self.names[idx]
        #img = skimage.io.MultiImage(os.path.join(TEST_DATA,name+'.tiff'))[-1]
        img = skimage.io.MultiImage(os.path.join(self.path,name+'.tiff'))[-1]
        
        tiles = tile(img)
        
        ### Stain Normalization ###
        TMP = np.zeros(((N,sz,sz,3)))
        #print(TMP.shape)
        for i in range(N):
            tran = trans(tiles[i])
            #print(tran.shape)
            TMP[i] = tran
        ############################
        
        img_resized = resize(TMP, downsize)
        img_resized = (255*img_resized).astype(np.uint8)
        
        
        #tiles = torch.Tensor(1.0 - TMP/255.0)
        tiles = torch.Tensor(1.0 - img_resized/255.0)
        tiles = (tiles - mean)/std
        return tiles.permute(0,3,1,2), name

# Show stain normalization examples

In [None]:
fig = plt.figure()
fig, ax = plt.subplots(4,5, figsize=(20,20))
i = 0
names = os.listdir(TRAIN_DATA)[:4]
#img = skimage.io.MultiImage(os.path.join(TRAIN_DATA,names))[-1]
#mask = skimage.io.MultiImage(os.path.join(MASKS,name+'_mask.tiff'))[-1]
#tiles = tile(img,mask)

for name in tqdm(names):
    img = skimage.io.MultiImage(os.path.join(TRAIN_DATA,name))[-1]
    tiles = tile(img)
    for j in range(2):
        #j = 2j + 1
        img = tiles[j]
        tran = trans(img)

        ax[i][0].imshow(target)
        ax[i][0].set_title("Target Image",fontsize=10)
        ax[i][2*j+1].imshow(img)
        ax[i][2*j+1].set_title("Source Image",fontsize=10)
        ax[i][2*j+2].imshow(tran)
        ax[i][2*j+2].set_title("Transformed Image",fontsize=10)
    i = i + 1


# Prediction

In [None]:
sub_df = pd.read_csv(SAMPLE)
if os.path.exists(TEST_DATA):
    print('Starting predictions')
    ds = PandaDataset(TEST_DATA,TEST)
    dl = DataLoader(ds, batch_size=bs, num_workers=nworkers, shuffle=False)
    names,probs,preds = [],[],[]
    
    with torch.no_grad():
        for x,y in tqdm(dl):
            x = x.cuda()
            
            #dihedral TTA
            x = torch.stack([x,x.flip(-1),x.flip(-2),x.flip(-1,-2),
              x.transpose(-1,-2),x.transpose(-1,-2).flip(-1),
              x.transpose(-1,-2).flip(-2),x.transpose(-1,-2).flip(-1,-2)],1)
            sz = 64 #after resize
            x = x.view(-1,N,3,sz,sz)
            p = [model(x) for model in models]
            p = torch.stack(p,1)
            #p = p.view(bs,8*len(models),-1).mean(1).argmax(-1).cpu()
            prob = p.view(bs,8*len(models),-1).mean(1).cpu()
            pred = prob.argmax(-1)
            names.append(y)
            probs.append(prob)
            preds.append(pred)
    

    
    names = np.concatenate(names)
    probs = torch.cat(probs).numpy()
    preds = torch.cat(preds).numpy()
    sub_df = pd.DataFrame({'image_id': names, 'isup_grade': preds})
    sub_df.to_csv('submission.csv', index=False)
    sub_df.head()
    
    
else:
    print('Found No test data')

    ds = PandaDataset(TRAIN_DATA,TRAIN)
    dl = DataLoader(ds, batch_size=bs, num_workers=nworkers, shuffle=False)
    names,probs,preds = [],[],[]

    #with torch.no_grad():
    #    for x,y in tqdm(dl):
    #        x = x.cuda()
    #        #dihedral TTA
    #        x = torch.stack([x,x.flip(-1),x.flip(-2),x.flip(-1,-2),
    #          x.transpose(-1,-2),x.transpose(-1,-2).flip(-1),
    #          x.transpose(-1,-2).flip(-2),x.transpose(-1,-2).flip(-1,-2)],1)
    #        sz = 64 #after resize
    #        x = x.view(-1,N,3,sz,sz)
    #        p = [model(x) for model in models]
    #        p = torch.stack(p,1)
    #        #p = p.view(bs,8*len(models),-1).mean(1).argmax(-1).cpu()
    #        prob = p.view(bs,8*len(models),-1).mean(1).cpu()
    #        pred = prob.argmax(-1)
    #        names.append(y)
    #        probs.append(prob)
    #        preds.append(pred)
        
    #names = np.concatenate(names)
    #probs = torch.cat(probs).numpy()
    #preds = torch.cat(preds).numpy()
    #sub_df = pd.DataFrame({'image_id': names, 'isup_grade': preds})
    #sub_df.to_csv('submission.csv', index=False)
    #sub_df.head()
    

In [None]:
sub_df.to_csv("submission.csv", index=False)
sub_df.head()