In [None]:
import os
import cv2
import glob
import torch
import random
import pandas as pd

from tqdm import tqdm
from PIL import Image as PImage
from fastai.vision.all import *

# ASL Loss

In [None]:
# DATA
class AsymmetricLoss(nn.Module):
    def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=True):
        super(AsymmetricLoss, self).__init__()

        self.gamma_neg = gamma_neg
        self.gamma_pos = gamma_pos
        self.clip = clip
        self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
        self.eps = eps

    def forward(self, x, y):
        """"
        Parameters
        ----------
        x: input logits
        y: targets (multi-label binarized vector)
        """

        # Calculating Probabilities
        x_sigmoid = torch.sigmoid(x)
        xs_pos = x_sigmoid
        xs_neg = 1 - x_sigmoid

        # Asymmetric Clipping
        if self.clip is not None and self.clip > 0:
            xs_neg = (xs_neg + self.clip).clamp(max=1)

        # Basic CE calculation
        los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
        los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
        loss = los_pos + los_neg
        
        # Asymmetric Focusing
        if self.gamma_neg > 0 or self.gamma_pos > 0:
            if self.disable_torch_grad_focal_loss:
                torch.set_grad_enabled(False)
            pt0 = xs_pos * y
            pt1 = xs_neg * (1 - y)  # pt = p if t > 0 else 1-p
            pt = pt0 + pt1
            one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
            one_sided_w = torch.pow(1 - pt, one_sided_gamma)
            if self.disable_torch_grad_focal_loss:
                torch.set_grad_enabled(True)
            loss *= one_sided_w

        return -loss.mean()
    
    
@delegates()
class AsymmetricLossFlat(BaseLoss):
    @use_kwargs_dict(keep=True, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8)
    def __init__(self, *args,  axis=-1, floatify=True, thresh=0.5, **kwargs):
        if kwargs.get('pos_weight', None) is not None: kwargs['flatten'] = False
        super().__init__(AsymmetricLoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)
        self.thresh = thresh

    def decodes(self, x):    return x>self.thresh
    def activation(self, x): return torch.sigmoid(x)

# Data

In [None]:
img_list = [item.split('/')[-1] for item in glob.glob('../input/hpa-cell/train/*.jpg')]
label_list = [item.split('_')[-1].split('.')[0] for item in img_list]

In [None]:
df = pd.DataFrame([], columns=['img', 'label'])
df['img'] = img_list
df['label'] = label_list

In [None]:
path = Path('../input/hpa-cell/train/')
labels = [str(i) for i in range(19)]

In [None]:
def get_x(r): return path/(r['img'])
def get_y(r): return r['label'].split('|')

In [None]:
dblock = DataBlock(blocks=(ImageBlock, MultiCategoryBlock(vocab=labels)),
                   splitter=RandomSplitter(0.2),
                   get_x=get_x,
                   get_y=get_y,
                   item_tfms=Resize(256),
                   batch_tfms=[Normalize.from_stats(*imagenet_stats), *aug_transforms()])

In [None]:
dls = dblock.dataloaders(df, bs=4)

In [None]:
dls.train_ds

In [None]:
dls.valid_ds

# Train

In [None]:
learn = cnn_learner(
    dls, 
    densenet121, 
    loss_func=AsymmetricLossFlat(),
    metrics=[accuracy_multi, PrecisionMulti()]
).to_fp16()

In [None]:
learn.fine_tune(4, cbs=[SaveModelCallback(fname='d121')])

# Test

In [None]:
img_list = [item.split('/')[-1] for item in glob.glob('../input/hpa-cell/test/*.jpg')]
label_list = [item.split('_')[-1].split('.')[0] for item in img_list]

In [None]:
df = pd.DataFrame([], columns=['img', 'label'])
df['img'] = img_list
df['label'] = label_list

In [None]:
path = Path('../input/hpa-cell/test/')
labels = [str(i) for i in range(19)]

In [None]:
dblock = DataBlock(blocks=(ImageBlock, MultiCategoryBlock(vocab=labels)),
                   splitter=RandomSplitter(0.),
                   get_x=get_x,
                   get_y=get_y,
                   item_tfms=Resize(256),
                   batch_tfms=[Normalize.from_stats(*imagenet_stats), *aug_transforms()])

In [None]:
dls = dblock.dataloaders(df, bs=4)

In [None]:
dls.train_ds

In [None]:
dls.valid_ds

In [None]:
learn = cnn_learner(dls, densenet121, metrics=[accuracy_multi, PrecisionMulti()]).to_fp16()

In [None]:
learn.load('d121')

In [None]:
p, t = learn.get_preds(0)

In [None]:
p = p.numpy()
t = t.numpy()

In [None]:
img_lst = learn.dls.train_ds.items['img'].to_list()

In [None]:
np.save('../result/imgs.npy', img_lst)
np.save('../result/probs.npy', p)