#### This notebook is heavily (almost a rip off) based on https://www.kaggle.com/code/thedevastator/training-fastai-baseline -- thanks to https://www.kaggle.com/thedevastator


##### Since the images are from healhty people, I wanted to see if we can determine Age/Gender using only the biopsy images.

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

from fastai.vision.all import *
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import os
import cv2
import gc
import random
from albumentations import *
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore")

In [None]:
bs = 64
nfolds = 4
fold = 0
SEED = 2020
root_data = '../input/hubmap-2022-512x512/'
TRAIN = f'{root_data}/train/'
MASKS = f'{root_data}/masks/'
LABELS = '../input/hubmap-organ-segmentation/train.csv'
NUM_WORKERS = 4
device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    #the following line gives ~10% speedup
    #but may lead to some stochasticity in the results 
    torch.backends.cudnn.benchmark = True
    return seed
    
seed_everything(SEED)

In [None]:
# https://www.kaggle.com/datasets/thedevastator/hubmap-2022-256x256
mean = np.array([0.7720342, 0.74582646, 0.76392896])
std = np.array([0.24745085, 0.26182273, 0.25782376])

def img2tensor(img,dtype:np.dtype=np.float32):
    if img.ndim==2 : img = np.expand_dims(img,2)
    img = np.transpose(img,(2,0,1))
    return torch.from_numpy(img.astype(dtype, copy=False))

class HuBMAPDataset(Dataset):
    def __init__(self, fold=fold, train=True, tfms=None, is_debug=False):
        ids = pd.read_csv(LABELS).id.astype(str).values
        kf = KFold(n_splits=nfolds,random_state=SEED,shuffle=True)
        ids = set(ids[list(kf.split(ids))[fold][0 if train else 1]])
        self.fnames = [fname for fname in os.listdir(TRAIN) if fname.split('_')[0] in ids]
        self.train = train
        self.tfms = tfms
        self.df = pd.read_csv(LABELS)
        self.is_debug = is_debug
        
    def __len__(self):
        return len(self.fnames)
    
    def __getitem__(self, idx):
        fname = self.fnames[idx]
        img = cv2.cvtColor(cv2.imread(os.path.join(TRAIN,fname)), cv2.COLOR_BGR2RGB)
        mask = cv2.imread(os.path.join(MASKS,fname),cv2.IMREAD_GRAYSCALE)
        if self.tfms is not None:
            augmented = self.tfms(image=img,mask=mask)
            img,mask = augmented['image'],augmented['mask']
        X = img2tensor((img/255.0 - mean)/std)
        Xmasks = img2tensor(mask)
        
        _id = int(os.path.splitext(fname)[0].split('_')[0])
        row = self.df[self.df['id'] == _id].reset_index(drop=True)
        y = torch.tensor([row['age'].values[0], 0 if row['sex'].values[0] == 'Male' else 1], dtype=torch.float32)
        
        if self.is_debug:
            return X, Xmasks, y
        return X, y
    
def get_aug(p=1.0):
    return Compose([
        HorizontalFlip(),
        VerticalFlip(),
        RandomRotate90(),
        ShiftScaleRotate(
            shift_limit=0.0625,
            scale_limit=0.2,
            rotate_limit=15,
            p=0.5, 
            border_mode=cv2.BORDER_CONSTANT
        ),
        OneOf([
            OpticalDistortion(p=0.3),
            GridDistortion(p=.1),
            IAAPiecewiseAffine(p=0.3),
        ], p=0.), # switched off
        OneOf([
            HueSaturationValue(10,15,10),
            CLAHE(clip_limit=2),
            RandomBrightnessContrast(),            
        ], p=0.), # switched off
    ], p=p)

In [None]:
#example of train images with masks
ds = HuBMAPDataset(tfms=get_aug(), is_debug=True)
dl = DataLoader(ds,batch_size=4,shuffle=False,num_workers=NUM_WORKERS)
imgs,masks,tgts = next(iter(dl))

plt.figure(figsize=(16,16))
for i,(img,mask) in enumerate(zip(imgs,masks)):
    img = ((img.permute(1,2,0)*std + mean)*255.0).numpy().astype(np.uint8)
    plt.subplot(8,8,i+1)
    plt.title(tgts[i].numpy().tolist())
    plt.imshow(img,vmin=0,vmax=255)
    plt.imshow(mask.squeeze().numpy(), alpha=0.2)
    plt.axis('off')
    plt.subplots_adjust(wspace=None, hspace=None)
    
del ds,dl,imgs,masks

In [None]:
try:
    import timm
except:
    !pip install timm
    import timm

class Model(nn.Module):
    def __init__(self, model_name='tf_efficientnet_b0_ns', pretrained=True):
        super().__init__()
        torch.cuda.empty_cache()
        self.model = timm.create_model(model_name, pretrained=pretrained, num_classes=2)

    def forward(self, x):
        x = self.model(x)
        return x

# model = Model('resnet18', pretrained=True).to(device)

# Train

### Trying out a resnet18 on 512x512 images [https://www.kaggle.com/datasets/thedevastator/hubmap-2022-512x512]

In [None]:
def the_loss(y, y1):
    gender_loss = F.binary_cross_entropy_with_logits(y[:, 1], y1[:, 1], reduction='mean')
    age_loss = F.mse_loss(y[:, 0], y1[:, 0], reduction='mean')
    return age_loss + gender_loss

def rmse_age(y1, y):
    y, y1 = y[:, 0], y1[:, 0]
    return ((y - y1)**2).mean() ** 0.5

def gender_accuracy(y1, y):
    y, y1 = y[:, 1], y1[:, 1]
    y1, y = nn.Sigmoid()(y1.flatten()) < 0.5, y.flatten() == 0
    return (y==y1).sum()/len(y)

for fold in range(nfolds):
    if fold not in [0]: continue
    ds_t = HuBMAPDataset(fold=fold, train=True, tfms=get_aug())
    ds_v = HuBMAPDataset(fold=fold, train=False)
    data = ImageDataLoaders.from_dsets(ds_t,ds_v,bs=bs,num_workers=NUM_WORKERS,pin_memory=True).cuda()
    model = Model('resnet18', pretrained=True).to(device)
    learn = Learner(
        data, model, loss_func=the_loss,
        metrics=[rmse_age, gender_accuracy],
        cbs=[
            ShowGraphCallback(),
            SaveModelCallback(fname=f'best_fold_{fold}'),
        ],
    ).to_fp16()

    learn.fit_one_cycle(10, lr_max=1e-3, pct_start=0.)
    gc.collect()