In [1]:
import os
import time
import datetime 
from pathlib import Path

import numpy as np
import torch
import wandb
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
import torch.utils.data as data # for code compatibility
from torchvision import models, transforms
from torchvision.datasets import CIFAR10, CIFAR100
import timm

import scipy.stats
from sklearn.metrics import auc, roc_curve

import pytorch_lightning as pl

# util
from tqdm import tqdm
from collections import Counter


# customized 
import models.arch as models
import dataset.cifar10 as dataset

from utils import Logger, AverageMeter, accuracy, mkdir_p, savefig
from progress.bar import Bar as Bar
import utils as utils_

##
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps")

In [2]:
# for vit_large_patch16_224_cifar10, CIFAR-10
'''
lr=0.02
epochs=25
n_shadows = 64
shadow_id = -1 
model = "efficientnet_b7"
dataset = "cifar100"
pkeep = 0.5
savedir = f"exp/{model}_{dataset}"
debug = True
'''

# for vgg19, CIFAR-10
_lr = 0.02
_epochs = 50
_arch = "vgg19"
_dataset = "cifar10"
_n_classes = 10 # depend on dataset
_debug = True
_batch_size = 64

_n_labeled = 10000
_alpha = 0.75
_lambda_u = 100

_ema_decay = 0.999 # default 
_T = 0.5
_train_iteration = int(_n_labeled / _batch_size) # phase 1: support full-supervised learning

In [3]:
_train_iteration

156

In [4]:
SEED = 1583745484

# prepare dataset

In [5]:
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x7f6e0c0e62f0>

In [6]:
print(f'==> Preparing cifar10')
transform_train = transforms.Compose([
    dataset.RandomPadandCrop(32),
    dataset.RandomFlip(),
    dataset.ToTensor(),
])

transform_val = transforms.Compose([
    dataset.ToTensor(),
])
datadir = Path().home() / "dataset"

batch_size=_batch_size

train_labeled_set, train_unlabeled_set, val_set, test_set = dataset.get_cifar10(datadir, _n_labeled, transform_train=transform_train, transform_val=transform_val)
labeled_trainloader = data.DataLoader(train_labeled_set, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True)
unlabeled_trainloader = data.DataLoader(train_unlabeled_set, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True)
val_loader = data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=0)
test_loader = data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0)

==> Preparing cifar10
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
#Labeled: 10000 #Unlabeled: 35000 #Val: 5000


In [7]:
train_labeled_set

Dataset CIFAR10_labeled
    Number of datapoints: 10000
    Root location: /home/dsanyal7/dataset
    Split: Train
    StandardTransform
Transform: Compose(
               <dataset.cifar10.RandomPadandCrop object at 0x7f6cfcdc0ac0>
               <dataset.cifar10.RandomFlip object at 0x7f6cfcdc0af0>
               <dataset.cifar10.ToTensor object at 0x7f6cfcdc0bb0>
           )

In [8]:
train_unlabeled_set

Dataset CIFAR10_unlabeled
    Number of datapoints: 35000
    Root location: /home/dsanyal7/dataset
    Split: Train
    StandardTransform
Transform: <dataset.cifar10.TransformTwice object at 0x7f6df9c82a30>

In [9]:
# utils_.get_mean_and_std(train_labeled_set)

# expected values:
# (tensor([ 2.1660e-05, -8.8033e-04,  1.0356e-03]),
# tensor([0.8125, 0.8125, 0.7622]))

## model

In [10]:
def create_model(ema=False):
    model = models.network(_arch, pretrained=False, n_classes=_n_classes)
    model = model.cuda()

    if ema:
        for param in model.parameters():
            param.detach_()

    return model

model = create_model()
ema_model = create_model(ema=True)

arch: vgg19, pretrained: False, n_classes: 10




arch: vgg19, pretrained: False, n_classes: 10


In [11]:
# model
print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))

Total params: 139.61M


In [12]:
class WeigthEMA(object):
    def __init__(self, model, ema_model, alpha=0.999):
        self.model = model
        self.ema_model = ema_model
        self.alpha = alpha
        self.params = list(model.state_dict().values())
        self.ema_params = list(ema_model.state_dict().values())
        self.wd = 0.02 * _lr # for lr=0.02, wd = 0.0004

        for param, ema_param in zip(self.params, self.ema_params):
            param.data.copy_(ema_param.data)

    def step(self):
        one_minus_alpha = 1.0 - self.alpha
        for param, ema_param in zip(self.params, self.ema_params):
            if ema_param.dtype==torch.float32:
                ema_param.mul_(self.alpha)
                ema_param.add_(param * one_minus_alpha)
                # customized weight decay (TODO: should carefully set up)
                # 
                # lr=0.02 -> (0.9996)^5928 (1 epoch): 0.09332.
                # lr=0.002 -> (0.99996)^5928 (1 epoch): 0.78889
                # param.mul_(1 - self.wd) # x 0.9996 (for lr=0.02), x0.99996 (for lr=0

