In [None]:
import torch
from torch import nn
from torchvision import models

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        model = models.convnext_tiny(weights=models.ConvNeXt_Tiny_Weights.IMAGENET1K_V1)
        self.feature_extractor = torch.nn.Sequential(*list(model.features.children()))
        
        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Linear(768, 2)
        )

    def forward(self, x):
        cam_features = self.feature_extractor(x)
        features = self.gap(cam_features)
        features = features.view(x.shape[0], -1)
        logits = self.classifier(features)
        return logits,cam_features

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
import numpy as np
from tqdm import tqdm
import random
import pickle
from PIL import Image
import torch.nn.functional as F


def min_max_normalize(image):
    min_value,_ = torch.min(image, 1)
    min_value,_ = torch.min(min_value, 1)
    min_value = min_value.unsqueeze(1).unsqueeze(1)
    max_value,_ = torch.max(image, 1)
    max_value,_ = torch.max(max_value, 1)
    max_value = max_value.unsqueeze(1).unsqueeze(1)
    
    output = (image - min_value) / (max_value - min_value)
    
    return output

class CustomDataset:
    def __init__(self, all_json, transform):
        self.all_json = all_json
        self.transform = transform
        self.basedir = './ISIC_2019_Training_Input/'
        self.label_dict = {"NV": 0, "MEL": 1}

    def __getitem__(self, idx):
        img_path = self.basedir + self.all_json[idx]['image'] + '.jpg'
        image = self.transform(Image.open(img_path).convert("RGB"))
        label = self.label_dict[self.all_json[idx]['label']]
        return image, image, torch.tensor(label, dtype=torch.long)

    def __len__(self):
        return len(self.all_json)


def do_epoch1(model, dataloader, criterion, optim=None):
    total_loss = 0
    total_accuracy = 0

    for x, _, y_true in tqdm(dataloader, leave=False):
        x, y_true = x.to(device), y_true.to(device)
        y_pred, _ = model(x)
        loss = criterion(y_pred, y_true)

        if optim is not None:
            optim.zero_grad()
            loss.backward()
            optim.step()

        total_loss += loss.item()
        total_accuracy += (y_pred.argmax(dim=1) == y_true).float().mean().item()

    mean_loss = total_loss / len(dataloader)
    mean_accuracy = total_accuracy / len(dataloader)
    return mean_loss, mean_accuracy


def do_epoch2(model, dataloader, criterion1, criterion2, optim1=None, optim2=None):
    total_ce1_loss = 0
    total_ce2_loss = 0
    total_caam_loss = 0
    total_accuracy = 0
    count = 0

    for x1, x2, y_true in tqdm(dataloader, leave=False):
        x1, x2, y_true = x1.to(device), x2.to(device), y_true.to(device)
        seed = random.randint(1,1000)
        
        # 1. Cross-Entropy Loss1 (original input)
        y_pred, feature_map = model(x1)
        feature_map = F.relu(feature_map)
        ce1_loss = criterion1(y_pred, y_true)

        # 2. CAAM Loss
        cam = min_max_normalize(torch.sum(feature_map, dim=1))
        
        transform_HF =  transforms.RandomHorizontalFlip()
        transform_VF =  transforms.RandomVerticalFlip()
        transform_random_crop =  transforms.RandomCrop(random.randint(160,210))
        transform_resize = transforms.Resize(224)

        ori_caam = transform_resize(cam)
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
        torch.cuda.manual_seed(seed)
        # transformed ori-caam
        caam = transform_resize(transform_VF(transform_HF(transform_random_crop(ori_caam))))
        
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
        torch.cuda.manual_seed(seed)
        transform_x = transform_resize(transform_VF(transform_HF(transform_random_crop(x2))))
        ty_pred,transform_feature_map = model(transform_x)
        transform_feature_map = F.relu(transform_feature_map)
        
        # transformed caam
        transform_caam = transform_resize(min_max_normalize(torch.sum(transform_feature_map, dim=1)))
        
        caam_loss = criterion2(transform_caam, caam)
        
        # 3. Cross-Entropy Loss2 ( input)
        ce2_loss = criterion1(ty_pred, y_true)
        
        loss = ce1_loss + caam_loss + ce2_loss
        
        if optim1 is not None:
            optim1.zero_grad()
            optim2.zero_grad()
            loss.backward()
            optim1.step()
            optim2.step()
        
        total_ce1_loss += ce1_loss.item()
        total_ce2_loss += ce2_loss.item()
        total_caam_loss += caam_loss.item()
        total_accuracy += (y_pred.argmax(dim=1) == y_true).float().mean().item()
        count += 1

    mean_ce1_loss = total_ce1_loss / count
    mean_ce2_loss = total_ce2_loss / count
    mean_caam_loss = total_caam_loss / count
    mean_accuracy = total_accuracy / count
    return mean_ce1_loss, mean_ce2_loss, mean_caam_loss, mean_accuracy


