In [1]:
import os
import time
import datetime 
from pathlib import Path

import numpy as np
import torch
import wandb
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Subset, random_split
import torch.utils.data as data # for code compatibility
from torchvision import models, transforms
from torchvision.datasets import CIFAR10, CIFAR100
import timm

import scipy.stats
from sklearn.metrics import auc, roc_curve

import pytorch_lightning as pl

# util
from tqdm import tqdm
from collections import Counter

# customized 
import models.arch as models

import dataset.cifar100 as dataset
# import dataset.cifar100_larger as dataset # for ViT model,

from utils import Logger, AverageMeter, accuracy, mkdir_p, savefig
from progress.bar import Bar as Bar
import utils as utils_
from tensorboardX import SummaryWriter

##
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps")
DEVICE

device(type='cuda')

In [2]:
# TODO: 
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
# for vgg19, CIFAR-100
_arch = "vgg19"
_dataset = "cifar100"

_lr = 0.02
_epochs = 50

_debug = True
_batch_size = 64

_n_labeled = 10000
# XXX TODO 
_n_unlabeled = 45000 - _n_labeled 

_alpha = 0.75
_lambda_u = 20 # default: 75

_ema_decay = 0.999 # default: 0.999
_T = 0.5 # default: 0.5
# XXX TODO 
_train_iteration = int(max(_n_labeled, _n_unlabeled) / _batch_size) # phase 1: support full-supervised learning, default: 1024
# _train_iteration = int(_n_labeled / _batch_size) # phase 1: support full-supervised learning, default: 1024

# XXX TODO 
_n_classes = 100 # depend on dataset

use_cuda = torch.cuda.is_available()

In [4]:
_out = 'experiments/{}_only/{}@{}_lu_{}_iter_{}'.format(_arch, _dataset, _n_labeled, _lambda_u, _train_iteration)
_out

'experiments/vgg19_only/cifar100@10000_lu_20_iter_546'

In [5]:
_train_iteration

546

In [6]:
SEED = 1583745484

# prepare dataset

In [7]:
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x7f451c069350>

In [8]:
print(f'==> Preparing cifar100')
transform_train = transforms.Compose([
    dataset.RandomPadandCrop(32),
    dataset.RandomFlip(),
    dataset.ToTensor(),
])

transform_val = transforms.Compose([
    dataset.ToTensor(),
])
datadir = Path().home() / "dataset"

batch_size=_batch_size

# measure time for data pre-processing
start_time = time.time()

train_labeled_set, train_unlabeled_set, val_set, test_set = dataset.get_cifar100(datadir, _n_labeled, transform_train=transform_train, transform_val=transform_val)
labeled_trainloader = data.DataLoader(train_labeled_set, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True)
unlabeled_trainloader = data.DataLoader(train_unlabeled_set, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True)
val_loader = data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=0)
"""
test_loader = data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0)
"""
indices = np.random.choice(len(test_set), int(0.2 * len(test_set)), replace=False)
subset_test_set = Subset(test_set, indices)
test_loader = data.DataLoader(subset_test_set, batch_size=batch_size, shuffle=False, num_workers=0)

==> Preparing cifar100
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
#Labeled: 10000 #Unlabeled: 35000 #Val: 5000


In [9]:
train_labeled_set

Dataset CIFAR100_labeled
    Number of datapoints: 10000
    Root location: /home/dsanyal7/dataset
    Split: Train
    StandardTransform
Transform: Compose(
               <dataset.cifar100.RandomPadandCrop object at 0x7f440ccc2940>
               <dataset.cifar100.RandomFlip object at 0x7f440ccc2970>
               <dataset.cifar100.ToTensor object at 0x7f440ccc2a30>
           )

In [10]:
train_labeled_set[0][0].shape

torch.Size([3, 32, 32])

In [11]:
train_unlabeled_set

Dataset CIFAR100_unlabeled
    Number of datapoints: 35000
    Root location: /home/dsanyal7/dataset
    Split: Train
    StandardTransform
Transform: <dataset.cifar100.TransformTwice object at 0x7f4509bf6b80>

In [12]:
len(subset_test_set)

2000

In [13]:
# utils_.get_mean_and_std(train_labeled_set)

