In [1]:
import os
import time
import datetime 
from pathlib import Path

import numpy as np
import torch
import wandb
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
import torch.utils.data as data # for code compatibility
from torchvision import models, transforms
from torchvision.datasets import CIFAR10, CIFAR100
import timm

import scipy.stats
from sklearn.metrics import auc, roc_curve

import pytorch_lightning as pl

# util
from tqdm import tqdm
from collections import Counter


# customized 
import models.arch as models
import dataset.cifar10 as dataset

from utils import Bar, Logger, AverageMeter, accuracy, mkdir_p, savefig
import utils as utils_

##
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps")

In [2]:
# for vit_large_patch16_224_cifar10, CIFAR-10
'''
lr=0.02
epochs=25
n_shadows = 64
shadow_id = -1 
model = "efficientnet_b7"
dataset = "cifar100"
pkeep = 0.5
savedir = f"exp/{model}_{dataset}"
debug = True
'''

# for vgg19, CIFAR-10
_lr = 0.02
_epochs = 50
_arch = "vgg19"
_dataset = "cifar10"
_n_classes = 10 # depend on dataset
_debug = True
_batch_size = 64
_ema_decay = 0.999 # default 

_n_labeled = 10000
_alpha = 0.75
_lambda_u = 100

In [3]:
SEED = 1583745484

# prepare dataset

In [4]:
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x7f7e7a8c4390>

In [5]:
print(f'==> Preparing cifar10')
transform_train = transforms.Compose([
    dataset.RandomPadandCrop(32),
    dataset.RandomFlip(),
    dataset.ToTensor(),
])

transform_val = transforms.Compose([
    dataset.ToTensor(),
])
datadir = Path().home() / "dataset"

batch_size=_batch_size

train_labeled_set, train_unlabeled_set, val_set, test_set = dataset.get_cifar10(datadir, _n_labeled, transform_train=transform_train, transform_val=transform_val)
labeled_trainloader = data.DataLoader(train_labeled_set, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True)
unlabeled_trainloader = data.DataLoader(train_unlabeled_set, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True)
val_loader = data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=0)
test_loader = data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0)

==> Preparing cifar10
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
#Labeled: 10000 #Unlabeled: 35000 #Val: 5000


In [6]:
train_labeled_set

Dataset CIFAR10_labeled
    Number of datapoints: 10000
    Root location: /home/dsanyal7/dataset
    Split: Train
    StandardTransform
Transform: Compose(
               <dataset.cifar10.RandomPadandCrop object at 0x7f7d6b5aeb50>
               <dataset.cifar10.RandomFlip object at 0x7f7d6b5aeb80>
               <dataset.cifar10.ToTensor object at 0x7f7d6b5aec40>
           )

In [7]:
train_unlabeled_set

Dataset CIFAR10_unlabeled
    Number of datapoints: 35000
    Root location: /home/dsanyal7/dataset
    Split: Train
    StandardTransform
Transform: <dataset.cifar10.TransformTwice object at 0x7f7d6e5d57c0>

In [8]:
# utils_.get_mean_and_std(train_labeled_set)

# expected values:
# (tensor([ 2.1660e-05, -8.8033e-04,  1.0356e-03]),
# tensor([0.8125, 0.8125, 0.7622]))

In [19]:
# adapting MixMatch's data loader...
train_dl = labeled_trainloader
test_dl = test_loader

In [20]:
def create_model(ema=False):
    model = models.network(_arch, pretrained=False, n_classes=_n_classes)
    model = model.cuda()

    if ema:
        for param in model.parameters():
            param.detach_()

    return model

model = create_model()
ema_model = create_model(ema=True)

arch: vgg19, pretrained: False, n_classes: 10
arch: vgg19, pretrained: False, n_classes: 10


In [21]:
# model
print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))

Total params: 139.61M