class CustomLoss(object):
    def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch):
        probs_u = torch.softmax(outputs_u, dim=1)
        
        Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1))
        Lu = torch.mean((probs_u - targets_u)**2)
        w = 0

        return Lx, Lu, w

In [13]:
optim = torch.optim.SGD(model.parameters(), lr=_lr, momentum=0.9, weight_decay=5e-4)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=_epochs)

ema_optim = WeigthEMA(model, ema_model, alpha=_ema_decay)

In [14]:
@torch.no_grad()
def get_acc(model, dl):
    acc = []
    for x, y in dl:
        x, y = x.to(DEVICE), y.to(DEVICE)
        acc.append(torch.argmax(model(x), dim=1) == y)
    acc = torch.cat(acc)
    acc = torch.sum(acc) / len(acc)

    return acc.item()

def interleave_offsets(batch, nu):
    groups = [batch // (nu + 1)] * (nu + 1)
    for x in range(batch - sum(groups)):
        groups[-x - 1] += 1
    offsets = [0]
    for g in groups:
        offsets.append(offsets[-1] + g)
    assert offsets[-1] == batch
    return offsets

def interleave(xy, batch):
    nu = len(xy) - 1
    offsets = interleave_offsets(batch, nu)
    xy = [[v[offsets[p]:offsets[p + 1]] for p in range(nu + 1)] for v in xy]
    for i in range(1, nu + 1):
        xy[0][i], xy[i][i] = xy[i][i], xy[0][i]
    return [torch.cat(v, dim=0) for v in xy]

def linear_rampup(current, rampup_length=_epochs):
    if rampup_length == 0:
        return 1.0
    else:
        current = np.clip(current / rampup_length, 0.0, 1.0)
        return float(current)

In [15]:
def train(labeled_trainloader, unlabeled_trainloader, model, optimizer, ema_optimizer, criterion, epoch, use_cuda):
    
    batch_time = AverageMeter()
    data_time = AverageMeter()
    
    losses = AverageMeter()
    losses_x = AverageMeter()
    losses_u = AverageMeter()
    ws = AverageMeter()
    end = time.time()
    
    bar = Bar('Training', max=_train_iteration)
    labeled_train_iter = iter(labeled_trainloader)
    unlabeled_train_iter = iter(unlabeled_trainloader)
    
    model.train()

    with tqdm(range(_train_iteration), desc="Training Progress", unit="batch") as progress_bar:
        for batch_idx in progress_bar:
    
            inputs_x, targets_x      = next(labeled_train_iter)
            (inputs_u, inputs_u2), _ = next(unlabeled_train_iter)
    
            # measure data loading time
            data_time.update(time.time() - end)
            
            batch_size = inputs_x.size(0)
    
            # convert label to one-hot
            targets_x = torch.zeros(batch_size, _n_classes).scatter_(1, targets_x.view(-1,1).long(), 1)
    
            if use_cuda:
                inputs_x, targets_x = inputs_x.cuda(), targets_x.cuda(non_blocking=True)
                inputs_u = inputs_u.cuda()
                inputs_u2 = inputs_u2.cuda()
    
            with torch.no_grad():
                # compute guessed labels of unlabel samples
                outputs_u = model(inputs_u)
                outputs_u2 = model(inputs_u2)
                p = (torch.softmax(outputs_u, dim=1) + torch.softmax(outputs_u2, dim=1)) / 2
                pt = p**(1/_T)
                targets_u = pt / pt.sum(dim=1, keepdim=True)
                targets_u = targets_u.detach()
            
            # mixup 
            all_inputs = torch.cat([inputs_x, inputs_u, inputs_u2], dim=0)
            all_targets = torch.cat([targets_x, targets_u, targets_u], dim=0)
    
            l = np.random.beta(_alpha, _alpha)
            # l = 1
            l = max(l, 1-l)
            
            idx = torch.randperm(all_inputs.size(0))
            
            input_a, input_b = all_inputs, all_inputs[idx]
            target_a, target_b = all_targets, all_targets[idx]
            
            # input_a > input_b is guaranteed.
            mixed_input = l * input_a + (1 - l) * input_b
            mixed_target = l * target_a + (1 - l) * target_b
    
            # interleave labeled and unlabed samples between batches to get correct batchnorm calculation 
            mixed_input = list(torch.split(mixed_input, batch_size))
            mixed_input = interleave(mixed_input, batch_size)
            
            logits = [model(mixed_input[0])]
            for input in mixed_input[1:]:
                logits.append(model(input))
    
            # put interleaved samples back
            logits = interleave(logits, batch_size)
            logits_x = logits[0]
            logits_u = torch.cat(logits[1:], dim=0)
    
            Lx, Lu, w = criterion(logits_x, mixed_target[:batch_size], logits_u, mixed_target[batch_size:], epoch+batch_idx/_train_iteration)
    
            loss = Lx + w * Lu
    
            losses.update(loss, inputs_x.size(0))
            losses_x.update(Lx, inputs_x.size(0))
            losses_u.update(Lu, inputs_x.size(0))
    
            optimizer.zero_grad()
            loss.backward()
            
            optimizer.step()
            ema_optimizer.step()
    
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
        
            # plot progress
            bar.suffix  = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | Loss_x: {loss_x:.4f} | Loss_u: {loss_u:.4f} | W: {w:.4f}'.format(
                        batch=batch_idx + 1,
                        size=_train_iteration,
                        data=data_time.avg,
                        bt=batch_time.avg,
                        total=bar.elapsed_td,
                        eta=bar.eta_td,
                        loss=losses.avg,
                        loss_x=losses_x.avg,
                        loss_u=losses_u.avg,
                        w=ws.avg,
                        )
            bar.next()

            progress_bar.set_postfix({
                "Data": f"{data_time.avg:.3f}s",
                "Batch": f"{batch_time.avg:.3f}s",
                "Total": bar.elapsed_td,
                "ETA": bar.eta_td,
                "Loss": f"{losses.avg:.4f}",
                "Loss_x": f"{losses_x.avg:.4f}",
                "Loss_u": f"{losses_u.avg:.4f}",
                "W": f"{ws.avg:.4f}",
            })
            
        bar.finish()
    
    return (losses.avg, losses_x.avg, losses_u.avg,)