# expected values:
# (tensor([ 2.1660e-05, -8.8033e-04,  1.0356e-03]),
# tensor([0.8125, 0.8125, 0.7622]))

## model

In [14]:
def create_model(ema=False):
    model = models.network(_arch, pretrained=True, n_classes=_n_classes)
    model = model.cuda()

    if ema:
        for param in model.parameters():
            param.detach_()

    return model

model = create_model()
ema_model = create_model(ema=True)

arch: vgg19, pretrained: True, n_classes: 100




Do not freeze layers for model: vgg19
arch: vgg19, pretrained: True, n_classes: 100
Do not freeze layers for model: vgg19


In [15]:
# model
print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))

Total params: 139.98M


In [16]:
def linear_rampup(current, rampup_length=_epochs):
    if rampup_length == 0:
        return 1.0
    else:
        current = np.clip(current / rampup_length, 0.0, 1.0)
        return float(current)

In [17]:
class WeigthEMA(object):
    def __init__(self, model, ema_model, alpha=0.999):
        self.model = model
        self.ema_model = ema_model
        self.alpha = alpha
        self.params = list(model.state_dict().values())
        self.ema_params = list(ema_model.state_dict().values())
        self.wd = 0.02 * _lr # for lr=0.02, wd = 0.0004

        for param, ema_param in zip(self.params, self.ema_params):
            param.data.copy_(ema_param.data)

    def step(self):
        one_minus_alpha = 1.0 - self.alpha
        for param, ema_param in zip(self.params, self.ema_params):
            if ema_param.dtype==torch.float32:
                ema_param.mul_(self.alpha)
                ema_param.add_(param * one_minus_alpha)
                # customized weight decay (TODO: should carefully set up)
                # 
                # lr=0.02 -> (0.9996)^5928 (1 epoch): 0.09332.
                # lr=0.002 -> (0.99996)^5928 (1 epoch): 0.78889
                # param.mul_(1 - self.wd) # x 0.9996 (for lr=0.02), x0.99996 (for lr=0

class CustomLoss(object):
    def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch):
        """
        probs_u = torch.softmax(outputs_u, dim=1)
        """
        
        Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1))
        
        # CHECK
        """
        Lu = torch.mean((probs_u - targets_u)**2)
        w = _lambda_u * linear_rampup(epoch) # _lambda_u: 75 (default)
        """ 
        Lu = 0
        w = 0

        return Lx, Lu, 0

In [18]:
optim = torch.optim.SGD(model.parameters(), lr=_lr, momentum=0.9, weight_decay=5e-4)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=_epochs)

ema_optim = WeigthEMA(model, ema_model, alpha=_ema_decay)

In [19]:
@torch.no_grad()
def get_acc(model, dl):
    acc = []
    for x, y in dl:
        x, y = x.to(DEVICE), y.to(DEVICE)
        acc.append(torch.argmax(model(x), dim=1) == y)
    acc = torch.cat(acc)
    acc = torch.sum(acc) / len(acc)

    return acc.item()

