In [39]:
import torch
import torch.nn as nn
import geffnet
from resnest.torch import resnest101
from pretrainedmodels import se_resnext101_32x4d
import albumentations as A 
import numpy as np
import os 
import pandas as pd 
from PIL import Image
import torch.nn.functional as F
#from accelerate import Accelerator
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import (roc_auc_score,
                             roc_curve,
                             auc,
                             accuracy_score,
                             mean_squared_error)
from transformers import get_linear_schedule_with_warmup

!pip install -U libauc
from libauc.losses import pAUC_CVaR_Loss
from libauc.optimizers import SOPA
from transformers import get_linear_schedule_with_warmup
from accelerate import Accelerator




In [40]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_epochs = 3
batch_size = 64
total_epochs = 60
weight_decay = 5e-4 # regularization weight decay
lr = 1e-3  # learning rate
eta = 1e1 # learning rate for control negative samples weights
decay_epochs = [20, 40]
decay_factor = 10

beta = 0.1 #

In [41]:
sigmoid = nn.Sigmoid()
class Swish(torch.autograd.Function):
    @staticmethod
    def forward(ctx, i):
        result = i * sigmoid(i)
        ctx.save_for_backward(i)
        return result
    @staticmethod
    def backward(ctx, grad_output):
        i = ctx.saved_variables[0]
        sigmoid_i = sigmoid(i)
        return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
class Swish_Module(nn.Module):
    def forward(self, x):
        return Swish.apply(x)
class Seresnext_Melanoma(nn.Module):
    def __init__(self, enet_type, out_dim, n_meta_features=0, n_meta_dim=[512, 128], pretrained=False):
        super(Seresnext_Melanoma, self).__init__()
        self.n_meta_features = n_meta_features
        if pretrained:
            self.enet = se_resnext101_32x4d(num_classes=1000, pretrained='imagenet')
        else:
            self.enet = se_resnext101_32x4d(num_classes=1000, pretrained=None)
        self.enet.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropouts = nn.ModuleList([
            nn.Dropout(0.5) for _ in range(5)
        ])
        in_ch = self.enet.last_linear.in_features
        if n_meta_features > 0:
            self.meta = nn.Sequential(
                nn.Linear(n_meta_features, n_meta_dim[0]),
                nn.BatchNorm1d(n_meta_dim[0]),
                Swish_Module(),
                nn.Dropout(p=0.3),
                nn.Linear(n_meta_dim[0], n_meta_dim[1]),
                nn.BatchNorm1d(n_meta_dim[1]),
                Swish_Module(),
            )
            in_ch += n_meta_dim[1]
        self.myfc = nn.Linear(in_ch, out_dim)
        self.enet.last_linear = nn.Identity()

    def extract(self, x):
        x = self.enet(x)
        return x

    def forward(self, x, x_meta=None):
        x = self.extract(x).squeeze(-1).squeeze(-1)
        if self.n_meta_features > 0:
            x_meta = self.meta(x_meta)
            x = torch.cat((x, x_meta), dim=1)
        for i, dropout in enumerate(self.dropouts):
            if i == 0:
                out = self.myfc(dropout(x))
            else:
                out += self.myfc(dropout(x))
        out /= len(self.dropouts)
        return out
    

class Effnet_Melanoma(nn.Module):
    def __init__(self, enet_type, out_dim, n_meta_features=0, n_meta_dim=[512, 128], pretrained=False):
        super(Effnet_Melanoma, self).__init__()
        self.n_meta_features = n_meta_features
        self.enet = geffnet.create_model(enet_type, pretrained=pretrained)
        self.dropouts = nn.ModuleList([
            nn.Dropout(0.5) for _ in range(5)
        ])
        in_ch = self.enet.classifier.in_features
        # if n_meta_features > 0:
        #     self.meta = nn.Sequential(
        #         nn.Linear(n_meta_features, n_meta_dim[0]),
        #         nn.BatchNorm1d(n_meta_dim[0]),
        #         Swish_Module(),
        #         nn.Dropout(p=0.3),
        #         nn.Linear(n_meta_dim[0], n_meta_dim[1]),
        #         nn.BatchNorm1d(n_meta_dim[1]),
        #         Swish_Module(),
        #     )
        #     in_ch += n_meta_dim[1]
        self.myfc = nn.Linear(in_ch, out_dim)
        self.enet.classifier = nn.Identity()

    def extract(self, x):
        x = self.enet(x)
        return x
                        
    def forward(self, x, x_meta=None):
        x = self.extract(x).squeeze(-1).squeeze(-1)
        # if self.n_meta_features > 0:
        #     x_meta = self.meta(x_meta)
        #     x = torch.cat((x, x_meta), dim=1)
        for i, dropout in enumerate(self.dropouts):
            if i == 0:
                out = self.myfc(dropout(x))
            else:
                out += self.myfc(dropout(x))
        out /= len(self.dropouts)
        return out

