In [1]:
from __future__ import print_function

import argparse
import csv
import os
import sys
import tifffile
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import torch
import glob
import pickle

import datetime
import math
import json
import time

import timm
from timm.models import create_model
from timm.scheduler import CosineLRScheduler
from timm.optim import create_optimizer
from timm.utils import NativeScaler, get_state_dict

from torch.utils.data import Dataset, random_split, DataLoader
from PIL import Image
import torchvision.models as models
import torchvision.transforms as T
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.nn as nn
from torchvision.utils import make_grid

import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler

from customloss import CustomLoss
from datamgr import *
import utils
from engine import *
from dataset import *
from models import *
from samplers import RASampler

import copy
from pathlib import Path
from tqdm import tqdm

In [2]:
def parse_args():
    parser = argparse.ArgumentParser(description= 'CBNA training script')
    parser.add_argument('--model_type'       , default='MLP', help='CNN/MLP/ViT/Fusion')
    #parser.add_argument('--data_split'       , default='train', help='train/test')
    #parser.add_argument('--get'       , default='predictions', help='features/predictions')
    parser.add_argument('--data_set'     , default='CBNA')
    parser.add_argument('--data_path'     , default='Data/irc_patches/')
    parser.add_argument('--file_path'     , default='datafile/')
    parser.add_argument('--output_dir', default='checkpoints/', help='path where to save, empty for no saving')
    parser.add_argument('--device', default='cpu', help='device to use for training / testing')
    parser.add_argument('--batch_size', default=128, type=int, help='Per-GPU batch-size : number of distinct images loaded on one GPU.')
    parser.add_argument('--input_size'  , default=224, type=int, help ='Image size for training')
    parser.add_argument('--drop_path'   , type=float, default=0.05)
    parser.add_argument('--loss_fn'     , default='focal', help='bce/focal')
    parser.add_argument('--weighted'    , type=utils.bool_flag, default=True)
    parser.add_argument('--eval_crop_ratio', default=0.875, type=float, help="Crop ratio for evaluation")
    parser.add_argument('--num_workers', default=8, type=int, help='Number of data loading workers per GPU.')
    parser.add_argument('--seed'       , default=0, type=int, help='seed')
    parser.add_argument('--eval'        , action='store_true')
    parser.add_argument('--thres_method'   , default='global', help='global/adaptive')
    parser.add_argument('--threshold'      , default=0.5, type=float, help='compute TSS, 0.45 for CNN & Fusion, 0.5 for MLP')

    return parser.parse_args('')

In [3]:
args = parse_args()
output_dir = Path(args.output_dir).joinpath('{}/{}/'.format(args.data_set, args.model_type))
args.output_dir = output_dir
if not os.path.isdir(args.output_dir):
    try:
        os.makedirs(args.output_dir, exist_ok = True)
        print("Directory '%s' created successfully" %args.output_dir)
    except OSError as error:
        print("Directory '%s' can not be created")

if args.data_set == 'CBNA':
    args.classes = list(np.genfromtxt(Path(args.file_path).joinpath('classes.txt'), dtype='str'))
    args.num_classes = len(args.classes)
    print('number of labels: ', args.num_classes)

number of labels:  2522


In [4]:
utils.fix_random_seeds(args.seed)
cudnn.benchmark = True

#print(args)

num_tasks = utils.get_world_size()
global_rank = utils.get_rank()

meta_features = None
if args.model_type=="MLP" or args.model_type=="Fusion":
    meta_features = ['LANDOLT_MOIST',
'N_prct', 'pH', 'CN', 'TMeanY', 'TSeason', 'PTotY', 'PSeason', 'RTotY',
'RSeason', 'AMPL', 'LENGTH', 'eauvive', 'clay', 'silt', 'sand', 'cv_alti']

if args.model_type=="MLP":
    dataset_train = CBNA(args, split='train', meta_features=meta_features)
    dataset_val = CBNA(args, split='val', meta_features=meta_features)
    dataset_test = CBNA(args, split='test', meta_features=meta_features)