def interleave_offsets(batch, nu):
    groups = [batch // (nu + 1)] * (nu + 1)
    for x in range(batch - sum(groups)):
        groups[-x - 1] += 1
    offsets = [0]
    for g in groups:
        offsets.append(offsets[-1] + g)
    assert offsets[-1] == batch
    return offsets

def interleave(xy, batch):
    nu = len(xy) - 1
    offsets = interleave_offsets(batch, nu)
    xy = [[v[offsets[p]:offsets[p + 1]] for p in range(nu + 1)] for v in xy]
    for i in range(1, nu + 1):
        xy[0][i], xy[i][i] = xy[i][i], xy[0][i]
    return [torch.cat(v, dim=0) for v in xy]

In [20]:
def train(labeled_trainloader, unlabeled_trainloader, model, optimizer, ema_optimizer, criterion, epoch, use_cuda):
    
    batch_time = AverageMeter()
    data_time = AverageMeter()
    
    losses = AverageMeter()
    losses_x = AverageMeter()
    losses_u = AverageMeter()
    ws = AverageMeter()
    end = time.time()
    
    bar = Bar('Training', max=_train_iteration)
    labeled_train_iter = iter(labeled_trainloader)
    unlabeled_train_iter = iter(unlabeled_trainloader)
    
    model.train()

    with tqdm(range(_train_iteration), desc="Training Progress", unit="batch") as progress_bar:
        for batch_idx in progress_bar:
            try:
                inputs_x, targets_x = next(labeled_train_iter)
            except:
                labeled_train_iter = iter(labeled_trainloader)
                inputs_x, targets_x = next(labeled_train_iter)

            """
            try:
                (inputs_u, inputs_u2), _ = next(unlabeled_train_iter) # two different augment(x)s;
            except:
                unlabeled_train_iter = iter(unlabeled_trainloader)
                (inputs_u, inputs_u2), _ = next(unlabeled_train_iter)
                """
    
            # measure data loading time
            data_time.update(time.time() - end)
            
            batch_size = inputs_x.size(0)
    
            # convert label to one-hot
            targets_x = torch.zeros(batch_size, _n_classes).scatter_(1, targets_x.view(-1,1).long(), 1)
    
            if use_cuda:
                inputs_x, targets_x = inputs_x.cuda(), targets_x.cuda(non_blocking=True)
                """
                inputs_u = inputs_u.cuda()
                inputs_u2 = inputs_u2.cuda()
                """

            """
            with torch.no_grad():
                # compute guessed labels of unlabel samples
                outputs_u = model(inputs_u)
                outputs_u2 = model(inputs_u2)
                p = (torch.softmax(outputs_u, dim=1) + torch.softmax(outputs_u2, dim=1)) / 2
                pt = p**(1/_T)
                targets_u = pt / pt.sum(dim=1, keepdim=True)
                targets_u = targets_u.detach()
            """
            
            # mixup 
            """ 
            all_inputs = torch.cat([inputs_x, inputs_u, inputs_u2], dim=0)
            all_targets = torch.cat([targets_x, targets_u, targets_u], dim=0)
            """
            all_inputs = torch.cat([inputs_x], dim=0)
            all_targets = torch.cat([targets_x], dim=0)

            # CHECK: 
            """
            l = np.random.beta(_alpha, _alpha)
            """
            l = 1
            l = max(l, 1-l)
            
            idx = torch.randperm(all_inputs.size(0))
            
            input_a, input_b = all_inputs, all_inputs[idx]
            target_a, target_b = all_targets, all_targets[idx]
            
            # input_a > input_b is guaranteed.
            mixed_input = l * input_a + (1 - l) * input_b
            mixed_target = l * target_a + (1 - l) * target_b

            """
            # interleave labeled and unlabed samples between batches to get correct batchnorm calculation 
            mixed_input = list(torch.split(mixed_input, batch_size))
            mixed_input = interleave(mixed_input, batch_size)
            """

            """
            logits = [model(mixed_input[0])]
            for input in mixed_input[1:]:
                logits.append(model(input))
            """
            logits = [model(mixed_input)]
    
            # put interleaved samples back
            """
            logits = interleave(logits, batch_size)
            logits_x = logits[0]
            logits_u = torch.cat(logits[1:], dim=0)
            """
            logits_x = logits[0]

            """
            Lx, Lu, w = criterion(logits_x, mixed_target[:batch_size], logits_u, mixed_target[batch_size:], epoch+batch_idx/_train_iteration)
            """
            Lx, Lu, w = criterion(logits_x, mixed_target[:batch_size], None, None, epoch+batch_idx/_train_iteration)
    
            loss = Lx + w * Lu
    
            losses.update(loss, inputs_x.size(0))
            losses_x.update(Lx, inputs_x.size(0))
            losses_u.update(Lu, inputs_x.size(0))
            ws.update(w, inputs_x.size(0))
    
            optimizer.zero_grad()
            loss.backward()
            
            optimizer.step()
            # CHECK 
            # ema_optimizer.step()
    
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
        
            # plot progress
            bar.suffix  = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | Loss_x: {loss_x:.4f} | Loss_u: {loss_u:.4f} | W: {w:.4f}'.format(
                        batch=batch_idx + 1,
                        size=_train_iteration,
                        data=data_time.avg,
                        bt=batch_time.avg,
                        total=bar.elapsed_td,
                        eta=bar.eta_td,
                        loss=losses.avg,
                        loss_x=losses_x.avg,
                        loss_u=losses_u.avg,
                        w=ws.avg,
                        )
            bar.next()

            progress_bar.set_postfix({
                "Data": f"{data_time.avg:.3f}s",
                "Batch": f"{batch_time.avg:.3f}s",
                "Total": bar.elapsed_td,
                "ETA": bar.eta_td,
                "Loss": f"{losses.avg:.4f}",
                "Loss_x": f"{losses_x.avg:.4f}",
                "Loss_u": f"{losses_u.avg:.4f}",
                "W": f"{ws.avg:.4f}",
            })
            
        bar.finish()
    
    return (losses.avg, losses_x.avg, losses_u.avg,)

