In [1]:
import os
import time
import datetime 
from pathlib import Path

import numpy as np
import torch
import wandb
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
import torch.utils.data as data # for code compatibility
from torchvision import models, transforms
from torchvision.datasets import CIFAR10, CIFAR100
import timm

import scipy.stats
from sklearn.metrics import auc, roc_curve

import pytorch_lightning as pl

# util
from tqdm import tqdm
from collections import Counter

# customized 
import models.arch as models
# import dataset.cifar10 as dataset
import dataset.cifar10_larger as dataset # for ViT model,

from utils import Logger, AverageMeter, accuracy, mkdir_p, savefig
from progress.bar import Bar as Bar
import utils as utils_
from tensorboardX import SummaryWriter

##
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps")
DEVICE

device(type='cuda')

In [2]:
# TODO: 
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [3]:
'''
lr=0.02
epochs=25
n_shadows = 64
shadow_id = -1 
model = "efficientnet_b7"
dataset = "cifar100"
pkeep = 0.5
savedir = f"exp/{model}_{dataset}"
debug = True
'''

# for vgg19, CIFAR-10
_lr = 0.00002
_epochs = 50
_arch = "vit_large_patch16_224"
_dataset = "cifar10"
_n_classes = 10 # depend on dataset
_debug = True
_batch_size = 16

_n_labeled = 10000
_n_unlabeled = 45000 - _n_labeled

_alpha = 0.75
_lambda_u = 20 # default: 75

_ema_decay = 0.999 # default: 0.999
_T = 0.5 # default: 0.5
_train_iteration = int(max(_n_labeled, _n_unlabeled) / _batch_size) # phase 1: support full-supervised learning, default: 1024

use_cuda = torch.cuda.is_available()

In [4]:
_out = 'experiments/{}/{}@{}_lr_{}_lu_{}_iter_{}'.format(_arch, _dataset, _n_labeled, _lr, _lambda_u, _train_iteration)
_out

'experiments/vit_large_patch16_224/cifar10@10000_lr_2e-05_lu_20_iter_2187'

In [5]:
_train_iteration

2187

In [6]:
SEED = 1583745484

# prepare dataset

In [7]:
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x7ff24c2e82f0>

In [8]:
print(f'==> Preparing cifar10')
transform_train = transforms.Compose([
    dataset.RandomPadandCrop(224),
    dataset.RandomFlip(),
    dataset.ToTensor(),
])

transform_val = transforms.Compose([
    dataset.ToTensor(),
])
datadir = Path().home() / "dataset"

batch_size=_batch_size

# measure time for data pre-processing
start_time = time.time()

train_labeled_set, train_unlabeled_set, val_set, test_set = dataset.get_cifar10(datadir, _n_labeled, transform_train=transform_train, transform_val=transform_val)
labeled_trainloader = data.DataLoader(train_labeled_set, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True)
unlabeled_trainloader = data.DataLoader(train_unlabeled_set, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True)
val_loader = data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=0)
test_loader = data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0)

elapsed_time = time.time() - start_time
print(f"Execution time for get_cifar10: {elapsed_time:.2f} seconds")

==> Preparing cifar10
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
#Labeled: 10000 #Unlabeled: 35000 #Val: 5000
Execution time for get_cifar10: 156.80 seconds


In [9]:
train_labeled_set

Dataset CIFAR10_labeled
    Number of datapoints: 10000
    Root location: /home/dsanyal7/dataset
    Split: Train
    StandardTransform
Transform: Compose(
               <dataset.cifar10_larger.RandomPadandCrop object at 0x7ff13cf85c70>
               <dataset.cifar10_larger.RandomFlip object at 0x7ff13cf85ca0>
               <dataset.cifar10_larger.ToTensor object at 0x7ff13cf85d60>
           )

In [10]:
train_labeled_set[0][0].shape

torch.Size([3, 224, 224])

In [11]:
train_unlabeled_set

Dataset CIFAR10_unlabeled
    Number of datapoints: 35000
    Root location: /home/dsanyal7/dataset
    Split: Train
    StandardTransform
Transform: <dataset.cifar10_larger.TransformTwice object at 0x7ff24b459af0>

In [12]:
# utils_.get_mean_and_std(train_labeled_set)

