In [1]:
import os
import time
import datetime 
from pathlib import Path

import numpy as np
import torch
import wandb
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
import torch.utils.data as data # for code compatibility
from torchvision import models, transforms
from torchvision.datasets import CIFAR10, CIFAR100
import timm

import scipy.stats
from sklearn.metrics import auc, roc_curve

import pytorch_lightning as pl

# util
from tqdm import tqdm
from collections import Counter

# customized 
import models.arch as models
import dataset.cifar10 as dataset

from utils import Logger, AverageMeter, accuracy, mkdir_p, savefig
from progress.bar import Bar as Bar
import utils as utils_
from tensorboardX import SummaryWriter

##
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps")

In [2]:
# for vit_large_patch16_224_cifar10, CIFAR-10
'''
lr=0.02
epochs=25
n_shadows = 64
shadow_id = -1 
model = "efficientnet_b7"
dataset = "cifar100"
pkeep = 0.5
savedir = f"exp/{model}_{dataset}"
debug = True
'''

# for vgg19, CIFAR-10
_lr = 0.02
_epochs = 50
_arch = "vgg19"
_dataset = "cifar10"
_n_classes = 10 # depend on dataset
_debug = True
_batch_size = 64

_n_labeled = 10000
_alpha = 0.75
_lambda_u = 20 # default: 75

_ema_decay = 0.999 # default: 0.999
_T = 0.5 # default: 0.5
_train_iteration = int(_n_labeled / _batch_size) # phase 1: support full-supervised learning, default: 1024

use_cuda = torch.cuda.is_available()

In [3]:
_out = 'cifar10@%d' % (_n_labeled)
_out

'cifar10@10000'

In [4]:
_train_iteration

156

In [5]:
SEED = 1583745484

# prepare dataset

In [6]:
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x7f93b3c462f0>

In [7]:
print(f'==> Preparing cifar10')
transform_train = transforms.Compose([
    dataset.RandomPadandCrop(32),
    dataset.RandomFlip(),
    dataset.ToTensor(),
])

transform_val = transforms.Compose([
    dataset.ToTensor(),
])
datadir = Path().home() / "dataset"

batch_size=_batch_size

train_labeled_set, train_unlabeled_set, val_set, test_set = dataset.get_cifar10(datadir, _n_labeled, transform_train=transform_train, transform_val=transform_val)
labeled_trainloader = data.DataLoader(train_labeled_set, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True)
unlabeled_trainloader = data.DataLoader(train_unlabeled_set, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True)
val_loader = data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=0)
test_loader = data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0)

==> Preparing cifar10
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
#Labeled: 10000 #Unlabeled: 35000 #Val: 5000


In [8]:
train_labeled_set

Dataset CIFAR10_labeled
    Number of datapoints: 10000
    Root location: /home/dsanyal7/dataset
    Split: Train
    StandardTransform
Transform: Compose(
               <dataset.cifar10.RandomPadandCrop object at 0x7f93a17d7880>
               <dataset.cifar10.RandomFlip object at 0x7f92a48a6670>
               <dataset.cifar10.ToTensor object at 0x7f92a48a66a0>
           )

In [9]:
train_unlabeled_set

Dataset CIFAR10_unlabeled
    Number of datapoints: 35000
    Root location: /home/dsanyal7/dataset
    Split: Train
    StandardTransform
Transform: <dataset.cifar10.TransformTwice object at 0x7f93a17d7730>

In [10]:
# utils_.get_mean_and_std(train_labeled_set)

# expected values:
# (tensor([ 2.1660e-05, -8.8033e-04,  1.0356e-03]),
# tensor([0.8125, 0.8125, 0.7622]))

## model

In [11]:
def create_model(ema=False):
    model = models.network(_arch, pretrained=False, n_classes=_n_classes)
    model = model.cuda()

    if ema:
        for param in model.parameters():
            param.detach_()

    return model

model = create_model()
ema_model = create_model(ema=True)

arch: vgg19, pretrained: False, n_classes: 10




arch: vgg19, pretrained: False, n_classes: 10


In [12]:
# model
print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))

Total params: 139.61M


In [13]:
def linear_rampup(current, rampup_length=_epochs):
    if rampup_length == 0:
        return 1.0
    else:
        current = np.clip(current / rampup_length, 0.0, 1.0)
        return float(current)