In [21]:
def validate(valloader, model, criterion, epoch, use_cuda, mode):

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    bar = Bar(f'{mode}', max=len(valloader))
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(valloader):
            # measure data loading time
            data_time.update(time.time() - end)

            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True)
            # compute output
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(outputs, targets, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1.item(), inputs.size(0))
            top5.update(prec5.item(), inputs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # plot progress
            bar.suffix  = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
                        batch=batch_idx + 1,
                        size=len(valloader),
                        data=data_time.avg,
                        bt=batch_time.avg,
                        total=bar.elapsed_td,
                        eta=bar.eta_td,
                        loss=losses.avg,
                        top1=top1.avg,
                        top5=top5.avg,
                        )
            bar.next()
        bar.finish()
        
    return (losses.avg, top1.avg)

In [22]:
train_criterion = CustomLoss()
criterion = nn.CrossEntropyLoss()

In [23]:
if not os.path.isdir(_out):
    mkdir_p(_out)

title = 'noisy-cifar-10'

logger = Logger(os.path.join(_out, 'log.txt'), title=title)
logger.set_names(['Train Loss', 'Train Loss X', 'Train Loss U',  'Valid Loss', 'Valid Acc.', 'Test Loss', 'Test Acc.'])
writer = SummaryWriter(_out)

best_acc = 0  # best test accuracy
step = 0
test_accs = []

for epoch in range(_epochs):

    train_loss, train_loss_x, train_loss_u = train(labeled_trainloader, unlabeled_trainloader, model, optim, ema_optim, train_criterion, epoch, use_cuda)
    """
    _, train_acc = validate(labeled_trainloader, ema_model, criterion, epoch, use_cuda, mode='Train Stats')
    val_loss, val_acc = validate(val_loader, ema_model, criterion, epoch, use_cuda, mode='Valid Stats')
    test_loss, test_acc = validate(test_loader, ema_model, criterion, epoch, use_cuda, mode='Test Stats ')
    """
    _, train_acc = 0.0, 0.0
    val_loss, val_acc = 0.0, 0.0
    test_loss, test_acc = validate(test_loader, model, criterion, epoch, use_cuda, mode='Test Stats ')

    sched.step()

    _test_acc = get_acc(model, test_loader)
    """
    print(f"[Epoch {epoch}] (naive) Test Accuracy: {_test_acc:.4f}")
    print(f"[Epoch {epoch}] Train Accuracy: {train_acc:.4f}")
    print(f"[Epoch {epoch}] Validation Accuracy: {val_acc:.4f}")
    print(f"[Epoch {epoch}] Test Accuracy: {test_acc:.4f}")
    """
    print(f"[Epoch {epoch}] (naive) Test Accuracy: {_test_acc:.4f}")
    print(f"[Epoch {epoch}] Test Accuracy: {test_acc:.4f}")
    
    
    step = _train_iteration * (epoch + 1)

    writer.add_scalar('losses/train_loss', train_loss, step)
    writer.add_scalar('losses/valid_loss', val_loss, step)
    writer.add_scalar('losses/test_loss', test_loss, step)

    writer.add_scalar('accuracy/train_acc', train_acc, step)
    writer.add_scalar('accuracy/val_acc', val_acc, step)
    writer.add_scalar('accuracy/test_acc', test_acc, step)

    logger.append([train_loss, train_loss_x, train_loss_u, val_loss, val_acc, test_loss, test_acc])

    best_acc = max(val_acc, best_acc)
    test_accs.append(test_acc)

logger.close()
writer.close()

print('Best acc:')
print(best_acc)

print('Mean acc:')
print(np.mean(test_accs[-20:]))