In [22]:
class WeigthEMA(object):
    def __init__(self, model, ema_model, alpha=0.999):
        self.model = model
        self.ema_model = ema_model
        self.alpha = alpha
        self.params = list(model.state_dict().values())
        self.ema_params = list(ema_model.state_dict().values())
        self.wd = 0.02 * _lr

        for param, ema_param in zip(self.params, self.ema_params):
            param.data.copy_(ema_param.data)

    def step(self):
        one_minus_alpha = 1.0 - self.alpha
        for param, ema_param in zip(self.params, self.ema_params):
            if ema_param.dtype==torch.float32:
                ema_param.mul_(self.alpha)
                ema_param.add_(param * one_minus_alpha)
                # customized weight decay
                param.mul_(1 - self.wd)

In [23]:
optim = torch.optim.SGD(model.parameters(), lr=_lr, momentum=0.9, weight_decay=5e-4)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=_epochs)

ema_optim = WeigthEMA(model, ema_model, alpha=_ema_decay)

In [24]:
@torch.no_grad()
def get_acc(model, dl):
    acc = []
    for x, y in dl:
        x, y = x.to(DEVICE), y.to(DEVICE)
        acc.append(torch.argmax(model(x), dim=1) == y)
    acc = torch.cat(acc)
    acc = torch.sum(acc) / len(acc)

    return acc.item()


def interleave_offsets(batch, nu):
    groups = [batch // (nu + 1)] * (nu + 1)
    for x in range(batch - sum(groups)):
        groups[-x - 1] += 1
    offsets = [0]
    for g in groups:
        offsets.append(offsets[-1] + g)
    assert offsets[-1] == batch
    return offsets


def interleave(xy, batch):
    nu = len(xy) - 1
    offsets = interleave_offsets(batch, nu)
    xy = [[v[offsets[p]:offsets[p + 1]] for p in range(nu + 1)] for v in xy]
    for i in range(1, nu + 1):
        xy[0][i], xy[i][i] = xy[i][i], xy[0][i]
    return [torch.cat(v, dim=0) for v in xy]

In [25]:
def train(labeled_trainloader, unlabeled_trainloader, model, optimizer, ema_optimizer, criterion, epoch, use_cuda):
    
    model.train()
    
    losses = AverageMeter()
    losses_x = AverageMeter()
    losses_u = AverageMeter()
    
    loss_total = 0
    pbar = tqdm(train_dl)
    for itr, (inputs_x, targets_x) in enumerate(pbar):
        inputs_x, targets_x = inputs_x.to(DEVICE), targets_x.to(DEVICE)

        batch_size = inputs_x.size(0)

        # convert label to one-hot
        targets_x = torch.zeros(batch_size, _n_classes, device=DEVICE).scatter_(1, targets_x.view(-1,1).long(), 1)

        # mixup 
        all_inputs = torch.cat([inputs_x], dim=0)
        all_targets = torch.cat([targets_x], dim=0)

        l = np.random.beta(_alpha, _alpha)
        # l = 1
        l = max(l, l-1)
        
        idx = torch.randperm(all_inputs.size(0))
        
        input_a, input_b = all_inputs, all_inputs[idx]
        target_a, target_b = all_targets, all_targets[idx]
        
        mixed_input = l * input_a + (1 - l) * input_b
        mixed_target = l * target_a + (1 - l) * target_b

        # interleave labeled and unlabed samples between batches to get correct batchnorm calculation 
        mixed_input = list(torch.split(mixed_input, batch_size))
        mixed_input = interleave(mixed_input, batch_size)

        outputs_x = model(inputs_x)
        Lx, Lu, w = criterion(outputs_x, targets_x, epoch)

        loss = Lx + w * Lu

        losses.update(loss, inputs_x.size(0))
        losses_x.update(Lx, inputs_x.size(0))
        losses_u.update(Lu, inputs_x.size(0))
        
        pbar.set_postfix_str(f"loss: {loss:.2f}")

        optimizer.zero_grad()
        loss.backward()
        
        optimizer.step()
        # ema_optimizer.step()
    
    return None
    

In [26]:
class CustomLoss(object):
    # def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch):
    def __call__(self, outputs_x, targets_x, epoch):
        Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1))
        Lu = 0
        w = 0

        return Lx, Lu, w

