In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip install --upgrade fastai > /dev/null
!pip install --upgrade fastcore > /dev/null
!pip install pretrainedmodels > /dev/null

In [None]:
from fastai.vision.all import *
import pretrainedmodels

In [None]:
!apt update && apt install -y openslide-tools
!pip install openslide-python

In [None]:
#Load the dependancies
from fastai.basics import *
from fastai.callback.all import *
from fastai.vision.all import *

import seaborn as sns
import numpy as np
import pandas as pd
import os
import cv2

import openslide

sns.set(style="whitegrid")
sns.set_context("paper")

matplotlib.rcParams['image.cmap'] = 'ocean_r'

In [None]:
source = Path("../input/prostate-cancer-grade-assessment/")
files = os.listdir(source)
source.ls()

In [None]:
train = source/'train_images'
mask = source/'train_label_masks'
train_labels = pd.read_csv(source/'train.csv')
train_labels.head(), len(train_labels)

In [None]:
train_labels_dtypes = {'image_id': 'string', 'data_provider': 'string', 
                       'isup_grade': int, 'gleason_score': 'string' }
train_labels = train_labels.astype(train_labels_dtypes)

In [None]:
localfiles = list(train.glob('*.tiff'))
localfiles = set([filename.stem[:32] for filename in localfiles])
len(localfiles)

In [None]:
train_labels = train_labels[train_labels.image_id.isin(localfiles)]
print(len(train_labels))

In [None]:
train_labels[train_labels.image_id == "b5db121ca6ba4d979a6bef814d5fdb17"]
'b5db121ca6ba4d979a6bef814d5fdb17' in localfiles

In [None]:
def open_tiff_image(image_path):
    f = openslide.OpenSlide(str(image_path))    

def check_tiff_images(path):
    valid_files = L()
    for file in path.glob("*.tiff"):
        try:
            n = open_tiff_image(file)
            valid_files.append(file.stem[:32])
        except Exception as e:
            print(str(file))
            print(e)        
    return valid_files

In [None]:
valid_train_files = check_tiff_images(train)
len(valid_train_files)

In [None]:
valid_mask_files = check_tiff_images(mask)
len(valid_mask_files)

In [None]:
valid_mask_files[0]

In [None]:
file_to_delete = (set(valid_train_files)).difference(set(valid_mask_files))

In [None]:
valid_train_files[:10], valid_mask_files[:10]

In [None]:
len(file_to_delete)

In [None]:
import os
for file in list(file_to_delete):
    train_labels = train_labels[train_labels.image_id != file]
    #os.remove(file)
len(train_labels)    

In [None]:
len(train_labels)

In [None]:
def custom_img(fn):
    fn = f'{train}/{fn.image_id}.tiff'
    #print(fn)
    try:
        file = openslide.OpenSlide(str(fn))        
    except Exception as e:
        print(fn)
        print(e)
    t = tensor(file.get_thumbnail(size=(255, 255)))
    img_pil = PILImage.create(t)
    return img_pil

def show_selective(p, scale=True, cmap=plt.cm.ocean_r, min_px=None, max_px=None):
    px = tensor(p)
    if min_px is not None: px[px<min_px] = float(min_px)
    if max_px is not None: px[px>max_px] = float(max_px)
    return px

def custom_selective_mask(fn):
    fn = f'{mask}/{fn.image_id}_mask.tiff'
    try:
        file = openslide.OpenSlide(str(fn))  
        #file = Image.open(str(fn))
    except Exception as e:        
        print(fn)
        print(e)
    t = tensor(file.get_thumbnail(size=(255, 255)))[:,:,0]
    ts = show_selective(t, min_px=None, max_px=None)
    return ts

In [None]:
blocks = (ImageBlock,
          ImageBlock,
          CategoryBlock)

getters = [
           custom_img,
           custom_selective_mask,
           ColReader('isup_grade')
          ]

In [None]:
dblock_model = DataBlock(blocks=blocks,
                   getters=getters,
                   splitter=RandomSplitter(0.1),
                   item_tfms=[Resize(224), ToTensor],
                   batch_tfms=[IntToFloatTensor, Normalize.from_stats(*imagenet_stats)])

dl = dblock_model.dataloaders(train_labels, bs=128)
dl.show_batch(max_n=4)

In [None]:
dl.train_ds.vocab

In [None]:
class ProstateCancerModel(Module):
    def __init__(self, encoder, head):
        self.encoder, self.head = encoder, head

    def forward(self, x1, x2):
        enc1 = self.encoder(x1)    
        enc2 = self.encoder(x2)    
        ftrs = torch.cat([enc1, enc2], dim=1)
        return self.head(ftrs)

def loss_func(out, targ):
    return CrossEntropyLossFlat()(out, targ.long())

def siamese_splitter(model):
    return [params(model.encoder), params(model.head)]

encoder = create_body(resnet34, cut=-2)
head = create_head(512*2, len(dl.vocab), ps=0.5)
model = ProstateCancerModel(encoder, head)

In [None]:
def prostate_cancer_splitter(model):
    return [params(model.encoder), params(model.head)]

In [None]:
kp = CohenKappa()
kp.weights = 'quadratic'

In [None]:
learner = Learner(dl,
                  model,
                  loss_func=loss_func,
                  splitter=prostate_cancer_splitter,
                  metrics=[accuracy, kp]
                  )
learner.freeze()

In [None]:
learner.fine_tune(1)

In [None]:
learner.export("prostate_stage_1.pkl")

In [None]:
learner.unfreeze()

In [None]:
learner.lr_find()

In [None]:
learner.model = model.cuda()

In [None]:
learner.fit_one_cycle(10)

In [None]:
learner.export("prostate_stage_2.pkl")