Training Progress: 100%|██████████| 546/546 [00:14<00:00, 36.99batch/s, Data=0.019s, Batch=0.027s, Total=0:00:14, ETA=0:00:00, Loss=3.4829, Loss_x=3.4829, Loss_u=0.0000, W=0.0000]


[Epoch 0] (naive) Test Accuracy: 0.2355
[Epoch 0] Test Accuracy: 23.5500


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 38.23batch/s, Data=0.019s, Batch=0.026s, Total=0:00:14, ETA=0:00:00, Loss=2.8443, Loss_x=2.8443, Loss_u=0.0000, W=0.0000]


[Epoch 1] (naive) Test Accuracy: 0.3115
[Epoch 1] Test Accuracy: 31.1500


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 38.56batch/s, Data=0.019s, Batch=0.026s, Total=0:00:14, ETA=0:00:00, Loss=2.5119, Loss_x=2.5119, Loss_u=0.0000, W=0.0000]


[Epoch 2] (naive) Test Accuracy: 0.3365
[Epoch 2] Test Accuracy: 33.6500


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 38.25batch/s, Data=0.019s, Batch=0.026s, Total=0:00:14, ETA=0:00:00, Loss=2.3409, Loss_x=2.3409, Loss_u=0.0000, W=0.0000]


[Epoch 3] (naive) Test Accuracy: 0.3725
[Epoch 3] Test Accuracy: 37.2500


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 38.41batch/s, Data=0.019s, Batch=0.026s, Total=0:00:14, ETA=0:00:00, Loss=2.1469, Loss_x=2.1469, Loss_u=0.0000, W=0.0000]


[Epoch 4] (naive) Test Accuracy: 0.3795
[Epoch 4] Test Accuracy: 37.9500


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 38.62batch/s, Data=0.019s, Batch=0.026s, Total=0:00:14, ETA=0:00:00, Loss=2.0017, Loss_x=2.0017, Loss_u=0.0000, W=0.0000]


[Epoch 5] (naive) Test Accuracy: 0.3845
[Epoch 5] Test Accuracy: 38.4500


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 38.47batch/s, Data=0.019s, Batch=0.026s, Total=0:00:14, ETA=0:00:00, Loss=1.8840, Loss_x=1.8840, Loss_u=0.0000, W=0.0000]


[Epoch 6] (naive) Test Accuracy: 0.4110
[Epoch 6] Test Accuracy: 41.1000


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 38.06batch/s, Data=0.019s, Batch=0.027s, Total=0:00:14, ETA=0:00:00, Loss=1.8055, Loss_x=1.8055, Loss_u=0.0000, W=0.0000]


[Epoch 7] (naive) Test Accuracy: 0.4175
[Epoch 7] Test Accuracy: 41.7500


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 38.15batch/s, Data=0.019s, Batch=0.027s, Total=0:00:14, ETA=0:00:00, Loss=1.6764, Loss_x=1.6764, Loss_u=0.0000, W=0.0000]


[Epoch 8] (naive) Test Accuracy: 0.4240
[Epoch 8] Test Accuracy: 42.4000


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 38.29batch/s, Data=0.018s, Batch=0.026s, Total=0:00:14, ETA=0:00:00, Loss=1.5831, Loss_x=1.5831, Loss_u=0.0000, W=0.0000]


[Epoch 9] (naive) Test Accuracy: 0.4240
[Epoch 9] Test Accuracy: 42.4000


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 37.89batch/s, Data=0.020s, Batch=0.027s, Total=0:00:14, ETA=0:00:00, Loss=1.5554, Loss_x=1.5554, Loss_u=0.0000, W=0.0000]


[Epoch 10] (naive) Test Accuracy: 0.4310
[Epoch 10] Test Accuracy: 43.1000


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 38.79batch/s, Data=0.019s, Batch=0.026s, Total=0:00:14, ETA=0:00:00, Loss=1.4238, Loss_x=1.4238, Loss_u=0.0000, W=0.0000]


[Epoch 11] (naive) Test Accuracy: 0.4555
[Epoch 11] Test Accuracy: 45.5500


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 38.06batch/s, Data=0.019s, Batch=0.027s, Total=0:00:14, ETA=0:00:00, Loss=1.3272, Loss_x=1.3272, Loss_u=0.0000, W=0.0000]