In [14]:
class WeigthEMA(object):
    def __init__(self, model, ema_model, alpha=0.999):
        self.model = model
        self.ema_model = ema_model
        self.alpha = alpha
        self.params = list(model.state_dict().values())
        self.ema_params = list(ema_model.state_dict().values())
        self.wd = 0.02 * _lr # for lr=0.02, wd = 0.0004

        for param, ema_param in zip(self.params, self.ema_params):
            param.data.copy_(ema_param.data)

    def step(self):
        one_minus_alpha = 1.0 - self.alpha
        for param, ema_param in zip(self.params, self.ema_params):
            if ema_param.dtype==torch.float32:
                ema_param.mul_(self.alpha)
                ema_param.add_(param * one_minus_alpha)
                # customized weight decay (TODO: should carefully set up)
                # 
                # lr=0.02 -> (0.9996)^5928 (1 epoch): 0.09332.
                # lr=0.002 -> (0.99996)^5928 (1 epoch): 0.78889
                # param.mul_(1 - self.wd) # x 0.9996 (for lr=0.02), x0.99996 (for lr=0

class CustomLoss(object):
    def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch):
        probs_u = torch.softmax(outputs_u, dim=1)
        
        Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1))
        Lu = torch.mean((probs_u - targets_u)**2)

        w = _lambda_u * linear_rampup(epoch) # _lambda_u: 75 (default)

        return Lx, Lu, w

In [15]:
optim = torch.optim.SGD(model.parameters(), lr=_lr, momentum=0.9, weight_decay=5e-4)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=_epochs)

ema_optim = WeigthEMA(model, ema_model, alpha=_ema_decay)

In [16]:
@torch.no_grad()
def get_acc(model, dl):
    acc = []
    for x, y in dl:
        x, y = x.to(DEVICE), y.to(DEVICE)
        acc.append(torch.argmax(model(x), dim=1) == y)
    acc = torch.cat(acc)
    acc = torch.sum(acc) / len(acc)

    return acc.item()

def interleave_offsets(batch, nu):
    groups = [batch // (nu + 1)] * (nu + 1)
    for x in range(batch - sum(groups)):
        groups[-x - 1] += 1
    offsets = [0]
    for g in groups:
        offsets.append(offsets[-1] + g)
    assert offsets[-1] == batch
    return offsets

def interleave(xy, batch):
    nu = len(xy) - 1
    offsets = interleave_offsets(batch, nu)
    xy = [[v[offsets[p]:offsets[p + 1]] for p in range(nu + 1)] for v in xy]
    for i in range(1, nu + 1):
        xy[0][i], xy[i][i] = xy[i][i], xy[0][i]
    return [torch.cat(v, dim=0) for v in xy]

In [17]:
def train(labeled_trainloader, unlabeled_trainloader, model, optimizer, ema_optimizer, criterion, epoch, use_cuda):
    
    batch_time = AverageMeter()
    data_time = AverageMeter()
    
    losses = AverageMeter()
    losses_x = AverageMeter()
    losses_u = AverageMeter()
    ws = AverageMeter()
    end = time.time()
    
    bar = Bar('Training', max=_train_iteration)
    labeled_train_iter = iter(labeled_trainloader)
    unlabeled_train_iter = iter(unlabeled_trainloader)
    
    model.train()

    with tqdm(range(_train_iteration), desc="Training Progress", unit="batch") as progress_bar:
        for batch_idx in progress_bar:
    
            inputs_x, targets_x      = next(labeled_train_iter)
            (inputs_u, inputs_u2), _ = next(unlabeled_train_iter)
    
            # measure data loading time
            data_time.update(time.time() - end)
            
            batch_size = inputs_x.size(0)
    
            # convert label to one-hot
            targets_x = torch.zeros(batch_size, _n_classes).scatter_(1, targets_x.view(-1,1).long(), 1)
    
            if use_cuda:
                inputs_x, targets_x = inputs_x.cuda(), targets_x.cuda(non_blocking=True)
                inputs_u = inputs_u.cuda()
                inputs_u2 = inputs_u2.cuda()
    
            with torch.no_grad():
                # compute guessed labels of unlabel samples
                outputs_u = model(inputs_u)
                outputs_u2 = model(inputs_u2)
                p = (torch.softmax(outputs_u, dim=1) + torch.softmax(outputs_u2, dim=1)) / 2
                pt = p**(1/_T)
                targets_u = pt / pt.sum(dim=1, keepdim=True)
                targets_u = targets_u.detach()
            
            # mixup 
            all_inputs = torch.cat([inputs_x, inputs_u, inputs_u2], dim=0)
            all_targets = torch.cat([targets_x, targets_u, targets_u], dim=0)
    
            l = np.random.beta(_alpha, _alpha)
            # l = 1
            l = max(l, 1-l)
            
            idx = torch.randperm(all_inputs.size(0))
            
            input_a, input_b = all_inputs, all_inputs[idx]
            target_a, target_b = all_targets, all_targets[idx]
            
            # input_a > input_b is guaranteed.
            mixed_input = l * input_a + (1 - l) * input_b
            mixed_target = l * target_a + (1 - l) * target_b
    
            # interleave labeled and unlabed samples between batches to get correct batchnorm calculation 
            mixed_input = list(torch.split(mixed_input, batch_size))
            mixed_input = interleave(mixed_input, batch_size)
            
            logits = [model(mixed_input[0])]
            for input in mixed_input[1:]:
                logits.append(model(input))
    
            # put interleaved samples back
            logits = interleave(logits, batch_size)
            logits_x = logits[0]
            logits_u = torch.cat(logits[1:], dim=0)
    
            Lx, Lu, w = criterion(logits_x, mixed_target[:batch_size], logits_u, mixed_target[batch_size:], epoch+batch_idx/_train_iteration)
    
            loss = Lx + w * Lu
    
            losses.update(loss, inputs_x.size(0))
            losses_x.update(Lx, inputs_x.size(0))
            losses_u.update(Lu, inputs_x.size(0))
            ws.update(w, inputs_x.size(0))
    
            optimizer.zero_grad()
            loss.backward()
            
            optimizer.step()
            ema_optimizer.step()
    
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
        
            # plot progress
            bar.suffix  = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | Loss_x: {loss_x:.4f} | Loss_u: {loss_u:.4f} | W: {w:.4f}'.format(
                        batch=batch_idx + 1,
                        size=_train_iteration,
                        data=data_time.avg,
                        bt=batch_time.avg,
                        total=bar.elapsed_td,
                        eta=bar.eta_td,
                        loss=losses.avg,
                        loss_x=losses_x.avg,
                        loss_u=losses_u.avg,
                        w=ws.avg,
                        )
            bar.next()

            progress_bar.set_postfix({
                "Data": f"{data_time.avg:.3f}s",
                "Batch": f"{batch_time.avg:.3f}s",
                "Total": bar.elapsed_td,
                "ETA": bar.eta_td,
                "Loss": f"{losses.avg:.4f}",
                "Loss_x": f"{losses_x.avg:.4f}",
                "Loss_u": f"{losses_u.avg:.4f}",
                "W": f"{ws.avg:.4f}",
            })
            
        bar.finish()
    
    return (losses.avg, losses_x.avg, losses_u.avg,)