# Initialize device, data, and model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

with open("./Train.pkl", "rb") as f:
    train_set = pickle.load(f)
with open("./Valid.pkl", "rb") as f:
    val_set = pickle.load(f)

transform = transforms.Compose([
                                   transforms.ToTensor()
                                   ,transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

train_dataset = CustomDataset(train_set, transform)
val_dataset = CustomDataset(val_set, transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)


# class weight
numoflabel_1 = 0
numoflabel_2 = 0
for i in range(len(train_set)):
    if train_set[i]['label'] == "NV":
        numoflabel_1 +=1
    elif train_set[i]['label'] == "MEL":
        numoflabel_2 +=1
        
weights = torch.tensor([numoflabel_1,numoflabel_2], dtype=torch.float32)
weights = weights / weights.sum()
weights = 1.0 / weights
weights = weights / weights.sum()

model = Net().to(device)
optim = torch.optim.Adam(model.parameters(),lr=5e-5)
optm_1 = torch.optim.Adam(model.feature_extractor.parameters(), lr=1e-4)
optm_2 = torch.optim.Adam(model.classifier.parameters(), lr=1e-5)

lr_schedule1 = torch.optim.lr_scheduler.ReduceLROnPlateau(optm_1, patience=10, verbose=True)
lr_schedule2 = torch.optim.lr_scheduler.ReduceLROnPlateau(optm_2, patience=10, verbose=True)

criterion1 = nn.CrossEntropyLoss(weight=weights.to(device))
criterion2 = nn.MSELoss()
best_loss = 99999999
# Initial 2-epoch training (do_epoch1)
for epoch in range(1, 3):
    model.train()
    train_loss, train_accuracy = do_epoch1(model, train_loader, criterion1, optim=optim)

for epoch in range(1, 100):
    model.train()
    train_loss1,train_loss2, train_loss3,train_accuracy = do_epoch2(model, train_loader, criterion1, criterion2, optim1 = optm_1, optim2 = optm_2)
    model.eval()
    with torch.no_grad():
        val_loss1,val_loss2, val_loss3,val_accuracy = do_epoch2(model, val_loader, criterion1, criterion2)
        
    print(f"Epoch {epoch:03d}: Train CE1 Loss: {train_loss1:.4f}, Train CE2 Loss: {train_loss2:.4f}, Train CAAM Loss: {train_loss3:.4f}, Acc: {train_accuracy:.4f} | "
          f"Val CE1 Loss: {val_loss1:.4f}, Val CE2 Loss: {val_loss2:.4f}, Val CAAM Loss: {val_loss3:.4f}, Acc: {val_accuracy:.4f}")

    if val_loss1 + val_loss2 + val_loss3 < best_loss:
        print('Saving model...')
        best_loss = val_loss1 + val_loss2 + val_loss3 
        torch.save(model.state_dict(), './GIT.pt')

    lr_schedule1.step(val_loss2)
    lr_schedule2.step(val_loss1 + val_loss3)