class Resnest_Melanoma(nn.Module):
    def __init__(self, enet_type, out_dim, n_meta_features=0, n_meta_dim=[512, 128], pretrained=False):
        super(Resnest_Melanoma, self).__init__()
        self.n_meta_features = n_meta_features
        self.enet = resnest101(pretrained=pretrained)
        self.dropouts = nn.ModuleList([
            nn.Dropout(0.5) for _ in range(5)
        ])
        in_ch = self.enet.fc.in_features
        if n_meta_features > 0:
            self.meta = nn.Sequential(
                nn.Linear(n_meta_features, n_meta_dim[0]),
                nn.BatchNorm1d(n_meta_dim[0]),
                Swish_Module(),
                nn.Dropout(p=0.3),
                nn.Linear(n_meta_dim[0], n_meta_dim[1]),
                nn.BatchNorm1d(n_meta_dim[1]),
                Swish_Module(),
            )
            in_ch += n_meta_dim[1]
        self.myfc = nn.Linear(in_ch, out_dim)
        self.enet.fc = nn.Identity()

    def extract(self, x):
        x = self.enet(x)
        return x

    def forward(self, x, x_meta=None):
        x = self.extract(x).squeeze(-1).squeeze(-1)
        if self.n_meta_features > 0:
            x_meta = self.meta(x_meta)
            x = torch.cat((x, x_meta), dim=1)
        for i, dropout in enumerate(self.dropouts):
            if i == 0:
                out = self.myfc(dropout(x))
            else:
                out += self.myfc(dropout(x))
        out /= len(self.dropouts)
        return out


In [42]:
CNNs = [('seresnext101',Seresnext_Melanoma,640,'Melanoma_2020_models/9c_se_x101_640_ext_15ep_best_fold0.pth'),
        ('tf_efficientnet_b4_ns',Effnet_Melanoma,448,'Melanoma_2020_models/9c_b4ns_448_ext_15ep-newfold_best_fold0.pth'),
        ('tf_efficientnet_b5_ns',Effnet_Melanoma,448,'Melanoma_2020_models/9c_b5ns_448_ext_15ep-newfold_best_fold0.pth'),
        ('tf_efficientnet_b6_ns',Effnet_Melanoma,576,'Melanoma_2020_models/9c_b6ns_576_ext_15ep_oldfold_best_fold0.pth'),
        ('tf_efficientnet_b7_ns',Effnet_Melanoma,576,'Melanoma_2020_models/9c_b7ns_1e_576_ext_15ep_oldfold_best_fold0.pth'),
        ('resnest101',Resnest_Melanoma,640,'Melanoma_2020_models/9c_nest101_2e_640_ext_15ep_best_fold0.pth')
]

In [43]:
class ImageDataset(Dataset):
    def __init__(self, csv_file,img_size, img_dir, mode='train'):
        self.data = csv_file if isinstance(csv_file, pd.DataFrame) else pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.image_paths = self.data['image'].values
        self.labels = self.data['target'].values
        self.image_size = img_size
        self.mode = mode
        if mode == 'train':
            self.transform = A.Compose([
                A.Transpose(p=0.5),
                A.VerticalFlip(p=0.5),
                A.HorizontalFlip(p=0.5),
                A.RandomBrightnessContrast(brightness_limit=0.2, p=0.75),
                A.OneOf([
                    A.MotionBlur(blur_limit=5),
                    A.MedianBlur(blur_limit=5),
                    A.GaussianBlur(blur_limit=5),
                    A.GaussNoise(var_limit=(5.0, 30.0)),
                ], p=0.7),
                A.OneOf([
                    A.OpticalDistortion(distort_limit=1.0),
                    A.GridDistortion(num_steps=5, distort_limit=1.0),
                    A.ElasticTransform(alpha=3),
                ], p=0.7),
                A.CLAHE(clip_limit=4.0, p=0.7),
                A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=10, p=0.5),
                A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, border_mode=0, p=0.85),
                A.CoarseDropout(max_holes=1, max_height=int(img_size * 0.375), max_width=int(img_size * 0.375), p=0.7),
            ])
        else:
            self.transform = None

        self.pos_indices = np.flatnonzero(self.labels==1)
        self.pos_index_map = {}
        for i, idx in enumerate(self.pos_indices):
            self.pos_index_map[idx] = i

    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        inputs = {}
        img_path = os.path.join(self.img_dir, self.image_paths[idx])
        if not os.path.exists(img_path):
            img_path += '.jpg'
        image = Image.open(img_path).convert('RGB')
        image = np.array(image)

        if self.mode == 'train':
            idx = self.pos_index_map[idx] if idx in self.pos_indices else -1
            image = self.transform_train(image)
        else:
            image = self.transform_test(image)
        
        inputs['pixel_values'] = F.interpolate(inputs['pixel_values'], size=(self.img_size, self.img_size), mode='bilinear', align_corners=False)

        inputs = {k: v.squeeze(0) for k, v in inputs.items()}
        inputs['labels'] = torch.tensor(self.labels[idx], dtype=torch.long)
        
        return inputs, idx
