# Imports

In [None]:
%env CUDA_VISIBLE_DEVICES=0

In [None]:
import torch

In [None]:
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['figure.facecolor'] = 'white'
matplotlib.rcParams['figure.figsize'] = (15, 5)

In [None]:
import pandas as pd
pd.options.display.max_columns = None

In [None]:
%run ../utils/__init__.py
config_logging(logging.INFO)

# Load model

In [None]:
%run ../models/classification/__init__.py

In [None]:
model = create_cnn(
    'densenet-121-v2', labels=range(14), gpool='avg', dropout=0.3, dropout_features=0.5,
).cuda()
model.classifier

# Load data

In [None]:
%run ../datasets/__init__.py

In [None]:
dataset_kwargs = {
    'dataset_name': 'chexpert',
    'dataset_type': 'train',
    'max_samples': 100,
    'image_size': (256, 256),
    'batch_size': 10,
    # 'labels': ['Cardiomegaly'],
}
dataloader = prepare_data_classification(**dataset_kwargs)
dataset = dataloader.dataset
len(dataset)

In [None]:
dataset.label_index['Cardiomegaly'].sum() / len(dataset)

# Try LibAUC

In [None]:
from libauc.losses import APLoss_SH, AUCMLoss
from libauc.optimizers import SOAP_SGD, SOAP_ADAM, PESG

In [None]:
# Copied from libauc code installed by python
# Authors have not released the code in github yet:
# https://github.com/yzhuoning/LibAUC/issues/7
class APLoss_SH(torch.nn.Module):
    """
    AP Loss with squared-hinge function: a novel loss function to directly optimize AUPRC
    
    inputs:
        margin: margin for squred hinge loss, e.g., m in [0, 1]
        beta: factors for moving average, which aslo refers to gamma in the paper
    outputs:
        loss  
    Reference:
        Qi, Q., Luo, Y., Xu, Z., Ji, S. and Yang, T., 2021. 
        Stochastic Optimization of Area Under Precision-Recall Curve for Deep Learning with Provable Convergence. 
        arXiv preprint arXiv:2104.08736.
    Link:
        https://arxiv.org/abs/2104.08736
    """
    def __init__(self, data_len=None, margin=0.8, beta=0.99, choice_p=3, device='cuda'):
        super(APLoss_SH, self).__init__()
        # TODO!
        self.u_all = torch.zeros(data_len, 1, dtype=torch.float64, device=device)
        self.u_pos = torch.zeros(data_len, 1, dtype=torch.float64, device=device)
        self.margin = margin
        self.choice_p = choice_p
        self.beta = beta
    
    def forward(self, y_pred, y_true, index_s):    
        y_pred_ps = y_pred[y_true == 1].reshape(-1, 1)
        y_pred_ns = y_pred[y_true == 0].reshape(-1, 1)
        
        y_true = y_true.reshape(-1)
        index_s = index_s[y_true==1]
        pos_num = (y_true == 1).int().sum()
        if True: #pos_num != 0:
            y_pred_matrix = y_pred_ns.repeat(pos_num, 1).reshape(pos_num, -1)  # (batch_size-pos_num, pos_num)
    
            neg_mask = torch.ones_like(y_pred_matrix)
            neg_mask[:, 0:pos_num] = 0
    
            pos_mask = torch.zeros_like(y_pred_matrix)
            pos_mask[:, 0:pos_num] = 1
    
            all_loss = torch.max(self.margin - (y_pred_ps - y_pred_matrix), torch.zeros_like(y_pred_matrix)) ** 2   
            pos_loss = torch.max(self.margin  - (y_pred_matrix), torch.zeros_like(y_pred_ps)) ** 2 * pos_mask
    
            if y_pred_ps.size(0) == 1:   
                self.u_all[index_s] = (1 - self.beta) * self.u_all[index_s] + self.beta * (all_loss.sum())
                self.u_pos[index_s] = (1 - self.beta) * self.u_pos[index_s] + self.beta * (pos_loss.sum())
            else:
                self.u_all[index_s] = (1 - self.beta) * self.u_all[index_s] + self.beta * (all_loss.sum(1, keepdim=True))
                self.u_pos[index_s] = (1 - self.beta) * self.u_pos[index_s] + self.beta * (pos_loss.sum(1, keepdim=True))
                
            p = all_loss / self.u_all[index_s]
        
            p.detach_()
            loss = torch.mean(p * all_loss)
        else:
            # TODO!
            all_loss = torch.max(self.margin - (0-y_pred_ns), torch.zeros_like(y_pred_ns))**2 
            loss = all_loss.mean()
        
        return loss