In [18]:
def validate(valloader, model, criterion, epoch, use_cuda, mode):

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    bar = Bar(f'{mode}', max=len(valloader))
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(valloader):
            # measure data loading time
            data_time.update(time.time() - end)

            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True)
            # compute output
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(outputs, targets, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1.item(), inputs.size(0))
            top5.update(prec5.item(), inputs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # plot progress
            bar.suffix  = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
                        batch=batch_idx + 1,
                        size=len(valloader),
                        data=data_time.avg,
                        bt=batch_time.avg,
                        total=bar.elapsed_td,
                        eta=bar.eta_td,
                        loss=losses.avg,
                        top1=top1.avg,
                        top5=top5.avg,
                        )
            bar.next()
        bar.finish()
        
    return (losses.avg, top1.avg)

In [19]:
train_criterion = CustomLoss()
criterion = nn.CrossEntropyLoss()

In [20]:
title = 'noisy-cifar-10'

logger = Logger(os.path.join(_out, 'log.txt'), title=title)
logger.set_names(['Train Loss', 'Train Loss X', 'Train Loss U',  'Valid Loss', 'Valid Acc.', 'Test Loss', 'Test Acc.'])
writer = SummaryWriter(_out)

best_acc = 0  # best test accuracy
step = 0
test_accs = []

for epoch in range(_epochs):

    train_loss, train_loss_x, train_loss_u = train(labeled_trainloader, unlabeled_trainloader, model, optim, ema_optim, train_criterion, epoch, use_cuda)
    _, train_acc = validate(labeled_trainloader, ema_model, criterion, epoch, use_cuda, mode='Train Stats')
    val_loss, val_acc = validate(val_loader, ema_model, criterion, epoch, use_cuda, mode='Valid Stats')
    test_loss, test_acc = validate(test_loader, ema_model, criterion, epoch, use_cuda, mode='Test Stats ')

    sched.step()

    _test_acc = get_acc(model, test_loader)
    print(f"[Epoch {epoch}] (naive) Test Accuracy: {_test_acc:.4f}")

    print(f"[Epoch {epoch}] Train Accuracy: {train_acc:.4f}")
    print(f"[Epoch {epoch}] Validation Accuracy: {val_acc:.4f}")
    print(f"[Epoch {epoch}] Test Accuracy: {test_acc:.4f}")

    
    step = _train_iteration * (epoch + 1)

    writer.add_scalar('losses/train_loss', train_loss, step)
    writer.add_scalar('losses/valid_loss', val_loss, step)
    writer.add_scalar('losses/test_loss', test_loss, step)

    writer.add_scalar('accuracy/train_acc', train_acc, step)
    writer.add_scalar('accuracy/val_acc', val_acc, step)
    writer.add_scalar('accuracy/test_acc', test_acc, step)

    logger.append([train_loss, train_loss_x, train_loss_u, val_loss, val_acc, test_loss, test_acc])

    best_acc = max(val_acc, best_acc)
    test_accs.append(test_acc)