# expected values:
# (tensor([ 2.1660e-05, -8.8033e-04,  1.0356e-03]),
# tensor([0.8125, 0.8125, 0.7622]))

## model

In [13]:
def create_model(ema=False):
    model = models.network(_arch, pretrained=False, n_classes=_n_classes)
    model = model.cuda()

    if ema:
        for param in model.parameters():
            param.detach_()

    return model

model = create_model()
ema_model = create_model(ema=True)

arch: vit_large_patch16_224, pretrained: False, n_classes: 10
Freezing ViT-Large intermediate layers...
arch: vit_large_patch16_224, pretrained: False, n_classes: 10
Freezing ViT-Large intermediate layers...


In [14]:
# model
print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))

Total params: 303.31M


In [15]:
def linear_rampup(current, rampup_length=_epochs):
    if rampup_length == 0:
        return 1.0
    else:
        current = np.clip(current / rampup_length, 0.0, 1.0)
        return float(current)

In [16]:
class WeigthEMA(object):
    def __init__(self, model, ema_model, alpha=0.999):
        self.model = model
        self.ema_model = ema_model
        self.alpha = alpha
        self.params = list(model.state_dict().values())
        self.ema_params = list(ema_model.state_dict().values())
        self.wd = 0.02 * _lr # for lr=0.02, wd = 0.0004

        for param, ema_param in zip(self.params, self.ema_params):
            param.data.copy_(ema_param.data)

    def step(self):
        one_minus_alpha = 1.0 - self.alpha
        for param, ema_param in zip(self.params, self.ema_params):
            if ema_param.dtype==torch.float32:
                ema_param.mul_(self.alpha)
                ema_param.add_(param * one_minus_alpha)
                # customized weight decay (TODO: should carefully set up)
                # 
                # lr=0.02 -> (0.9996)^5928 (1 epoch): 0.09332.
                # lr=0.002 -> (0.99996)^5928 (1 epoch): 0.78889
                # param.mul_(1 - self.wd) # x 0.9996 (for lr=0.02), x0.99996 (for lr=0

class CustomLoss(object):
    def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch):
        probs_u = torch.softmax(outputs_u, dim=1)
        
        Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1))
        Lu = torch.mean((probs_u - targets_u)**2)

        w = _lambda_u * linear_rampup(epoch) # _lambda_u: 75 (default)

        return Lx, Lu, w

In [17]:
optim = torch.optim.SGD(model.parameters(), lr=_lr, momentum=0.9, weight_decay=5e-4)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=_epochs)

ema_optim = WeigthEMA(model, ema_model, alpha=_ema_decay)

In [18]:
@torch.no_grad()
def get_acc(model, dl):
    acc = []
    for x, y in dl:
        x, y = x.to(DEVICE), y.to(DEVICE)
        acc.append(torch.argmax(model(x), dim=1) == y)
    acc = torch.cat(acc)
    acc = torch.sum(acc) / len(acc)

    return acc.item()