[Epoch 12] (naive) Test Accuracy: 0.4345
[Epoch 12] Test Accuracy: 43.4500


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 38.47batch/s, Data=0.019s, Batch=0.026s, Total=0:00:14, ETA=0:00:00, Loss=1.2591, Loss_x=1.2591, Loss_u=0.0000, W=0.0000]


[Epoch 13] (naive) Test Accuracy: 0.4550
[Epoch 13] Test Accuracy: 45.5000


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 38.14batch/s, Data=0.019s, Batch=0.027s, Total=0:00:14, ETA=0:00:00, Loss=1.1895, Loss_x=1.1895, Loss_u=0.0000, W=0.0000]


[Epoch 14] (naive) Test Accuracy: 0.4425
[Epoch 14] Test Accuracy: 44.2500


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 38.13batch/s, Data=0.020s, Batch=0.026s, Total=0:00:14, ETA=0:00:00, Loss=1.0982, Loss_x=1.0982, Loss_u=0.0000, W=0.0000]


[Epoch 15] (naive) Test Accuracy: 0.4560
[Epoch 15] Test Accuracy: 45.6000


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 37.60batch/s, Data=0.020s, Batch=0.027s, Total=0:00:14, ETA=0:00:00, Loss=0.9961, Loss_x=0.9961, Loss_u=0.0000, W=0.0000]


[Epoch 16] (naive) Test Accuracy: 0.4710
[Epoch 16] Test Accuracy: 47.1000


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 38.13batch/s, Data=0.019s, Batch=0.027s, Total=0:00:14, ETA=0:00:00, Loss=0.9166, Loss_x=0.9166, Loss_u=0.0000, W=0.0000]


[Epoch 17] (naive) Test Accuracy: 0.4720
[Epoch 17] Test Accuracy: 47.2000


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 38.03batch/s, Data=0.019s, Batch=0.027s, Total=0:00:14, ETA=0:00:00, Loss=0.8338, Loss_x=0.8338, Loss_u=0.0000, W=0.0000]


[Epoch 18] (naive) Test Accuracy: 0.4795
[Epoch 18] Test Accuracy: 47.9500


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 38.18batch/s, Data=0.019s, Batch=0.026s, Total=0:00:14, ETA=0:00:00, Loss=0.7650, Loss_x=0.7650, Loss_u=0.0000, W=0.0000]


[Epoch 19] (naive) Test Accuracy: 0.4840
[Epoch 19] Test Accuracy: 48.4000


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 38.45batch/s, Data=0.020s, Batch=0.026s, Total=0:00:14, ETA=0:00:00, Loss=0.6781, Loss_x=0.6781, Loss_u=0.0000, W=0.0000]


[Epoch 20] (naive) Test Accuracy: 0.4920
[Epoch 20] Test Accuracy: 49.2000


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 37.70batch/s, Data=0.020s, Batch=0.027s, Total=0:00:14, ETA=0:00:00, Loss=0.6018, Loss_x=0.6018, Loss_u=0.0000, W=0.0000]


[Epoch 21] (naive) Test Accuracy: 0.4820
[Epoch 21] Test Accuracy: 48.2000


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 38.29batch/s, Data=0.020s, Batch=0.026s, Total=0:00:14, ETA=0:00:00, Loss=0.5338, Loss_x=0.5338, Loss_u=0.0000, W=0.0000]


[Epoch 22] (naive) Test Accuracy: 0.5090
[Epoch 22] Test Accuracy: 50.9000


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 37.74batch/s, Data=0.019s, Batch=0.027s, Total=0:00:14, ETA=0:00:00, Loss=0.4862, Loss_x=0.4862, Loss_u=0.0000, W=0.0000]


[Epoch 23] (naive) Test Accuracy: 0.4780
[Epoch 23] Test Accuracy: 47.8000


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 38.18batch/s, Data=0.019s, Batch=0.027s, Total=0:00:14, ETA=0:00:00, Loss=0.3941, Loss_x=0.3941, Loss_u=0.0000, W=0.0000]