logger.close()
writer.close()

print('Best acc:')
print(best_acc)

print('Mean acc:')
print(np.mean(test_accs[-20:]))

Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.62batch/s, Data=0.046s, Batch=0.074s, Total=0:00:11, ETA=0:00:00, Loss=2.2719, Loss_x=2.2713, Loss_u=0.0027, W=0.1987]


[Epoch 0] (naive) Test Accuracy: 0.1770
[Epoch 0] Train Accuracy: 11.5986
[Epoch 0] Validation Accuracy: 11.4200
[Epoch 0] Test Accuracy: 11.9700


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 14.12batch/s, Data=0.046s, Batch=0.071s, Total=0:00:11, ETA=0:00:00, Loss=2.2017, Loss_x=2.1997, Loss_u=0.0033, W=0.5987]


[Epoch 1] (naive) Test Accuracy: 0.2145
[Epoch 1] Train Accuracy: 15.2244
[Epoch 1] Validation Accuracy: 15.7600
[Epoch 1] Test Accuracy: 15.3400


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 14.08batch/s, Data=0.047s, Batch=0.072s, Total=0:00:11, ETA=0:00:00, Loss=2.1232, Loss_x=2.1195, Loss_u=0.0037, W=0.9987]


[Epoch 2] (naive) Test Accuracy: 0.2389
[Epoch 2] Train Accuracy: 12.8806
[Epoch 2] Validation Accuracy: 13.0200
[Epoch 2] Test Accuracy: 12.8900


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 14.01batch/s, Data=0.046s, Batch=0.072s, Total=0:00:11, ETA=0:00:00, Loss=2.0799, Loss_x=2.0736, Loss_u=0.0045, W=1.3987]


[Epoch 3] (naive) Test Accuracy: 0.2749
[Epoch 3] Train Accuracy: 13.3814
[Epoch 3] Validation Accuracy: 13.5400
[Epoch 3] Test Accuracy: 13.3700


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 14.15batch/s, Data=0.046s, Batch=0.072s, Total=0:00:11, ETA=0:00:00, Loss=2.0595, Loss_x=2.0512, Loss_u=0.0046, W=1.7987]


[Epoch 4] (naive) Test Accuracy: 0.2733
[Epoch 4] Train Accuracy: 14.8538
[Epoch 4] Validation Accuracy: 14.8800
[Epoch 4] Test Accuracy: 14.8200


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.90batch/s, Data=0.048s, Batch=0.073s, Total=0:00:11, ETA=0:00:00, Loss=2.0439, Loss_x=2.0312, Loss_u=0.0058, W=2.1987]


[Epoch 5] (naive) Test Accuracy: 0.3002
[Epoch 5] Train Accuracy: 17.1174
[Epoch 5] Validation Accuracy: 17.5000
[Epoch 5] Test Accuracy: 16.9900


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.92batch/s, Data=0.048s, Batch=0.073s, Total=0:00:11, ETA=0:00:00, Loss=2.0109, Loss_x=1.9944, Loss_u=0.0063, W=2.5987]


[Epoch 6] (naive) Test Accuracy: 0.2799
[Epoch 6] Train Accuracy: 19.7616
[Epoch 6] Validation Accuracy: 20.5200
[Epoch 6] Test Accuracy: 19.3900


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 14.01batch/s, Data=0.046s, Batch=0.072s, Total=0:00:11, ETA=0:00:00, Loss=1.9912, Loss_x=1.9706, Loss_u=0.0069, W=2.9987]


[Epoch 7] (naive) Test Accuracy: 0.3645
[Epoch 7] Train Accuracy: 22.1054
[Epoch 7] Validation Accuracy: 23.1000
[Epoch 7] Test Accuracy: 22.1600


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.97batch/s, Data=0.048s, Batch=0.072s, Total=0:00:11, ETA=0:00:00, Loss=1.9704, Loss_x=1.9463, Loss_u=0.0071, W=3.3987]