def interleave_offsets(batch, nu):
    groups = [batch // (nu + 1)] * (nu + 1)
    for x in range(batch - sum(groups)):
        groups[-x - 1] += 1
    offsets = [0]
    for g in groups:
        offsets.append(offsets[-1] + g)
    assert offsets[-1] == batch
    return offsets

def interleave(xy, batch):
    nu = len(xy) - 1
    offsets = interleave_offsets(batch, nu)
    xy = [[v[offsets[p]:offsets[p + 1]] for p in range(nu + 1)] for v in xy]
    for i in range(1, nu + 1):
        xy[0][i], xy[i][i] = xy[i][i], xy[0][i]
    return [torch.cat(v, dim=0) for v in xy]

In [19]:
def train(labeled_trainloader, unlabeled_trainloader, model, optimizer, ema_optimizer, criterion, epoch, use_cuda):
    
    batch_time = AverageMeter()
    data_time = AverageMeter()
    
    losses = AverageMeter()
    losses_x = AverageMeter()
    losses_u = AverageMeter()
    ws = AverageMeter()
    end = time.time()
    
    bar = Bar('Training', max=_train_iteration)
    labeled_train_iter = iter(labeled_trainloader)
    unlabeled_train_iter = iter(unlabeled_trainloader)
    
    model.train()

    with tqdm(range(_train_iteration), desc="Training Progress", unit="batch") as progress_bar:
        for batch_idx in progress_bar:
            try:
                inputs_x, targets_x = next(labeled_train_iter)
            except:
                labeled_train_iter = iter(labeled_trainloader)
                inputs_x, targets_x = next(labeled_train_iter)
    
            try:
                (inputs_u, inputs_u2), _ = next(unlabeled_train_iter) # two different augment(x)s;
            except:
                unlabeled_train_iter = iter(unlabeled_trainloader)
                (inputs_u, inputs_u2), _ = next(unlabeled_train_iter)
    
            # measure data loading time
            data_time.update(time.time() - end)
            
            batch_size = inputs_x.size(0)
    
            # convert label to one-hot
            targets_x = torch.zeros(batch_size, _n_classes).scatter_(1, targets_x.view(-1,1).long(), 1)
    
            if use_cuda:
                inputs_x, targets_x = inputs_x.cuda(), targets_x.cuda(non_blocking=True)
                inputs_u = inputs_u.cuda()
                inputs_u2 = inputs_u2.cuda()
    
            with torch.no_grad():
                # compute guessed labels of unlabel samples
                outputs_u = model(inputs_u)
                outputs_u2 = model(inputs_u2)
                p = (torch.softmax(outputs_u, dim=1) + torch.softmax(outputs_u2, dim=1)) / 2
                pt = p**(1/_T)
                targets_u = pt / pt.sum(dim=1, keepdim=True)
                targets_u = targets_u.detach()
            
            # mixup 
            all_inputs = torch.cat([inputs_x, inputs_u, inputs_u2], dim=0)
            all_targets = torch.cat([targets_x, targets_u, targets_u], dim=0)
    
            l = np.random.beta(_alpha, _alpha)
            # l = 1
            l = max(l, 1-l)
            
            idx = torch.randperm(all_inputs.size(0))
            
            input_a, input_b = all_inputs, all_inputs[idx]
            target_a, target_b = all_targets, all_targets[idx]
            
            # input_a > input_b is guaranteed.
            mixed_input = l * input_a + (1 - l) * input_b
            mixed_target = l * target_a + (1 - l) * target_b
    
            # interleave labeled and unlabed samples between batches to get correct batchnorm calculation 
            mixed_input = list(torch.split(mixed_input, batch_size))
            mixed_input = interleave(mixed_input, batch_size)
            
            logits = [model(mixed_input[0])]
            for input in mixed_input[1:]:
                logits.append(model(input))
    
            # put interleaved samples back
            logits = interleave(logits, batch_size)
            logits_x = logits[0]
            logits_u = torch.cat(logits[1:], dim=0)
    
            Lx, Lu, w = criterion(logits_x, mixed_target[:batch_size], logits_u, mixed_target[batch_size:], epoch+batch_idx/_train_iteration)
    
            loss = Lx + w * Lu
    
            losses.update(loss, inputs_x.size(0))
            losses_x.update(Lx, inputs_x.size(0))
            losses_u.update(Lu, inputs_x.size(0))
            ws.update(w, inputs_x.size(0))
    
            optimizer.zero_grad()
            loss.backward()
            
            optimizer.step()
            ema_optimizer.step()
    
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
        
            # plot progress
            bar.suffix  = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | Loss_x: {loss_x:.4f} | Loss_u: {loss_u:.4f} | W: {w:.4f}'.format(
                        batch=batch_idx + 1,
                        size=_train_iteration,
                        data=data_time.avg,
                        bt=batch_time.avg,
                        total=bar.elapsed_td,
                        eta=bar.eta_td,
                        loss=losses.avg,
                        loss_x=losses_x.avg,
                        loss_u=losses_u.avg,
                        w=ws.avg,
                        )
            bar.next()

            progress_bar.set_postfix({
                "Data": f"{data_time.avg:.3f}s",
                "Batch": f"{batch_time.avg:.3f}s",
                "Total": bar.elapsed_td,
                "ETA": bar.eta_td,
                "Loss": f"{losses.avg:.4f}",
                "Loss_x": f"{losses_x.avg:.4f}",
                "Loss_u": f"{losses_u.avg:.4f}",
                "W": f"{ws.avg:.4f}",
            })
            
        bar.finish()
    
    return (losses.avg, losses_x.avg, losses_u.avg,)