In [None]:
class MultilabelAPLoss(nn.Module):
    def __init__(self, n_labels, n_samples, **kwargs):
        super().__init__()

        self.losses = nn.ModuleList([
            APLoss_SH(data_len=n_samples, **kwargs)
            for _ in range(n_labels)
        ])
        
    def forward(self, preds, targets, index):
        return torch.tensor([
            loss(preds[:, idx], targets[:, idx], index)
            for idx, loss in enumerate(self.losses)
        ], device=preds.device)

In [None]:
class AUCMLoss(torch.nn.Module):
    """
    AUCM Loss: a novel loss function to directly optimize AUROC
    
    inputs:
        margin: margin term for AUCM loss, e.g., m in [0, 1]
        imratio: imbalance ratio, i.e., the ratio of number of postive samples to number of total samples
    outputs:
        loss value 
    
    Reference: 
        Yuan, Z., Yan, Y., Sonka, M. and Yang, T., 2020. 
        Robust Deep AUC Maximization: A New Surrogate Loss and Empirical Studies on Medical Image Classification. 
        arXiv preprint arXiv:2012.03173.
    Link:
        https://arxiv.org/abs/2012.03173
    """
    def __init__(self, margin=1.0, imratio=None, device='cuda'):
        super(AUCMLoss, self).__init__()
        self.margin = margin
        self.p = imratio
        self.a = torch.zeros(1, dtype=torch.float32, device=device, requires_grad=True)
        self.b = torch.zeros(1, dtype=torch.float32, device=device, requires_grad=True)
        self.alpha = torch.zeros(1, dtype=torch.float32, device=device, requires_grad=True)
        
    def forward(self, y_pred, y_true):
        if self.p is None:
           self.p = (y_true==1).float().sum()/y_true.shape[0]   
     
        y_pred = y_pred.reshape(-1, 1) # be carefull about these shapes
        y_true = y_true.reshape(-1, 1) 
        loss = (1-self.p)*torch.mean((y_pred - self.a)**2*(1==y_true).float()) + \
                    self.p*torch.mean((y_pred - self.b)**2*(0==y_true).float())   + \
                    2*self.alpha*(self.p*(1-self.p)*self.margin + \
                    torch.mean((self.p*y_pred*(0==y_true).float() - (1-self.p)*y_pred*(1==y_true).float())) )- \
                    self.p*(1-self.p)*self.alpha**2
        return loss

In [None]:
# loss_fn = APLoss_SH(data_len=len(dataloader.dataset))
# loss_fn = MultilabelAPLoss(14, len(dataloader.dataset), device='cuda')
loss_fn = AUCMLoss(imratio=0.1)

In [None]:
predictions = torch.rand(7, 1).cuda()
labels = (torch.rand(7, 1) > 0.5).long().cuda()
l = loss_fn(predictions, labels)
l

In [None]:
# optimizer = SOAP_SGD(model.parameters(), lr=0.001)
optimizer = PESG(
    model,
    a=loss_fn.a,
    b=loss_fn.b,
    alpha=loss_fn.alpha,
    imratio=loss_fn.p,
    lr=0.001,
)

In [None]:
_ = model.train()

In [None]:
INDEX = 1

for batch in dataloader:
    images = batch.image.cuda()
    labels = batch.labels.cuda()
    # index = batch.idx.cuda()

    out = model(images)
    prediction = out[0]
    
    loss = loss_fn(prediction[:, INDEX], labels[:, INDEX])

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    break

In [None]:
prediction.size(), labels.size() # , index.size()

In [None]:
from medai.models import load_compiled_model
from medai.utils import RunId

In [None]:
run_id = RunId('0321_052008', False, 'cls')

In [None]:
cm = load_compiled_model(run_id)
type(cm.model)

In [None]:
cm.model.parameters