[Epoch 8] (naive) Test Accuracy: 0.3748
[Epoch 8] Train Accuracy: 24.6294
[Epoch 8] Validation Accuracy: 25.6800
[Epoch 8] Test Accuracy: 25.2800


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.90batch/s, Data=0.048s, Batch=0.073s, Total=0:00:11, ETA=0:00:00, Loss=1.9357, Loss_x=1.9066, Loss_u=0.0076, W=3.7987]


[Epoch 9] (naive) Test Accuracy: 0.2978
[Epoch 9] Train Accuracy: 27.6342
[Epoch 9] Validation Accuracy: 29.2200
[Epoch 9] Test Accuracy: 28.5700


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.93batch/s, Data=0.048s, Batch=0.073s, Total=0:00:11, ETA=0:00:00, Loss=1.9290, Loss_x=1.8957, Loss_u=0.0079, W=4.1987]


[Epoch 10] (naive) Test Accuracy: 0.3344
[Epoch 10] Train Accuracy: 31.0397
[Epoch 10] Validation Accuracy: 33.0400
[Epoch 10] Test Accuracy: 32.5800


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.94batch/s, Data=0.047s, Batch=0.073s, Total=0:00:11, ETA=0:00:00, Loss=1.8840, Loss_x=1.8446, Loss_u=0.0086, W=4.5987]


[Epoch 11] (naive) Test Accuracy: 0.3969
[Epoch 11] Train Accuracy: 35.8073
[Epoch 11] Validation Accuracy: 37.5400
[Epoch 11] Test Accuracy: 36.5700


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.86batch/s, Data=0.047s, Batch=0.073s, Total=0:00:11, ETA=0:00:00, Loss=1.9036, Loss_x=1.8537, Loss_u=0.0100, W=4.9987]


[Epoch 12] (naive) Test Accuracy: 0.4610
[Epoch 12] Train Accuracy: 38.0108
[Epoch 12] Validation Accuracy: 40.4000
[Epoch 12] Test Accuracy: 39.8200


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 14.00batch/s, Data=0.048s, Batch=0.073s, Total=0:00:11, ETA=0:00:00, Loss=1.8515, Loss_x=1.7989, Loss_u=0.0097, W=5.3987]


[Epoch 13] (naive) Test Accuracy: 0.4837
[Epoch 13] Train Accuracy: 41.2059
[Epoch 13] Validation Accuracy: 42.8400
[Epoch 13] Test Accuracy: 42.1400


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.93batch/s, Data=0.046s, Batch=0.072s, Total=0:00:11, ETA=0:00:00, Loss=1.8450, Loss_x=1.7855, Loss_u=0.0103, W=5.7987]


[Epoch 14] (naive) Test Accuracy: 0.4766
[Epoch 14] Train Accuracy: 43.5096
[Epoch 14] Validation Accuracy: 45.1400
[Epoch 14] Test Accuracy: 44.2000


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.97batch/s, Data=0.046s, Batch=0.072s, Total=0:00:11, ETA=0:00:00, Loss=1.7861, Loss_x=1.7215, Loss_u=0.0104, W=6.1987]


[Epoch 15] (naive) Test Accuracy: 0.5340
[Epoch 15] Train Accuracy: 45.6530
[Epoch 15] Validation Accuracy: 47.2400
[Epoch 15] Test Accuracy: 47.0300


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 14.01batch/s, Data=0.047s, Batch=0.072s, Total=0:00:11, ETA=0:00:00, Loss=1.7839, Loss_x=1.7143, Loss_u=0.0105, W=6.5987]


[Epoch 16] (naive) Test Accuracy: 0.4641
[Epoch 16] Train Accuracy: 48.9083
[Epoch 16] Validation Accuracy: 49.4800
[Epoch 16] Test Accuracy: 48.8700


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.95batch/s, Data=0.048s, Batch=0.072s, Total=0:00:11, ETA=0:00:00, Loss=1.7610, Loss_x=1.6864, Loss_u=0.0107, W=6.9987]


[Epoch 17] (naive) Test Accuracy: 0.5333
[Epoch 17] Train Accuracy: 51.0116
[Epoch 17] Validation Accuracy: 51.7800
[Epoch 17] Test Accuracy: 50.7400


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.95batch/s, Data=0.048s, Batch=0.073s, Total=0:00:11, ETA=0:00:00, Loss=1.7034, Loss_x=1.6219, Loss_u=0.0110, W=7.3987]


