In [1]:
! pip install progress

[0m

In [2]:
import sys
sys.path.append('../')
import argparse, os, shutil, time, random, math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import torch.nn.functional as F
import losses
from datasets.cifar100 import get_cifar100
from train.train import get_train_fn, get_update_score_fn
from train.validate import get_valid_fn
from models.net import get_model
from losses.loss import get_loss, get_optimizer
from utils.config import parse_args, reproducibility, dataset_argument
from utils.common import make_imb_data, save_checkpoint, adjust_learning_rate

from tqdm.autonotebook import tqdm
from utils.logger import logger

In [3]:
args = parse_args(run_type = 'jupyter')
reproducibility(args.seed)
args = dataset_argument(args)
best_acc = 0 # best test accuracy

args.logger = logger(args)

---> ---cifar100---
---> Argument
    └> network     : resnet32
    └> epochs      : 200
    └> batch_size  : 128
    └> update_epoch: 1
    └> lr          : 0.1
    └> momentum    : 0.9
    └> wd          : 0.0002
    └> nesterov    : False
    └> scheduler   : warmup
    └> warmup      : 5
    └> aug_prob    : 0.5
    └> cutout      : False
    └> cmo         : False
    └> posthoc_la  : False
    └> cuda        : False
    └> randaug     : False
    └> num_test    : 8
    └> verbose     : False
    └> out         : ./results/cifar100/ce@N_500_ir_100/
    └> data_dir    : /input/dataset/
    └> workers     : 4
    └> seed        : None
    └> gpu         : None
    └> dataset     : cifar100
    └> num_max     : 500
    └> imb_ratio   : 100
    └> loss_fn     : ce
    └> num_experts : 3
    └> num_class   : 100


In [4]:
args.imb_ratio = 100
args.aug_type = 'many'
args.data_dir = '/input/dataset/'
args.aug_prob = 0.5
# args.data_dir = '/home/work/cuda/dataset'

In [5]:
global best_acc

try:
    assert args.num_max <= 50000. / args.num_class
except AssertionError:
    args.num_max = int(50000 / args.num_class)

N_SAMPLES_PER_CLASS = make_imb_data(args.num_max, args.num_class, args.imb_ratio)

if args.dataset == 'cifar100':
    print(f'==> Preparing imbalanced CIFAR-100')
    # trainset, allset = get_cifar100(os.path.join(args.data_dir, 'cifar100/'), N_SAMPLES_PER_CLASS, cutout=args.cutout)
    trainset, allset, devset, testset = get_cifar100(os.path.join(args.data_dir, 'cifar100/'), imb_ratio = args.imb_ratio, \
                                                     cutout = args.cutout,  contrast = args.loss_fn == 'ncl', \
                                                     randaug = args.randaug, aug_prob = args.aug_prob)
    N_SAMPLES_PER_CLASS = trainset.img_num_list
    
    
trainloader = data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=False)
allloader = data.DataLoader(allset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)

[500, 477, 455, 434, 415, 396, 378, 361, 344, 328, 314, 299, 286, 273, 260, 248, 237, 226, 216, 206, 197, 188, 179, 171, 163, 156, 149, 142, 135, 129, 123, 118, 112, 107, 102, 98, 93, 89, 85, 81, 77, 74, 70, 67, 64, 61, 58, 56, 53, 51, 48, 46, 44, 42, 40, 38, 36, 35, 33, 32, 30, 29, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 15, 14, 13, 13, 12, 12, 11, 11, 10, 10, 9, 9, 8, 8, 7, 7, 7, 6, 6, 6, 6, 5, 5, 5, 5]
==> Preparing imbalanced CIFAR-100
Files already downloaded and verified
Magnitude set = tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30], dtype=torch.int32)
Operation set = tensor([ 0,  1,  2,  3,  3,  4,  5,  6,  6,  7,  8,  9,  9, 10, 11, 11, 12, 13,
        14, 14, 15, 16, 17, 17, 18, 19, 20, 20, 21, 22, 22], dtype=torch.int32)
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
#Train: 10847, #Test: 9000


### Fix Feature extractor check