In [16]:
train_criterion = CustomLoss()

In [None]:
for epoch in range(_epochs):

    train(labeled_trainloader, unlabeled_trainloader, model, optim, ema_optim, train_criterion, epoch, use_cuda=True)
    sched.step()

    test_acc = get_acc(model, test_loader)
    print(f"[Epoch {epoch}] Test Accuracy: {test_acc:.4f}")

print(f"[test] acc_test: {get_acc(model, test_loader):.4f}")

Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.79batch/s, Data=0.046s, Batch=0.073s, Total=0:00:11, ETA=0:00:00, Loss=2.2733, Loss_x=2.2733, Loss_u=0.0027, W=0.0000]


[Epoch 0] Test Accuracy: 0.1664


Training Progress: 100%|██████████| 156/156 [00:10<00:00, 14.23batch/s, Data=0.048s, Batch=0.071s, Total=0:00:11, ETA=0:00:00, Loss=2.2030, Loss_x=2.2030, Loss_u=0.0033, W=0.0000]


[Epoch 1] Test Accuracy: 0.1917


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 14.12batch/s, Data=0.048s, Batch=0.071s, Total=0:00:11, ETA=0:00:00, Loss=2.1111, Loss_x=2.1111, Loss_u=0.0039, W=0.0000]


[Epoch 2] Test Accuracy: 0.2197


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 14.11batch/s, Data=0.047s, Batch=0.072s, Total=0:00:11, ETA=0:00:00, Loss=2.0948, Loss_x=2.0948, Loss_u=0.0045, W=0.0000]


[Epoch 3] Test Accuracy: 0.2553


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 14.14batch/s, Data=0.047s, Batch=0.071s, Total=0:00:11, ETA=0:00:00, Loss=2.0362, Loss_x=2.0362, Loss_u=0.0048, W=0.0000]


[Epoch 4] Test Accuracy: 0.2700


Training Progress: 100%|██████████| 156/156 [00:10<00:00, 14.20batch/s, Data=0.047s, Batch=0.071s, Total=0:00:11, ETA=0:00:00, Loss=2.0291, Loss_x=2.0291, Loss_u=0.0057, W=0.0000]


[Epoch 5] Test Accuracy: 0.2775


Training Progress: 100%|██████████| 156/156 [00:10<00:00, 14.19batch/s, Data=0.047s, Batch=0.071s, Total=0:00:11, ETA=0:00:00, Loss=1.9924, Loss_x=1.9924, Loss_u=0.0061, W=0.0000]


[Epoch 6] Test Accuracy: 0.3183


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 14.09batch/s, Data=0.050s, Batch=0.072s, Total=0:00:11, ETA=0:00:00, Loss=1.9815, Loss_x=1.9815, Loss_u=0.0063, W=0.0000]


[Epoch 7] Test Accuracy: 0.3257


Training Progress:  62%|██████▏   | 97/156 [00:06<00:04, 14.32batch/s, Data=0.048s, Batch=0.073s, Total=0:00:07, ETA=0:00:05, Loss=1.9569, Loss_x=1.9569, Loss_u=0.0065, W=0.0000]