[Epoch 18] (naive) Test Accuracy: 0.5650
[Epoch 18] Train Accuracy: 53.7660
[Epoch 18] Validation Accuracy: 53.5000
[Epoch 18] Test Accuracy: 52.6700


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.93batch/s, Data=0.046s, Batch=0.073s, Total=0:00:11, ETA=0:00:00, Loss=1.7186, Loss_x=1.6275, Loss_u=0.0117, W=7.7987]


[Epoch 19] (naive) Test Accuracy: 0.6063
[Epoch 19] Train Accuracy: 55.3786
[Epoch 19] Validation Accuracy: 55.9600
[Epoch 19] Test Accuracy: 55.0700


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.99batch/s, Data=0.047s, Batch=0.073s, Total=0:00:11, ETA=0:00:00, Loss=1.6936, Loss_x=1.6002, Loss_u=0.0114, W=8.1987]


[Epoch 20] (naive) Test Accuracy: 0.6074
[Epoch 20] Train Accuracy: 57.7925
[Epoch 20] Validation Accuracy: 57.5400
[Epoch 20] Test Accuracy: 56.8000


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.90batch/s, Data=0.049s, Batch=0.073s, Total=0:00:11, ETA=0:00:00, Loss=1.6533, Loss_x=1.5573, Loss_u=0.0112, W=8.5987]


[Epoch 21] (naive) Test Accuracy: 0.5939
[Epoch 21] Train Accuracy: 59.7556
[Epoch 21] Validation Accuracy: 59.2400
[Epoch 21] Test Accuracy: 58.8300


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.98batch/s, Data=0.047s, Batch=0.073s, Total=0:00:11, ETA=0:00:00, Loss=1.6473, Loss_x=1.5448, Loss_u=0.0114, W=8.9987]


[Epoch 22] (naive) Test Accuracy: 0.6376
[Epoch 22] Train Accuracy: 61.3081
[Epoch 22] Validation Accuracy: 60.6800
[Epoch 22] Test Accuracy: 60.3000


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.91batch/s, Data=0.047s, Batch=0.073s, Total=0:00:11, ETA=0:00:00, Loss=1.5872, Loss_x=1.4837, Loss_u=0.0110, W=9.3987]


[Epoch 23] (naive) Test Accuracy: 0.6564
[Epoch 23] Train Accuracy: 63.1310
[Epoch 23] Validation Accuracy: 62.2400
[Epoch 23] Test Accuracy: 61.8000


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.96batch/s, Data=0.047s, Batch=0.073s, Total=0:00:11, ETA=0:00:00, Loss=1.6152, Loss_x=1.4997, Loss_u=0.0118, W=9.7987]


[Epoch 24] (naive) Test Accuracy: 0.6719
[Epoch 24] Train Accuracy: 64.5933
[Epoch 24] Validation Accuracy: 63.8000
[Epoch 24] Test Accuracy: 63.1800


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.90batch/s, Data=0.047s, Batch=0.073s, Total=0:00:11, ETA=0:00:00, Loss=1.6242, Loss_x=1.5028, Loss_u=0.0119, W=10.1987]


[Epoch 25] (naive) Test Accuracy: 0.6551
[Epoch 25] Train Accuracy: 66.6066
[Epoch 25] Validation Accuracy: 65.0600
[Epoch 25] Test Accuracy: 64.5700


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 14.04batch/s, Data=0.048s, Batch=0.072s, Total=0:00:11, ETA=0:00:00, Loss=1.5986, Loss_x=1.4708, Loss_u=0.0121, W=10.5987]


[Epoch 26] (naive) Test Accuracy: 0.6458
[Epoch 26] Train Accuracy: 67.6382
[Epoch 26] Validation Accuracy: 66.0400
[Epoch 26] Test Accuracy: 65.6600


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.94batch/s, Data=0.046s, Batch=0.073s, Total=0:00:11, ETA=0:00:00, Loss=1.5566, Loss_x=1.4290, Loss_u=0.0116, W=10.9987]


[Epoch 27] (naive) Test Accuracy: 0.7110
[Epoch 27] Train Accuracy: 68.7700
[Epoch 27] Validation Accuracy: 67.2200
[Epoch 27] Test Accuracy: 66.8100


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.73batch/s, Data=0.049s, Batch=0.074s, Total=0:00:11, ETA=0:00:00, Loss=1.4717, Loss_x=1.3444, Loss_u=0.0112, W=11.3987]


[Epoch 28] (naive) Test Accuracy: 0.6841
[Epoch 28] Train Accuracy: 70.4728
[Epoch 28] Validation Accuracy: 67.9800
[Epoch 28] Test Accuracy: 67.8100


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.82batch/s, Data=0.049s, Batch=0.073s, Total=0:00:11, ETA=0:00:00, Loss=1.4594, Loss_x=1.3305, Loss_u=0.0109, W=11.7987]