In [6]:
def train_fe(many_aug, med_aug, few_aug, alternate=False):

    #  TODO: Augmentation
    per_class_num = torch.tensor(N_SAMPLES_PER_CLASS)
    many = torch.where(per_class_num > 100)[0].numpy().tolist()
    med = torch.where((per_class_num <= 100) & (per_class_num >=20))[0].numpy().tolist()
    few = torch.where(per_class_num < 20)[0].numpy().tolist()
    
    if alternate:
        aug_state = torch.zeros(len(trainloader.dataset))
        for idx, label in enumerate(trainloader.dataset.targets):
            if label in many:
                aug_state[idx] = many_aug
            elif label in med:
                aug_state[idx] = med_aug
            elif label in few:
                aug_state[idx] = few_aug
            else:
                raise NotImplementedError
                
        orig_state = torch.zeros(len(trainloader.dataset))

    # Model
    print ("==> creating {}".format(args.network))
    model = get_model(args, N_SAMPLES_PER_CLASS)
    args.loss_fn = 'ce'
    train_criterion = get_loss(args, N_SAMPLES_PER_CLASS)
    criterion = nn.CrossEntropyLoss(reduction='mean') # For test, validation 
    optimizer = get_optimizer(args, model)
    train = get_train_fn(args)
    validate = get_valid_fn(args)

    test_accs, dev_accs, all_accs = [],[],[]
    best_acc = 0
    # for epoch in range(args.epochs):
    for epoch in tqdm(range(args.epochs)):
        lr = adjust_learning_rate(optimizer, epoch, None, args)
        if alternate and epoch%2==0:
            trainloader.dataset.curr_state = orig_state
        elif alternate and epoch%2==1:
            trainloader.dataset.curr_state = aug_state
        train_loss = train(args, trainloader, model, optimizer, train_criterion, epoch, None, None)
        all_loss, all_acc, all_cls = validate(args, allloader, model, criterion, N_SAMPLES_PER_CLASS,  num_class=args.num_class, mode='All Valid')
        
        if best_acc < all_acc:
            best_acc = all_acc
            best_cls = all_cls
        
    # Print the final results
    args.logger(f'Final Performance...', level=1)
    args.logger(f'best bAcc (all): {best_acc}', level=2)
    args.logger(f'best bAcc (many): {best_cls[0]}', level=2)
    args.logger(f'best bAcc (med): {best_cls[1]}', level=2)
    args.logger(f'best bAcc (few): {best_cls[2]}', level=2)
        
    return model

In [7]:
def adjust_learning_rate_crt(optimizer, epoch, args):
    epoch = epoch + 1
    if epoch > 8:
        lr = args.lr * 0.001
    elif epoch > 6:
        lr = args.lr * 0.01
    elif epoch > 3:
        lr = args.lr * 0.1
    else:
        lr = args.lr
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr
    
def train_fc(model):
    # Model
    import torch.optim as optim
    args.loss_fn = 'bs'
    
    train_criterion = get_loss(args, N_SAMPLES_PER_CLASS)
    criterion = nn.CrossEntropyLoss(reduction='mean') # For test, validation 
    train = get_train_fn(args)
    validate = get_valid_fn(args)

    # TODO : Freeze implementation
    optimizer = optim.SGD(model.linear.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd)
    
    # TODO : Augmentation change
    for i, state in enumerate(trainloader.dataset.curr_state):
        trainloader.dataset.curr_state[i] = 0
        
    test_accs, dev_accs, all_accs = [],[],[]
    best_acc = 0
    # for epoch in range(args.epochs):
    for epoch in tqdm(range(10)): # args.epochs
        lr = adjust_learning_rate_crt(optimizer, epoch, args)
        train_loss = train(args, trainloader, model, optimizer, train_criterion, epoch, None, None)
        all_loss, all_acc, all_cls = validate(args, allloader, model, criterion, N_SAMPLES_PER_CLASS,  num_class=args.num_class, mode='All Valid')
        # dev_loss, dev_acc, dev_cls = validate(args, devloader, model, criterion, N_SAMPLES_PER_CLASS,  num_class=args.num_class, mode='Dev Valid')
        # test_loss, test_acc, test_cls = validate(args, testloader, model, criterion, N_SAMPLES_PER_CLASS,  num_class=args.num_class, mode='Test Valid')
        
        if best_acc < all_acc:
            best_acc = all_acc
            best_cls = all_cls
        
    
    # Print the final results
    args.logger(f'Final Performance...', level=1)
    args.logger(f'best bAcc (all): {best_acc}', level=2)
    args.logger(f'best bAcc (many): {best_cls[0]}', level=2)
    args.logger(f'best bAcc (med): {best_cls[1]}', level=2)
    args.logger(f'best bAcc (few): {best_cls[2]}', level=2)

In [None]:
# many: 5, med: 0, few: 0
model1 = train_fe(5, 0, 0, alternate=False)
train_fc(model1)

==> creating resnet32


  0%|          | 0/200 [00:00<?, ?it/s]

In [None]:
# many: 0, med: 0, few: 5
model2 = train_fe(0, 0, 5, alternate=False)
train_fc(model2)

In [None]:
# many: 0, med: 0, few: 0
model3 = train_fe(0, 0, 0, alternate=False)
train_fc(model3)