def collate_fn(batch):
    inputs = {
        'pixel_values': torch.stack([x[0]['pixel_values'] for x in batch]),  # Access the 'pixel_values' from the dictionary
        'labels': torch.tensor([x[0]['labels'] for x in batch])            # Access the 'labels' from the dictionary
    } 
    indices = torch.tensor([x[1] for x in batch])  # Extract the indices from the batch
    return inputs, indices 

In [44]:
train_metadata = pd.read_csv('/Users/jimmyhe/Desktop/KaggleCompetitions/CNNFineTune/ISIC_2019_Training_GroundTruth.csv')
train_metadata['target'] = train_metadata[['BCC', 'SCC', 'MEL']].eq(1.0).any(axis=1).astype(int)
path_list = [f"/Users/jimmyhe/Desktop/KaggleCompetitions/ISISCANCER/MetaDataPlusProprocessed/train-image/image/{id}.jpg" for id in train_metadata.image]
path_list = path_list[:10]

In [45]:
pos_length = train_metadata[train_metadata['target']==1].shape[0]
data_length = train_metadata.shape[0]
import logging
import tqdm
log_file = 'training_log.log'
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[
    logging.FileHandler(log_file),
    logging.StreamHandler()
])
logger = logging.getLogger(__name__)


In [46]:
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['target'] for x in batch])
    } 
def comp_score(solution: pd.DataFrame, submission: pd.DataFrame, min_tpr: float=0.80):
    v_gt = abs(np.asarray(solution.values)-1)
    v_pred = np.array([1.0 - x for x in submission.values])
    max_fpr = abs(1-min_tpr)
    partial_auc_scaled = roc_auc_score(v_gt, v_pred, max_fpr=max_fpr)
    partial_auc = 0.5 * max_fpr**2 + (max_fpr - 0.5 * max_fpr**2) / (1.0 - 0.5) * (partial_auc_scaled - 0.5)
    return partial_auc