[Epoch 24] (naive) Test Accuracy: 0.5050
[Epoch 24] Test Accuracy: 50.5000


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 38.25batch/s, Data=0.020s, Batch=0.027s, Total=0:00:14, ETA=0:00:00, Loss=0.3143, Loss_x=0.3143, Loss_u=0.0000, W=0.0000]


[Epoch 25] (naive) Test Accuracy: 0.5240
[Epoch 25] Test Accuracy: 52.4000


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 38.40batch/s, Data=0.019s, Batch=0.026s, Total=0:00:14, ETA=0:00:00, Loss=0.2766, Loss_x=0.2766, Loss_u=0.0000, W=0.0000]


[Epoch 26] (naive) Test Accuracy: 0.5070
[Epoch 26] Test Accuracy: 50.7000


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 38.36batch/s, Data=0.019s, Batch=0.026s, Total=0:00:14, ETA=0:00:00, Loss=0.2184, Loss_x=0.2184, Loss_u=0.0000, W=0.0000]


[Epoch 27] (naive) Test Accuracy: 0.5190
[Epoch 27] Test Accuracy: 51.9000


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 37.97batch/s, Data=0.020s, Batch=0.027s, Total=0:00:14, ETA=0:00:00, Loss=0.1813, Loss_x=0.1813, Loss_u=0.0000, W=0.0000]


[Epoch 28] (naive) Test Accuracy: 0.5030
[Epoch 28] Test Accuracy: 50.3000


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 37.81batch/s, Data=0.019s, Batch=0.027s, Total=0:00:14, ETA=0:00:00, Loss=0.1405, Loss_x=0.1405, Loss_u=0.0000, W=0.0000]


[Epoch 29] (naive) Test Accuracy: 0.5085
[Epoch 29] Test Accuracy: 50.8500


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 37.58batch/s, Data=0.019s, Batch=0.027s, Total=0:00:14, ETA=0:00:00, Loss=0.1084, Loss_x=0.1084, Loss_u=0.0000, W=0.0000]


[Epoch 30] (naive) Test Accuracy: 0.5185
[Epoch 30] Test Accuracy: 51.8500


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 37.48batch/s, Data=0.020s, Batch=0.027s, Total=0:00:14, ETA=0:00:00, Loss=0.0812, Loss_x=0.0812, Loss_u=0.0000, W=0.0000]


[Epoch 31] (naive) Test Accuracy: 0.5160
[Epoch 31] Test Accuracy: 51.6000


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 38.17batch/s, Data=0.019s, Batch=0.027s, Total=0:00:14, ETA=0:00:00, Loss=0.0529, Loss_x=0.0529, Loss_u=0.0000, W=0.0000]


[Epoch 32] (naive) Test Accuracy: 0.5275
[Epoch 32] Test Accuracy: 52.7500


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 37.43batch/s, Data=0.019s, Batch=0.027s, Total=0:00:14, ETA=0:00:00, Loss=0.0367, Loss_x=0.0367, Loss_u=0.0000, W=0.0000]


[Epoch 33] (naive) Test Accuracy: 0.5405
[Epoch 33] Test Accuracy: 54.0500


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 37.68batch/s, Data=0.020s, Batch=0.027s, Total=0:00:14, ETA=0:00:00, Loss=0.0259, Loss_x=0.0259, Loss_u=0.0000, W=0.0000]


[Epoch 34] (naive) Test Accuracy: 0.5430
[Epoch 34] Test Accuracy: 54.3000


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 38.17batch/s, Data=0.019s, Batch=0.027s, Total=0:00:14, ETA=0:00:00, Loss=0.0223, Loss_x=0.0223, Loss_u=0.0000, W=0.0000]


[Epoch 35] (naive) Test Accuracy: 0.5515
[Epoch 35] Test Accuracy: 55.1500


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 38.30batch/s, Data=0.019s, Batch=0.027s, Total=0:00:14, ETA=0:00:00, Loss=0.0119, Loss_x=0.0119, Loss_u=0.0000, W=0.0000]


[Epoch 36] (naive) Test Accuracy: 0.5385
[Epoch 36] Test Accuracy: 53.8500


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 37.46batch/s, Data=0.019s, Batch=0.027s, Total=0:00:14, ETA=0:00:00, Loss=0.0103, Loss_x=0.0103, Loss_u=0.0000, W=0.0000]


