In [None]:
import os
import time
import skimage.io
import numpy as np
import pandas as pd
import cv2
import PIL.Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SubsetRandomSampler, RandomSampler, SequentialSampler
import albumentations
from sklearn.model_selection import StratifiedKFold
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import zipfile
import torchvision
import shutil
os.mkdir('output/')

# Config

In [None]:
data_dir = '../input/prostate-cancer-grade-assessment'
df_train = pd.read_csv(os.path.join(data_dir, 'train.csv'))
image_folder = os.path.join(data_dir, 'train_images')
tile_size = 256
image_size = 256
n_tiles = 36
num_workers = 8
print(image_folder)

# Create Folds

In [None]:
skf = StratifiedKFold(8, shuffle=True, random_state=42)
df_train['fold'] = -1
for i, (train_idx, valid_idx) in enumerate(skf.split(df_train, df_train['isup_grade'])):
    df_train.loc[valid_idx, 'fold'] = i
df_train.head()

# Dataset

In [None]:
def get_tiles(img, mode=0):
        result = []
        h, w, c = img.shape
        pad_h = (tile_size - h % tile_size) % tile_size + ((tile_size * mode) // 2)
        pad_w = (tile_size - w % tile_size) % tile_size + ((tile_size * mode) // 2)

        img2 = np.pad(img,[[pad_h // 2, pad_h - pad_h // 2], [pad_w // 2,pad_w - pad_w//2], [0,0]], constant_values=255)
        img3 = img2.reshape(
            img2.shape[0] // tile_size,
            tile_size,
            img2.shape[1] // tile_size,
            tile_size,
            3
        )

        img3 = img3.transpose(0,2,1,3,4).reshape(-1, tile_size, tile_size,3)
        n_tiles_with_info = (img3.reshape(img3.shape[0],-1).sum(1) < tile_size ** 2 * 3 * 255).sum()
        if len(img3) < n_tiles:
            img3 = np.pad(img3,[[0,n_tiles-len(img3)],[0,0],[0,0],[0,0]], constant_values=255)
        idxs = np.argsort(img3.reshape(img3.shape[0],-1).sum(-1))[:n_tiles]
        img3 = img3[idxs]
        for i in range(len(img3)):
            result.append({'img':img3[i], 'idx':i})
        return result, n_tiles_with_info >= n_tiles


In [None]:
transforms_train = albumentations.Compose([
    albumentations.Transpose(p=0.5),
    albumentations.VerticalFlip(p=0.5),
    albumentations.HorizontalFlip(p=0.5),
])

In [None]:
fold = 4
df_train = df_train[df_train['fold']<fold]
indexes = np.arange(len(df_train))[:10]
for index in tqdm(indexes):
    row = df_train.iloc[index]
    img_id = row.image_id
    tiff_file = os.path.join(image_folder, f'{img_id}.tiff')
    image = skimage.io.MultiImage(tiff_file)[2]
    tiles, OK = get_tiles(image, 0)
    idxes = list(range(n_tiles))
    n_row_tiles = int(np.sqrt(n_tiles))
    images = np.zeros((image_size * n_row_tiles, image_size * n_row_tiles, 3))
    for h in range(n_row_tiles):
        for w in range(n_row_tiles):
            i = h * n_row_tiles + w
            if len(tiles) > idxes[i]:
                this_img = tiles[idxes[i]]['img']
            else:
                this_img = np.ones((image_size, image_size, 3)).astype(np.uint8) * 255
            this_img = transforms_train(image=this_img)['image']
            h1 = h * image_size
            w1 = w * image_size
            images[h1:h1+image_size, w1:w1+image_size] = this_img
    images = images.astype(np.float32)
    images /= 255
    plt.imsave('output/' + f'{img_id}.png', images)