else: 
    dataset_train = CBNA(args, split='train', meta_features=meta_features, transform=build_transform(args, aug=False))
    dataset_val = CBNA(args, split='val', meta_features=meta_features, transform=build_transform(args, aug=False))
    dataset_test = CBNA(args, split='test', meta_features=meta_features, transform=build_transform(args, aug=False))

print(f"Data loaded: there are {len(dataset_train)} train images, {len(dataset_test)} test images.")

if args.model_type=="MLP" or args.model_type=="Fusion":
    args.n_meta_features = len(dataset_train.meta_features)

args.train_label_cnt = dataset_train.get_label_count()
args.test_label_cnt = dataset_test.get_label_count()

train_sampler = torch.utils.data.SequentialSampler(dataset_train)
train_loader = DataLoader(dataset_train, sampler=train_sampler, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True, drop_last=False)
val_sampler = torch.utils.data.SequentialSampler(dataset_val)
val_loader = DataLoader(dataset_val, sampler=val_sampler, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True, drop_last=False)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
test_loader = DataLoader(dataset_test, sampler=test_sampler, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True, drop_last=False)

print(f"Creating model: {args.model_type}")
if args.model_type == 'MLP':
    model = Mlp_CBNA(args, out_dim=args.num_classes)
elif args.model_type == 'CNN':
    model = Resnet_CBNA(args, out_dim=args.num_classes)
elif args.model_type == 'Fusion':
    model = Fusion_CBNA(args, out_dim=args.num_classes)
elif args.model_type == 'ViT':
    model = Vit_CBNA(args, out_dim=args.num_classes)
model.to(args.device)

n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)