[Epoch 29] (naive) Test Accuracy: 0.7048
[Epoch 29] Train Accuracy: 71.2340
[Epoch 29] Validation Accuracy: 68.6800
[Epoch 29] Test Accuracy: 68.6200


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 14.03batch/s, Data=0.045s, Batch=0.072s, Total=0:00:11, ETA=0:00:00, Loss=1.5465, Loss_x=1.3989, Loss_u=0.0121, W=12.1987]


[Epoch 30] (naive) Test Accuracy: 0.6998
[Epoch 30] Train Accuracy: 72.8866
[Epoch 30] Validation Accuracy: 69.5000
[Epoch 30] Test Accuracy: 69.4400


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.99batch/s, Data=0.043s, Batch=0.073s, Total=0:00:11, ETA=0:00:00, Loss=1.5036, Loss_x=1.3542, Loss_u=0.0119, W=12.5987]


[Epoch 31] (naive) Test Accuracy: 0.7165
[Epoch 31] Train Accuracy: 73.7580
[Epoch 31] Validation Accuracy: 70.3800
[Epoch 31] Test Accuracy: 70.0900


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.79batch/s, Data=0.045s, Batch=0.074s, Total=0:00:11, ETA=0:00:00, Loss=1.5376, Loss_x=1.3786, Loss_u=0.0122, W=12.9987]


[Epoch 32] (naive) Test Accuracy: 0.7297
[Epoch 32] Train Accuracy: 75.3105
[Epoch 32] Validation Accuracy: 70.7800
[Epoch 32] Test Accuracy: 70.8800


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.93batch/s, Data=0.049s, Batch=0.073s, Total=0:00:11, ETA=0:00:00, Loss=1.5414, Loss_x=1.3749, Loss_u=0.0124, W=13.3987]


[Epoch 33] (naive) Test Accuracy: 0.7364
[Epoch 33] Train Accuracy: 76.0817
[Epoch 33] Validation Accuracy: 71.3600
[Epoch 33] Test Accuracy: 71.5400


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 14.03batch/s, Data=0.047s, Batch=0.072s, Total=0:00:11, ETA=0:00:00, Loss=1.5061, Loss_x=1.3347, Loss_u=0.0124, W=13.7987]


[Epoch 34] (naive) Test Accuracy: 0.7293
[Epoch 34] Train Accuracy: 76.8229
[Epoch 34] Validation Accuracy: 71.9400
[Epoch 34] Test Accuracy: 72.0200


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.92batch/s, Data=0.046s, Batch=0.073s, Total=0:00:11, ETA=0:00:00, Loss=1.4685, Loss_x=1.2931, Loss_u=0.0123, W=14.1987]


[Epoch 35] (naive) Test Accuracy: 0.7351
[Epoch 35] Train Accuracy: 77.6643
[Epoch 35] Validation Accuracy: 72.6000
[Epoch 35] Test Accuracy: 72.5700


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.89batch/s, Data=0.044s, Batch=0.073s, Total=0:00:11, ETA=0:00:00, Loss=1.4531, Loss_x=1.2760, Loss_u=0.0121, W=14.5987]


[Epoch 36] (naive) Test Accuracy: 0.7167
[Epoch 36] Train Accuracy: 78.3253
[Epoch 36] Validation Accuracy: 73.3400
[Epoch 36] Test Accuracy: 73.2400


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.92batch/s, Data=0.049s, Batch=0.073s, Total=0:00:11, ETA=0:00:00, Loss=1.4297, Loss_x=1.2555, Loss_u=0.0116, W=14.9987]


[Epoch 37] (naive) Test Accuracy: 0.7538
[Epoch 37] Train Accuracy: 79.1066
[Epoch 37] Validation Accuracy: 74.1200
[Epoch 37] Test Accuracy: 73.8000


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.96batch/s, Data=0.044s, Batch=0.073s, Total=0:00:11, ETA=0:00:00, Loss=1.3933, Loss_x=1.2170, Loss_u=0.0114, W=15.3987]


[Epoch 38] (naive) Test Accuracy: 0.7421
[Epoch 38] Train Accuracy: 80.2484
[Epoch 38] Validation Accuracy: 74.8800
[Epoch 38] Test Accuracy: 74.2700


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.96batch/s, Data=0.047s, Batch=0.073s, Total=0:00:11, ETA=0:00:00, Loss=1.4953, Loss_x=1.2923, Loss_u=0.0128, W=15.7987]