def evaluate_model(model, dataloader, device):
    model.eval()
    all_preds = []
    all_labels = []
    # Wrap the dataloader with tqdm to show a progress bar
    with torch.no_grad():
        for inputs,_ in tqdm(dataloader, desc="Evaluating", unit="batch"):
            pixel_values = inputs['pixel_values'].to(device)
            labels = inputs['labels'].to(device)
            
            # Forward pass
            outputs = model(pixel_values=pixel_values)
            logits = outputs.logits
            preds = torch.argmax(logits, dim=-1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    roc_auc = comp_score(pd.Series(all_labels), pd.Series(all_preds))
    logger.info(f"pAUC Score: {roc_auc:.4f}")
    return roc_auc

def train_model(model, model_name, train_dataloader, val_dataloader, loss_fn, scheduler, device, accelerator, num_epochs=3):
    model.train()
    loss_log = []
    best_pauc = 0
    optimizer = SOPA(model, loss_fn=loss_fn, mode='adam', lr=lr, eta=eta, weight_decay=weight_decay)
    for epoch in range(num_epochs):
        total_loss = 0
        mse_loss = nn.MSELoss()
        ce_loss = nn.CrossEntropyLoss()
        epoch_losses = {'epoch': epoch + 1, 'total_loss': 0, 'mse_loss': 0, 'ce_loss': 0, 'auc_loss': 0, 'pauc_loss': 0}
        train_preds, train_labels = [], []
        for inputs, index in tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}", unit="batch"):
            with accelerator.accumulate(model):
                optimizer.zero_grad()
                pixel_values = inputs['pixel_values'].to(device)
                labels = inputs['labels'].to(device)
                index = index.to(device)
                outputs = model(pixel_values=pixel_values, labels=labels)
                loss = loss_fn(outputs, labels, index)
                mse = mse_loss(outputs.logits, labels.float())
                ce = ce_loss(outputs.logits, labels.long())
                
                accelerator.backward(loss)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

                epoch_losses['train_loss'] += loss.item()
                train_preds.extend(torch.sigmoid(outputs.logits[:, 1]).detach().cpu().numpy())
                train_labels.extend(inputs['labels'].cpu().numpy())

        epoch_losses['train_loss'] /= len(train_dataloader)
        epoch_losses['train_pauc'] = comp_score(pd.Series(train_labels), pd.Series(train_preds))
        epoch_losses['train_accuracy'] = accuracy_score(train_labels, np.round(train_preds))
        epoch_losses['train_MSE'] = mean_squared_error(train_labels, train_preds)
            
        # Validation
        val_preds, val_labels = [], []
        with torch.no_grad():
            for batch in tqdm(val_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs} - Validation", unit="batch"):
                batch = {k: v.to(accelerator.device) for k, v in batch.items()}  # Move batch to device
                outputs = model(**batch)
                loss = loss_fn(outputs.logits[:, 1], batch['labels'].float())
                
                epoch_losses['val_loss'] += loss.item()
                val_preds.extend(torch.sigmoid(outputs.logits[:, 1]).cpu().numpy())
                val_labels.extend(batch['labels'].cpu().numpy())

        epoch_losses['val_loss'] /= len(val_dataloader)
        epoch_losses['val_pauc'] = comp_score(pd.Series(val_labels), pd.Series(val_preds))
        epoch_losses['val_accuracy'] = accuracy_score(val_labels, np.round(val_preds))
        epoch_losses['val_MSE'] = mean_squared_error(val_labels, val_preds)
  
        accelerator.print(f"Epoch {epoch + 1}/{num_epochs}, "
                          f"Train Loss: {epoch_losses['train_loss']:.4f}, "
                          f"Train pAUC: {epoch_losses['train_pauc']:.4f}, "
                          f"Train Accuracy: {epoch_losses['train_accuracy']:.4f}, "
                          f"Train MSE: {epoch_losses['train_MSE']:.4f}, "
                          f"Val Loss: {epoch_losses['val_loss']:.4f}, "
                          f"Val pAUC: {epoch_losses['val_pauc']:.4f}, "
                          f"Val Accuracy: {epoch_losses['val_accuracy']:.4f}, "
                          f"Val MSE: {epoch_losses['val_MSE']:.4f}")

        if epoch_losses['val_pauc'] > best_pauc:
            best_pauc = epoch_losses['val_pauc']
            accelerator.wait_for_everyone()
            unwrapped_model = accelerator.unwrap_model(model)
            accelerator.save(unwrapped_model.state_dict(), f"{save_dir}/best_model_{model_name.replace('/', '_')}.pth")
        else:
            # Save model at each epoch
            accelerator.wait_for_everyone()
            unwrapped_model = accelerator.unwrap_model(model)
            accelerator.save(unwrapped_model.state_dict(), f"{save_dir}/model_epoch_{epoch + 1}_{model_name.replace('/', '_')}.pth")

        # Log losses to CSV
        loss_log.append(epoch_losses)
        if accelerator.is_main_process:
            pd.DataFrame(loss_log).to_csv(f"{save_dir}/loss_log_{model_name.replace('/', '_')}.csv", index=False)

    accelerator.print(f"Best Validation pAUC Score: {best_pauc:.4f}")
    return best_pauc


In [47]:
def main():
    # Iterate through each model in CNNs
    accelerator = Accelerator()
    DEVICE = accelerator.device
    logger.info(f"DEVICE USING: {DEVICE}")
    for enet_type, ModelClass, img_size, model_dir in CNNs:
        print(f"Processing {enet_type}")
        batch_size=64
        # Initialize the model
        if ModelClass == Seresnext_Melanoma:
            model = ModelClass(enet_type, out_dim=9)
        elif ModelClass == Effnet_Melanoma:
            model = ModelClass(enet_type, out_dim=9)
        elif ModelClass == Resnest_Melanoma:
            model = ModelClass(enet_type, out_dim=9)
        train_dataset = ImageDataset(train_metadata, path_list, mode='train')
        val_dataset = ImageDataset(train_metadata, path_list)
        train_dataloader = torch.utils.data.DataLoader(train_dataset, image_size=img_size,batch_size=batch_size, shuffle=False, num_workers=1)
        val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=1)
        try:
            model.load_state_dict(torch.load(model_dir, map_location=device), strict=True)
        except:
            state_dict = torch.load(model_dir, map_location=device)
            state_dict = {k[7:] if k.startswith('module.') else k: state_dict[k] for k in state_dict.keys()}
            model.load_state_dict(state_dict, strict=True)

        num_training_steps = len(train_dataloader) * num_epochs
        num_warmup_steps = int(0.1 * num_training_steps)
        loss_fn = pAUC_CVaR_Loss(pos_len=pos_length, beta=beta)

        optimizer = SOPA(model.parameters(), loss_fn=loss_fn, lr=0.1, momentum=0.9)

        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps
        )
        train_model(model, enet_type, train_dataloader, val_dataloader, loss_fn, scheduler, device, num_epochs=3)