In [None]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, ConcatDataset
from torchvision import datasets, models
from torchvision import transforms as T
from torchvision.io.image import read_image
from torchvision.transforms.functional import normalize, resize, to_pil_image

import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import precision_recall_fscore_support
import copy
import math
import glob
from PIL import Image
import albumentations as A
from torch.utils.data import DataLoader, ConcatDataset
import albumentations.pytorch
import cv2
import numpy as np
import math

!pip install timm

from google.colab import drive
drive.mount('/content/drive')

%cd /content/drive/MyDrive/skin_classification/tta_skin
# ..
from torchhelper import evaluate, CustomSkinDataset
from backbone import ERM
from adapt_methods import SAR, SHOT, T3A
!unzip "../../lesion_rec/isic2.zip" -d "/content"

In [None]:
# define transforms
# 1) augment training images
# 2) load test images without augmentation

train_transform = A.Compose([
                A.RandomBrightness(limit=0.2, p=0.75),
                A.RandomContrast(limit=0.2, p=0.75),
                A.Transpose(p=0.5),
                A.HorizontalFlip(p=0.5),
                A.VerticalFlip(p=0.5),
                A.OneOf([
                  A.MotionBlur(blur_limit=5),
                  A.MedianBlur(blur_limit=5),
                  A.GaussianBlur(blur_limit=5),
                  A.GaussNoise(var_limit=(5.0, 30.0)),
              ], p=0.7),
                A.CLAHE(clip_limit=4.0, p=0.7),
                A.Resize(224,224),
                A.Normalize(),
                A.pytorch.transforms.ToTensorV2()
                ])

test_transform = A.Compose([
                A.Resize(224,224),
                A.Normalize(),
                A.pytorch.transforms.ToTensorV2()
                ])

In [None]:
backbone = 'deit'
model = ERM(backbone,  1e-4).to('cuda')
for param in model.network[0].parameters():
    param.requires_grad = False

In [None]:
torch.manual_seed(0)
domain_list = ['clean','gel_bubble','hair','dark_corner', 'ruler']
#model_types = ['hvit', 'resnet50']


# specify domains to be used for leave-one-out evaluation
for test_domain in ['dark_corner']:
  print("Testing on domain: ", test_domain)
  # use other domains as training domains
  train_domains = [x for x in domain_list if x != test_domain]
  print("Training on domains: ", train_domains)

  train_datasets = []
  for train_domain in train_domains:
    a = CustomSkinDataset(train_domain, train_transform, ["ben","mel"])
    # upsample minority class
    b = CustomSkinDataset(train_domain, train_transform, ["mel"])
    c = CustomSkinDataset(train_domain, train_transform, ["mel"])
    train_datasets.append(ConcatDataset([a,b,c]))


  # load train and test sets
  image_datasets = {}
  image_datasets['train_set'] = ConcatDataset(train_datasets)
  dataloaders = {}


  dataloaders['train_set'] = DataLoader(image_datasets['train_set'],
                                batch_size=32,
                                shuffle=True,
                                num_workers=2,
                                        pin_memory=True)

  dataloaders[test_domain] = torch.utils.data.DataLoader(CustomDataset(test_domain, test_transform, ["ben","mel"]),
                                batch_size=32,
                                shuffle=True,
                                num_workers=2,
                                pin_memory=True)

  model_trained = train_model(model, dataloaders['train_set'], 10)

  print("Saving trained base model...")
  torch.save(model_trained.state_dict(), './models/' + str(backbone) + str(test_domain)+ 'v1.pth')
  
  print("Evaluating base model")
  acc,pr,rc,f1 = evaluate(model_trained, test_domain)
  print(f'Acc.: {acc.item():.3f} + Rc: {rc:.3f} + Pr. {pr:.3f} + fc. {f1:.3f}')

  print("Evaluating adaptation models") #
  adapt_methods = {'SHOT':SHOT,'T3A':T3A, 'SAR': SAR}
  for adapt_method in adapt_methods:
      print(f'Using: {adapt_method}...')
      mcpy = copy.deepcopy(model_trained)
      adapt_model = adapt_methods[adapt_method](mcpy)
      acc,pr,rc,f1 = evaluate(adapt_model, test_domain, True)
      print(f'{adapt_method} Acc.: {acc.item():.3f} + Rc: {rc:.3f} + Pr.: {pr:.3f} + f1: {f1:.3f}')