[Epoch 39] (naive) Test Accuracy: 0.7485
[Epoch 39] Train Accuracy: 80.1683
[Epoch 39] Validation Accuracy: 75.5800
[Epoch 39] Test Accuracy: 74.5200


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.92batch/s, Data=0.047s, Batch=0.073s, Total=0:00:11, ETA=0:00:00, Loss=1.3561, Loss_x=1.1714, Loss_u=0.0114, W=16.1987]


[Epoch 40] (naive) Test Accuracy: 0.7616
[Epoch 40] Train Accuracy: 80.9195
[Epoch 40] Validation Accuracy: 76.0800
[Epoch 40] Test Accuracy: 74.9700


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 14.03batch/s, Data=0.048s, Batch=0.072s, Total=0:00:11, ETA=0:00:00, Loss=1.3922, Loss_x=1.1970, Loss_u=0.0118, W=16.5987]


[Epoch 41] (naive) Test Accuracy: 0.7665
[Epoch 41] Train Accuracy: 81.9411
[Epoch 41] Validation Accuracy: 76.3000
[Epoch 41] Test Accuracy: 75.2900


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.92batch/s, Data=0.047s, Batch=0.073s, Total=0:00:11, ETA=0:00:00, Loss=1.4353, Loss_x=1.2284, Loss_u=0.0122, W=16.9987]


[Epoch 42] (naive) Test Accuracy: 0.7584
[Epoch 42] Train Accuracy: 82.6723
[Epoch 42] Validation Accuracy: 76.6200
[Epoch 42] Test Accuracy: 75.6200


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.99batch/s, Data=0.047s, Batch=0.073s, Total=0:00:11, ETA=0:00:00, Loss=1.4596, Loss_x=1.2399, Loss_u=0.0126, W=17.3987]


[Epoch 43] (naive) Test Accuracy: 0.7701
[Epoch 43] Train Accuracy: 82.9427
[Epoch 43] Validation Accuracy: 76.6800
[Epoch 43] Test Accuracy: 75.9200


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 14.01batch/s, Data=0.044s, Batch=0.072s, Total=0:00:11, ETA=0:00:00, Loss=1.4016, Loss_x=1.1956, Loss_u=0.0116, W=17.7987]


[Epoch 44] (naive) Test Accuracy: 0.7698
[Epoch 44] Train Accuracy: 83.2732
[Epoch 44] Validation Accuracy: 76.9600
[Epoch 44] Test Accuracy: 76.1600


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.94batch/s, Data=0.047s, Batch=0.073s, Total=0:00:11, ETA=0:00:00, Loss=1.4055, Loss_x=1.1887, Loss_u=0.0119, W=18.1987]


[Epoch 45] (naive) Test Accuracy: 0.7763
[Epoch 45] Train Accuracy: 84.0345
[Epoch 45] Validation Accuracy: 77.1800
[Epoch 45] Test Accuracy: 76.3100


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.95batch/s, Data=0.048s, Batch=0.073s, Total=0:00:11, ETA=0:00:00, Loss=1.4113, Loss_x=1.1906, Loss_u=0.0119, W=18.5987]


[Epoch 46] (naive) Test Accuracy: 0.7779
[Epoch 46] Train Accuracy: 83.7139
[Epoch 46] Validation Accuracy: 77.3800
[Epoch 46] Test Accuracy: 76.5600


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.96batch/s, Data=0.049s, Batch=0.073s, Total=0:00:11, ETA=0:00:00, Loss=1.3744, Loss_x=1.1550, Loss_u=0.0115, W=18.9987]


[Epoch 47] (naive) Test Accuracy: 0.7694
[Epoch 47] Train Accuracy: 84.5853
[Epoch 47] Validation Accuracy: 77.4800
[Epoch 47] Test Accuracy: 76.7200


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.92batch/s, Data=0.047s, Batch=0.073s, Total=0:00:11, ETA=0:00:00, Loss=1.3872, Loss_x=1.1621, Loss_u=0.0116, W=19.3987]


[Epoch 48] (naive) Test Accuracy: 0.7782
[Epoch 48] Train Accuracy: 84.5052
[Epoch 48] Validation Accuracy: 77.7600
[Epoch 48] Test Accuracy: 76.9800


Training Progress: 100%|██████████| 156/156 [00:11<00:00, 13.99batch/s, Data=0.046s, Batch=0.073s, Total=0:00:11, ETA=0:00:00, Loss=1.4107, Loss_x=1.1719, Loss_u=0.0121, W=19.7987]


[Epoch 49] (naive) Test Accuracy: 0.7754
[Epoch 49] Train Accuracy: 84.8858
[Epoch 49] Validation Accuracy: 77.8600
[Epoch 49] Test Accuracy: 77.1600
Best acc:
77.86
Mean acc:
74.203