reload_file = utils.get_best_file(args.output_dir)        
print("best_file" , reload_file)
checkpoint = torch.load(reload_file, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model'])
print("model loaded")

Data loaded: there are 98763 train images, 15980 test images.
Creating model: MLP
number of params: 401498
best_file checkpoints/CBNA/MLP/best.pth
model loaded


In [5]:
criterion = CustomLoss(args.loss_fn, args)
evaluate(model, test_loader, criterion, args)

100%|█████████████████████████████████████████| 125/125 [00:03<00:00, 31.99it/s]


{'eval_loss': 0.00020920961105730385,
 'macro_tss': 0.6960844993591309,
 'micro_tss': 0.714148998260498,
 'weighted_tss': 0.6047736406326294}

In [15]:
class ModelWithTemperature(nn.Module):
    """
    A thin decorator, which wraps a model with temperature scaling
    model (nn.Module):
        A classification neural network
        NB: Output of the neural network should be the classification logits,
            NOT the softmax (or log softmax)!
    """
    def __init__(self, model):
        super(ModelWithTemperature, self).__init__()
        self.model = model
        self.temperature = nn.Parameter(torch.ones(len(args.classes))*0.5)

    def forward(self, input):
        logits = self.model(input)
        return self.temperature_scale(logits)

    def temperature_scale(self, logits):
        """
        Perform temperature scaling on logits
        """
        # Expand temperature to match the size of logits
        temperature = self.temperature
        return logits / temperature

    # This function probably should live outside of this class, but whatever
    def set_temperature(self, val_loader):
        
        criterion = CustomLoss(args.loss_fn, args)

        # First: collect all the logits and labels for the validation set
        logits_list = []
        labels_list = []
        with torch.no_grad():
            for data, targets in val_loader:
                if args.model_type=="MLP" or args.model_type=="CNN" or args.model_type=="ViT":
                    samples = data
                    samples = samples.to(args.device, non_blocking=True)
                    logits = self.model(samples)
                elif args.model_type=="Fusion":
                    samples_img, samples_meta = data
                    samples_img = samples_img.to(args.device, non_blocking=True)
                    samples_meta = samples_meta.to(args.device, non_blocking=True)
                    logits = self.model(samples_img, samples_meta)
                
                targets = targets.to(args.device, non_blocking=True)
                logits_list.append(logits)
                labels_list.append(targets)
            logits = torch.cat(logits_list)
            labels = torch.cat(labels_list)

        loss_before_temperature = criterion(logits, labels).item()

        # Next: optimize the temperature w.r.t
        optimizer = torch.optim.AdamW([self.temperature], lr=1e-1, weight_decay=5e-2)

        for i in range(15):
            optimizer.zero_grad()
            loss = criterion(self.temperature_scale(logits), labels)
            print('loss', loss.item())
            loss.backward()
            optimizer.step()

        loss_after_temperature = criterion(self.temperature_scale(logits), labels).item()
        
        print('Before temperature scaling: ', loss_before_temperature)
        #print('Optimal temperature: ', self.temperature)
        print('After temperature scaling: ', loss_after_temperature)

        return self

In [16]:
model_w_tem = ModelWithTemperature(model).to(args.device)
model_w_tem.set_temperature(val_loader)

loss 0.0003092587285209447
loss 0.0002684619394131005
loss 0.00024264470266643912
loss 0.00022566104598809034
loss 0.0002141133154509589
loss 0.00020603384473361075
loss 0.00020023580873385072
loss 0.0001959777728188783
loss 0.00019278208492323756
loss 0.00019033285207115114
loss 0.00018841648125089705
loss 0.00018688567797653377
loss 0.00018563726916909218
loss 0.00018459807324688882
loss 0.00018371552869211882
Before temperature scaling:  0.000200928290723823
After temperature scaling:  0.00018295169866178185


ModelWithTemperature(
  (model): Mlp_CBNA(
    (meta): Sequential(
      (0): Linear(in_features=17, out_features=512, bias=True)
      (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Dropout(p=0.3, inplace=False)
      (4): Linear(in_features=512, out_features=128, bias=True)
      (5): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU(inplace=True)
    )
    (fc): Linear(in_features=128, out_features=2522, bias=True)
  )
)

In [17]:
TP, TN, FP, FN = 0., 0., 0., 0.
label_cnt = 0.
eval_loss = []
model_w_tem.eval()
with torch.no_grad():
    for data, targets in tqdm(test_loader, file=sys.stdout):
        if args.model_type=="MLP" or args.model_type=="CNN" or args.model_type=="ViT":
            samples = data
            samples = samples.to(args.device, non_blocking=True)
            inputs = samples
            outputs = model_w_tem(inputs)
        elif args.model_type=="Fusion":
            samples_img, samples_meta = data
            samples_img = samples_img.to(args.device, non_blocking=True)
            samples_meta = samples_meta.to(args.device, non_blocking=True)
            inputs = (samples_img, samples_meta)
            outputs = model_w_tem(inputs)

        targets = targets.to(args.device, non_blocking=True)
        label_cnt += targets.sum(0).float()
        
        loss = criterion(outputs, targets)
        eval_loss.append(loss.item())

        if args.thres_method == 'adaptive':
            thresholds = torch.load(args.output_dir.joinpath('thresholds_train.pth'))
            thresholds = torch.tensor(thresholds).to(args.device)
        elif args.thres_method == 'global':
            thresholds = args.threshold

        tp, tn, fp, fn = compute_scores(torch.sigmoid(outputs), targets, thresholds)
        TP += tp
        TN += tn
        FP += fp
        FN += fn

    eval_loss = torch.tensor(eval_loss).mean().item()
    weight = label_cnt / label_cnt.sum()
    macro_tss, micro_tss, weighted_tss = compute_metrics(TP, TN, FP, FN, weight=weight)

    eval_stats = {'eval_loss': eval_loss, 'macro_tss': macro_tss, 'micro_tss': micro_tss, 'weighted_tss': weighted_tss}
    print("TSS after temperature scaling: ", eval_stats)

100%|█████████████████████████████████████████| 125/125 [00:03<00:00, 31.74it/s]
TSS after temperature scaling:  {'eval_loss': 0.000195201369933784, 'macro_tss': 0.6960844993591309, 'micro_tss': 0.714148998260498, 'weighted_tss': 0.6047736406326294}