In [20]:
def validate(valloader, model, criterion, epoch, use_cuda, mode):

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    bar = Bar(f'{mode}', max=len(valloader))
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(valloader):
            # measure data loading time
            data_time.update(time.time() - end)

            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True)
            # compute output
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(outputs, targets, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1.item(), inputs.size(0))
            top5.update(prec5.item(), inputs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # plot progress
            bar.suffix  = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
                        batch=batch_idx + 1,
                        size=len(valloader),
                        data=data_time.avg,
                        bt=batch_time.avg,
                        total=bar.elapsed_td,
                        eta=bar.eta_td,
                        loss=losses.avg,
                        top1=top1.avg,
                        top5=top5.avg,
                        )
            bar.next()
        bar.finish()
        
    return (losses.avg, top1.avg)

In [21]:
train_criterion = CustomLoss()
criterion = nn.CrossEntropyLoss()

In [22]:
if not os.path.isdir(_out):
    mkdir_p(_out)

title = 'noisy-cifar-10'

logger = Logger(os.path.join(_out, 'log.txt'), title=title)
logger.set_names(['Train Loss', 'Train Loss X', 'Train Loss U',  'Valid Loss', 'Valid Acc.', 'Test Loss', 'Test Acc.'])
writer = SummaryWriter(_out)

best_acc = 0  # best test accuracy
step = 0
test_accs = []

for epoch in range(_epochs):

    train_loss, train_loss_x, train_loss_u = train(labeled_trainloader, unlabeled_trainloader, model, optim, ema_optim, train_criterion, epoch, use_cuda)
    _, train_acc = validate(labeled_trainloader, ema_model, criterion, epoch, use_cuda, mode='Train Stats')
    val_loss, val_acc = validate(val_loader, ema_model, criterion, epoch, use_cuda, mode='Valid Stats')
    test_loss, test_acc = validate(test_loader, ema_model, criterion, epoch, use_cuda, mode='Test Stats ')

    sched.step()

    _test_acc = get_acc(model, test_loader)
    print(f"[Epoch {epoch}] (naive) Test Accuracy: {_test_acc:.4f}")

    print(f"[Epoch {epoch}] Train Accuracy: {train_acc:.4f}")
    print(f"[Epoch {epoch}] Validation Accuracy: {val_acc:.4f}")
    print(f"[Epoch {epoch}] Test Accuracy: {test_acc:.4f}")

    
    step = _train_iteration * (epoch + 1)

    writer.add_scalar('losses/train_loss', train_loss, step)
    writer.add_scalar('losses/valid_loss', val_loss, step)
    writer.add_scalar('losses/test_loss', test_loss, step)

    writer.add_scalar('accuracy/train_acc', train_acc, step)
    writer.add_scalar('accuracy/val_acc', val_acc, step)
    writer.add_scalar('accuracy/test_acc', test_acc, step)

    logger.append([train_loss, train_loss_x, train_loss_u, val_loss, val_acc, test_loss, test_acc])

    best_acc = max(val_acc, best_acc)
    test_accs.append(test_acc)

logger.close()
writer.close()

print('Best acc:')
print(best_acc)

print('Mean acc:')
print(np.mean(test_accs[-20:]))

Training Progress: 100%|██████████| 2187/2187 [26:42<00:00,  1.37batch/s, Data=0.047s, Batch=0.733s, Total=0:26:42, ETA=0:00:00, Loss=2.1575, Loss_x=2.1567, Loss_u=0.0039, W=0.1999]


[Epoch 0] (naive) Test Accuracy: 0.2682
[Epoch 0] Train Accuracy: 23.0000
[Epoch 0] Validation Accuracy: 23.9400
[Epoch 0] Test Accuracy: 24.8700


Training Progress:  65%|██████▌   | 1428/2187 [17:31<09:18,  1.36batch/s, Data=0.048s, Batch=0.736s, Total=0:17:31, ETA=0:09:16, Loss=2.1000, Loss_x=2.0977, Loss_u=0.0043, W=0.5305]


KeyboardInterrupt: 