In [27]:
train_criterion = CustomLoss()

In [None]:
for epoch in range(_epochs):

    train(train_dl, None, model, optim, ema_optim, train_criterion, epoch, use_cuda=True)
    sched.step()

    test_acc = get_acc(model, test_dl)
    print(f"[Epoch {epoch}] Test Accuracy: {test_acc:.4f}")

print(f"[test] acc_test: {get_acc(model, test_dl):.4f}")

100%|██████████| 156/156 [00:04<00:00, 38.56it/s, loss: 2.20]


[Epoch 0] Test Accuracy: 0.1234


100%|██████████| 156/156 [00:04<00:00, 38.50it/s, loss: 2.08]


[Epoch 1] Test Accuracy: 0.1705


100%|██████████| 156/156 [00:04<00:00, 38.54it/s, loss: 2.13]


[Epoch 2] Test Accuracy: 0.1874


100%|██████████| 156/156 [00:04<00:00, 38.35it/s, loss: 1.86]


[Epoch 3] Test Accuracy: 0.2591


100%|██████████| 156/156 [00:04<00:00, 38.63it/s, loss: 1.78]


[Epoch 4] Test Accuracy: 0.2505


100%|██████████| 156/156 [00:04<00:00, 38.64it/s, loss: 1.84]


[Epoch 5] Test Accuracy: 0.2921


100%|██████████| 156/156 [00:04<00:00, 38.58it/s, loss: 1.56]


[Epoch 6] Test Accuracy: 0.3186


100%|██████████| 156/156 [00:04<00:00, 38.57it/s, loss: 1.80]


[Epoch 7] Test Accuracy: 0.3517


100%|██████████| 156/156 [00:04<00:00, 38.50it/s, loss: 1.76]


[Epoch 8] Test Accuracy: 0.3047


100%|██████████| 156/156 [00:04<00:00, 38.54it/s, loss: 1.49]


[Epoch 9] Test Accuracy: 0.3464


100%|██████████| 156/156 [00:04<00:00, 38.49it/s, loss: 1.65]


[Epoch 10] Test Accuracy: 0.4041


100%|██████████| 156/156 [00:04<00:00, 38.68it/s, loss: 1.18]


[Epoch 11] Test Accuracy: 0.4461


100%|██████████| 156/156 [00:04<00:00, 38.53it/s, loss: 1.49]


[Epoch 12] Test Accuracy: 0.4726


100%|██████████| 156/156 [00:04<00:00, 38.72it/s, loss: 1.38]


[Epoch 13] Test Accuracy: 0.5256


100%|██████████| 156/156 [00:03<00:00, 39.29it/s, loss: 1.73]


[Epoch 14] Test Accuracy: 0.5198


100%|██████████| 156/156 [00:04<00:00, 38.64it/s, loss: 1.29]


[Epoch 15] Test Accuracy: 0.5302


100%|██████████| 156/156 [00:04<00:00, 38.78it/s, loss: 1.28]


[Epoch 16] Test Accuracy: 0.5527


100%|██████████| 156/156 [00:04<00:00, 38.64it/s, loss: 1.47]


[Epoch 17] Test Accuracy: 0.5924


100%|██████████| 156/156 [00:04<00:00, 38.45it/s, loss: 1.25]


[Epoch 18] Test Accuracy: 0.5380


100%|██████████| 156/156 [00:04<00:00, 38.38it/s, loss: 1.23]


[Epoch 19] Test Accuracy: 0.5979


100%|██████████| 156/156 [00:04<00:00, 38.68it/s, loss: 1.20]


[Epoch 20] Test Accuracy: 0.6078


100%|██████████| 156/156 [00:04<00:00, 38.47it/s, loss: 0.86]