[Epoch 37] (naive) Test Accuracy: 0.5360
[Epoch 37] Test Accuracy: 53.6000


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 37.80batch/s, Data=0.019s, Batch=0.027s, Total=0:00:14, ETA=0:00:00, Loss=0.0052, Loss_x=0.0052, Loss_u=0.0000, W=0.0000]


[Epoch 38] (naive) Test Accuracy: 0.5510
[Epoch 38] Test Accuracy: 55.1000


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 37.55batch/s, Data=0.020s, Batch=0.027s, Total=0:00:14, ETA=0:00:00, Loss=0.0056, Loss_x=0.0056, Loss_u=0.0000, W=0.0000]


[Epoch 39] (naive) Test Accuracy: 0.5450
[Epoch 39] Test Accuracy: 54.5000


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 37.83batch/s, Data=0.019s, Batch=0.027s, Total=0:00:14, ETA=0:00:00, Loss=0.0048, Loss_x=0.0048, Loss_u=0.0000, W=0.0000]


[Epoch 40] (naive) Test Accuracy: 0.5485
[Epoch 40] Test Accuracy: 54.8500


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 37.82batch/s, Data=0.019s, Batch=0.027s, Total=0:00:14, ETA=0:00:00, Loss=0.0035, Loss_x=0.0035, Loss_u=0.0000, W=0.0000]


[Epoch 41] (naive) Test Accuracy: 0.5545
[Epoch 41] Test Accuracy: 55.4500


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 37.69batch/s, Data=0.020s, Batch=0.027s, Total=0:00:14, ETA=0:00:00, Loss=0.0036, Loss_x=0.0036, Loss_u=0.0000, W=0.0000]


[Epoch 42] (naive) Test Accuracy: 0.5560
[Epoch 42] Test Accuracy: 55.6000


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 38.01batch/s, Data=0.020s, Batch=0.027s, Total=0:00:14, ETA=0:00:00, Loss=0.0028, Loss_x=0.0028, Loss_u=0.0000, W=0.0000]


[Epoch 43] (naive) Test Accuracy: 0.5575
[Epoch 43] Test Accuracy: 55.7500


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 36.64batch/s, Data=0.019s, Batch=0.028s, Total=0:00:15, ETA=0:00:00, Loss=0.0025, Loss_x=0.0025, Loss_u=0.0000, W=0.0000]


[Epoch 44] (naive) Test Accuracy: 0.5580
[Epoch 44] Test Accuracy: 55.8000


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 37.43batch/s, Data=0.020s, Batch=0.027s, Total=0:00:14, ETA=0:00:00, Loss=0.0031, Loss_x=0.0031, Loss_u=0.0000, W=0.0000]


[Epoch 45] (naive) Test Accuracy: 0.5595
[Epoch 45] Test Accuracy: 55.9500


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 38.08batch/s, Data=0.019s, Batch=0.027s, Total=0:00:14, ETA=0:00:00, Loss=0.0025, Loss_x=0.0025, Loss_u=0.0000, W=0.0000]


[Epoch 46] (naive) Test Accuracy: 0.5620
[Epoch 46] Test Accuracy: 56.2000


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 38.31batch/s, Data=0.020s, Batch=0.026s, Total=0:00:14, ETA=0:00:00, Loss=0.0025, Loss_x=0.0025, Loss_u=0.0000, W=0.0000]


[Epoch 47] (naive) Test Accuracy: 0.5600
[Epoch 47] Test Accuracy: 56.0000


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 37.75batch/s, Data=0.019s, Batch=0.027s, Total=0:00:14, ETA=0:00:00, Loss=0.0021, Loss_x=0.0021, Loss_u=0.0000, W=0.0000]


[Epoch 48] (naive) Test Accuracy: 0.5605
[Epoch 48] Test Accuracy: 56.0500


Training Progress: 100%|██████████| 546/546 [00:14<00:00, 37.47batch/s, Data=0.020s, Batch=0.027s, Total=0:00:14, ETA=0:00:00, Loss=0.0028, Loss_x=0.0028, Loss_u=0.0000, W=0.0000]


[Epoch 49] (naive) Test Accuracy: 0.5600
[Epoch 49] Test Accuracy: 56.0000
Best acc:
0.0
Mean acc:
54.720000000000006
