In [20]:
# import torch
# print(f"torch: {torch.__version__}")
# import torchvision
# print(f"torchvisio: {torchvision.__version__}")
# import torchaudio
# print(f"torchaudio: {torchaudio.__version__}")
# import torchtext
# print(f"torchtext: {torchtext.__version__}")
# import torchdata
# print(f"torchdata: {torchdata.__version__}")

In [1]:
import hydra

In [2]:
print(hydra.__version__)

1.3.2


In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
import numpy as np
import random

import torchvision
import torchvision.transforms as transforms

from src.utils.optim.schedulers import CosineWarmup
from src.models.sequence.rnns.rnn import RNN


import os
import argparse
from tqdm.auto import tqdm


import wandb
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33myuqinzhou[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [22]:
# Use cuda if present
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Device available now:', device)

if device == 'cuda':
    cudnn.benchmark = True

Device available now: cpu


## Defining functions

### Hyperparameters

In [30]:
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
# Dataset
parser.add_argument('--dataset', default='cifar10', choices=['cifar10', 'listops', 'imdb', 'aan', 'pathfinder'], type=str, help='Dataset')
###! imdb refers to TEXT, ann refers to RETRIEVAL 
parser.add_argument('--grayscale', action='store_true', help='Use grayscale CIFAR10')


# Dataloader
parser.add_argument('--num_workers', default=0, type=int, help='Number of workers to use for dataloader')
parser.add_argument('--batch_size', default=50, type=int, help='Batch size')


# Optimizer
parser.add_argument('--lr', default= 0.01, type=float, help='Learning rate') # 0.01
parser.add_argument('--lr_factor', default= 0.25, type=float, help='Factor of Learning rate') 
parser.add_argument('--weight_decay', default=0.05, type=float, help='Weight decay')



# Scheduler
parser.add_argument('--epochs', default=200, type=float, help='Training epochs')


# Model
parser.add_argument('--n_layers', default=2, type=int, help='Number of layers') #6
parser.add_argument('--d_model', default=5, type=int, help='Model dimension') #512
parser.add_argument('--d_hidden', default=5, type=int, help='Hidden (state) dimension ') #384
parser.add_argument('--dropout', default=0.1, type=float, help='Dropout')
parser.add_argument('--prenorm', action='store_false', help='Prenorm')
parser.add_argument('--norm', default= 'BN', choices=['LN', 'BN'], help='Norm types')
parser.add_argument('--cell', default= 'rnn', type=str, help='RNN\'s cell')


# General
parser.add_argument('--resume', '-r', action='store_true', help='Resume from checkpoint')

# args = parser.parse_args()
args, unknown = parser.parse_known_args()

In [31]:
print(args)

Namespace(batch_size=50, cell='rnn', d_hidden=5, d_model=5, dataset='cifar10', dropout=0.1, epochs=200, grayscale=False, lr=0.01, lr_factor=0.25, n_layers=2, norm='BN', num_workers=0, prenorm=True, resume=False, weight_decay=0.05)


In [69]:
# d = RNN(d_input = 3, d_model = 5, lr = args.lr * args.lr_factor, cell = "rnn", return_output=True, transposed=False, dropout=0)

In [70]:
for i in d.parameters():
    print(i._optim)

{'weight_decay': 0.0, 'lr': 0.0025}
{'weight_decay': 0.0, 'lr': 0.0025}


### Datasets

In [37]:
def split_train_val(train, val_split):
    train_len = int(len(train) * (1.0-val_split))
    train, val = torch.utils.data.random_split(
        train,
        (train_len, len(train) - train_len),
        generator=torch.Generator().manual_seed(42),
    )
    return train, val

In [38]:
if args.dataset == 'cifar10':
    if args.grayscale:
        transform = transforms.Compose([
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize(mean=122.6 / 255.0, std=61.0 / 255.0),
            transforms.Lambda(lambda x: x.view(1, 1024).t())
        ])
    else:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            transforms.Lambda(lambda x: x.view(3, 1024).t())
        ])
    
    # S4 is trained on sequences with no data augmentation!
    transform_train = transform_test = transform

    trainset = torchvision.datasets.CIFAR10(
        root='./data/cifar/', train=True, download=True, transform=transform_train)
    trainset, _ = split_train_val(trainset, val_split=0.1)

    valset = torchvision.datasets.CIFAR10(
        root='./data/cifar/', train=True, download=True, transform=transform_test)
    _, valset = split_train_val(valset, val_split=0.1)

    testset = torchvision.datasets.CIFAR10(
        root='./data/cifar/', train=False, download=True, transform=transform_test)

    d_input = 3 if not args.grayscale else 1
    d_output = 10

else: raise NotImplementedError

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [39]:
# Dataloaders
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
valloader = torch.utils.data.DataLoader(
    valset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)


In [40]:
# Taking a single batch of the images
images, labels = next(iter(trainloader))
print(images.size(), labels)

torch.Size([50, 1024, 3]) tensor([1, 9, 6, 6, 6, 9, 6, 9, 7, 3, 1, 2, 8, 0, 9, 4, 3, 8, 0, 6, 2, 3, 8, 2,
        4, 7, 5, 3, 8, 0, 1, 3, 3, 3, 1, 4, 8, 7, 8, 0, 3, 3, 0, 1, 3, 5, 4, 0,
        8, 2])


In [41]:
len(trainset), len(trainloader)

(45000, 900)

NameError: name 'torch' is not defined

### Model architecture

In [43]:
class RNNbased(nn.Module):

    def __init__(
        self,
        d_input,
        d_output,
        lr,
        cell='rnn',
        d_model=256,
        d_hidden=128,
        n_layers=2,
        dropout=0.2,
        prenorm=True,
    ):
        super().__init__()

        self.prenorm = prenorm

        # Linear encoder (d_input = 1 for grayscale and 3 for RGB) (like embedding layer)
        self.encoder = nn.Linear(d_input, d_model)

        # Stack S4 layers as residual blocks
        self.layers = nn.ModuleList()
        self.norms = nn.ModuleList()
        self.dropouts = nn.ModuleList()
        self.FFNs = nn.ModuleList()

        for _ in range(n_layers):
            self.layers.append(
                RNN(d_input = d_model, d_model = d_hidden, lr = lr, cell = cell, return_output=True, transposed=False, dropout=0)
            )
            # self.norms.append(nn.LayerNorm(d_model))
            self.norms.append(nn.BatchNorm1d(d_model)) 
            self.dropouts.append(nn.Dropout1d(dropout))
            self.FFNs.append(nn.Sequential(nn.Linear(d_hidden, d_model*2), nn.GLU())    #shall test FFN + GELU later                  
                                 )

        # Linear decoder
        self.decoder = nn.Linear(d_model, d_output)

    def forward(self, x):
        """
        Input x is shape (B, L, d_input)
        """
        x = self.encoder(x)  # (B, L, d_input) -> (B, L, d_model)
        
        for layer, norm, dropout, FFN in zip(self.layers, self.norms, self.dropouts, self.FFNs):
            ''' Each iteration of this loop will map (B, L, d_model) -> (B, L, d_model) '''

            
            z = x #(B, L, d_model) -> (B, L, d_model)
            if self.prenorm:
                # Prenorm (BN)
                z = norm(z.transpose(-1, -2)).transpose(-1, -2) #(B, L, d_model) -> (B, L, d_model)

            # Apply recurrence: we ignore the state input and output
            z, _ = layer(z) #(B, L, d_model) -> (B, L, d_hidden)

            # Dropout on the output of the Recurrence block
            z = dropout(z) #(B, L, d_hidden) -> (B, L, d_hidden)
            
            # MLP +GLP
            z = FFN(z) #(B, L, d_hidden) -> (B, L, d_model)

            z = dropout(z) # (B, L, d_model) -> (B, L, d_model)

            # Residual connection
            x = z + x  # (B, L, d_model) -> (B, L, d_model)

            if not self.prenorm:
                # Post-norm (BN)
                x = norm(x.transpose(-1, -2)).transpose(-1, -2) # (B, L, d_model) -> (B, L, d_model)
                
        # Pooling: average pooling over the sequence length
        x = x.mean(dim=1) # (B, L, d_model) -> (B, d_model)

        # Decode the outputs
        x = self.decoder(x)  # (B, d_model) -> (B, d_output)

        return x

In [67]:
example = RNNbased(d_input=d_input, 
                   d_output=d_output, 
                   cell='rnn',
                   lr = args.lr * args.lr_factor,
                   d_model=args.d_model, 
                   d_hidden=args.d_hidden, 
                   n_layers=args.n_layers, 
                   dropout=args.dropout, 
                   prenorm=args.prenorm)
print(example)

RNNbased(
  (encoder): Linear(in_features=3, out_features=5, bias=True)
  (layers): ModuleList(
    (0): RNN(
      (cell): RNNCell(
        (W_hx): Linear(in_features=5, out_features=5, bias=False)
        (activate): Tanh()
        (W_hh): Linear(in_features=5, out_features=5, bias=False)
      )
    )
    (1): RNN(
      (cell): RNNCell(
        (W_hx): Linear(in_features=5, out_features=5, bias=False)
        (activate): Tanh()
        (W_hh): Linear(in_features=5, out_features=5, bias=False)
      )
    )
  )
  (norms): ModuleList(
    (0): BatchNorm1d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): BatchNorm1d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (dropouts): ModuleList(
    (0): Dropout1d(p=0.1, inplace=False)
    (1): Dropout1d(p=0.1, inplace=False)
  )
  (FFNs): ModuleList(
    (0): Sequential(
      (0): Linear(in_features=5, out_features=10, bias=True)
      (1): GLU(dim=-1)
    )
    (1): Sequential(
      (0): 

In [59]:
def setup_optimizer(model, lr, weight_decay, epochs):
    # All parameters in the model
    all_parameters = list(model.parameters())

    # General parameters don't contain the special _optim key
    params = [p for p in all_parameters if not hasattr(p, "_optim")]

    # Create an optimizer with the general parameters
    optimizer = optim.AdamW(params, lr=lr, weight_decay=weight_decay)

    # Add parameters with special hyperparameters
    hps = [getattr(p, "_optim") for p in all_parameters if hasattr(p, "_optim")]
        # e.g., p could be {'weight_decay': 0.0, 'lr': 1e-07}
    hps = [
        dict(s) for s in sorted(list(dict.fromkeys(frozenset(hp.items()) for hp in hps)))
    ]  # Unique dicts 
    
    
    for hp in hps: 
        params = [p for p in all_parameters if getattr(p, "_optim", None) == hp] ## select parameter matrices that have "_optim" and assign "_optim = None" to matrices that do not have
        optimizer.add_param_group(
            {"params": params, **hp} ## <**hp> referes to hyperparameters e.g., {'weight_decay': 0.0}
        )

    # Create a lr scheduler
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=patience, factor=0.2)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    scheduler = CosineWarmup(optimizer, T_max = epochs, eta_min= 1e-7, warmup_step= int(epochs * 0.1) + 1) 

    ''' Print optimizer info '''
    keys = sorted(set([k for hp in hps for k in hp.keys()]))
    
    
    for i, g in enumerate(optimizer.param_groups):
        group_hps = {k: g.get(k, None) for k in keys}
        print(' | '.join([
            f"Optimizer group {i}",
            f"{len(g['params'])} tensors",
        ] + [f"{k} {v}" for k, v in group_hps.items()]))

    return optimizer, scheduler

In [61]:
# example
optimizer, scheduler = setup_optimizer(
    example, lr=args.lr, weight_decay=args.weight_decay, epochs=args.epochs
)

Optimizer group 0 | 12 tensors | lr 0.0004761904761904762 | weight_decay 0.05
Optimizer group 1 | 4 tensors | lr 0.00011904761904761905 | weight_decay 0.0


### Training

In [62]:
# Training
def train(model, optimizer,  criterion):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    pbar = tqdm(enumerate(trainloader))
    print_every = 50

    for batch_idx, (inputs, targets) in pbar:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        wandb.log({"Train/Batch loss": train_loss/(batch_idx+1)})

        if (batch_idx % print_every) == 0:
            print(f"Batch: {batch_idx}, Avg: {train_loss/(batch_idx+1)}")
        
        pbar.set_description(
            'Batch Idx: (%d/%d) | Loss: %.3f | Acc: %.3f%% (%d/%d)' %
            (batch_idx, len(trainloader), train_loss/(batch_idx+1), 100.*correct/total, correct, total)
        )
    return train_loss/(batch_idx+1)



def eval(model, criterion, dataloader):
    global best_acc
    model.eval()
    eval_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        pbar = tqdm(enumerate(dataloader))
        for batch_idx, (inputs, targets) in pbar:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            eval_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()


            pbar.set_description(
                'Batch Idx: (%d/%d) | Loss: %.3f | Acc: %.3f%% (%d/%d)' %
                (batch_idx, len(dataloader), eval_loss/(batch_idx+1), 100.*correct/total, correct, total)
            )
            
        return 100.*correct/total

Loss

In [81]:
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
run_id = wandb.util.generate_id()
CHECKPOINT_PATH = f'./checkpoint/checkpoint_{run_id}.pth'
print(CHECKPOINT_PATH)

./checkpoint/checkpoint_tvli55po.pth


In [82]:
total_runs = 1
for run in range(total_runs):
    wandb.init(
        id= run_id,
        project="test", 
        # Model + Cell + Run
        name=f"RNN_Vanilla_{run}", 
        config=args,
        resume = 'allow')
    
    # defining model (resume or not)
    if not wandb.run.resumed:
        print('==> Building model / ...')
        model = RNNbased(d_input=d_input, 
                    d_output=d_output, 
                    lr = args.lr * args.lr_factor,
                    cell='rnn',
                    d_model=args.d_model, 
                    d_hidden=args.d_hidden, 
                    n_layers=args.n_layers, 
                    dropout=args.dropout, 
                    prenorm=args.prenorm)
        
        optimizer, scheduler = setup_optimizer(model, lr=args.lr, weight_decay=args.weight_decay, epochs=args.epochs)
        # optimizer = optim.AdamW(model.parameters(), lr= args.lr, weight_decay= args.weight_decay)
        # scheduler = CosineWarmup(optimizer, T_max = args.epochs, eta_min= 1e-7, warmup_step= int(args.epochs * 0.1) + 1) 
        
    else:
        print('==> Resuming from checkpoint...')
        checkpoint = torch.load(CHECKPOINT_PATH) #not use wandb.restore('checkpoint.tar') because of encoding error
        
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        best_acc = checkpoint['acc']
        start_epoch = checkpoint['epoch']
        
    model = model.to(device)

    

    ## defining training, validating and testing
    pbar = tqdm(range(start_epoch, args.epochs))
    for epoch in pbar:
        wandb.log({"Epoch": epoch, "lr_general": scheduler.get_last_lr()[0]}) #record general lr for the current lr
        wandb.log({"Epoch": epoch, "lr_special": scheduler.get_last_lr()[1]}) #record special lr for the current lr
        
        if epoch == 0:
            pbar.set_description('Epoch: %d' % (epoch))
        else:
            pbar.set_description('Epoch: %d | Val acc: %1.3f' % (epoch, val_acc))


        print('==> Training...')
        epoch_loss = train(model = model,optimizer = optimizer, criterion = nn.CrossEntropyLoss())
        wandb.log({"Epoch": epoch, "Train/Epoch Loss": epoch_loss})


        print('==> Validating...')
        val_acc = eval(model = model, criterion = nn.CrossEntropyLoss(), dataloader = valloader)
        wandb.log({"Epoch": epoch, "Val/Val acc": val_acc})
        
        scheduler.step() #update lr
        
        #Save checkpoints
        if val_acc > best_acc:
            state = {
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(), # record lr for the next epoch
                'acc': val_acc,
                'epoch': epoch,

            }
            if not os.path.isdir('checkpoint'):
                os.mkdir('checkpoint')
                
            torch.save(state, CHECKPOINT_PATH)
            # wandb.save(CHECKPOINT_PATH)
            best_acc = val_acc
    
    print('==> Testing...')
    test_acc = eval(model = model, criterion = nn.CrossEntropyLoss(), dataloader = testloader)
    wandb.log({"Test/Test acc": test_acc})
    
    # Mark the run as finished
    wandb.finish()

VBox(children=(Label(value='0.024 MB of 0.024 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Epoch,▁▁▁▁██
Train/Batch loss,████▇▇▇▇▇▇▇▇▇▇▇▇▆▆▆▆▆▆▆▆▄▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁
Train/Epoch Loss,▁
Val/Val acc,▁
lr,▄▁█▂

0,1
Epoch,1.0
Train/Batch loss,2.09261
Train/Epoch Loss,2.27602
Val/Val acc,17.14
lr,0.00024


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01675181388333158, max=1.0)…

==> Building model / ...
Optimizer group 0 | 12 tensors | lr 0.0004761904761904762 | weight_decay 0.05
Optimizer group 1 | 4 tensors | lr 0.00011904761904761905 | weight_decay 0.0


Epoch: 0:   0%|          | 0/200 [00:00<?, ?it/s]

==> Training...




Batch: 0, Avg: 2.3438615798950195




Batch: 50, Avg: 2.352909635095035




Batch: 100, Avg: 2.331971444705925




Batch: 150, Avg: 2.3215812278899137




Batch: 200, Avg: 2.3131653740631406




Batch: 250, Avg: 2.3026826144214647




Batch: 300, Avg: 2.293034563032892




Batch: 350, Avg: 2.286021816085207




Batch: 400, Avg: 2.2779499402367267




Batch: 450, Avg: 2.270503469158435




Batch: 500, Avg: 2.2634049013941113




Batch: 550, Avg: 2.2559005755478156




Batch: 600, Avg: 2.2488120102049307




Batch: 650, Avg: 2.2413552118702786




Batch: 700, Avg: 2.234190433749799




Batch: 750, Avg: 2.2277191736091786




Batch: 800, Avg: 2.221757888198643




Batch: 850, Avg: 2.215749137550067


Batch Idx: (899/900) | Loss: 2.210 | Acc: 16.522% (7435/45000): : 900it [03:53,  3.85it/s]


==> Validating...


Batch Idx: (99/100) | Loss: 2.105 | Acc: 21.560% (1078/5000): : 100it [00:08, 11.84it/s]
Epoch: 1 | Val acc: 21.560:   0%|          | 1/200 [04:02<13:23:44, 242.33s/it]

==> Training...




Batch: 0, Avg: 2.1363587379455566




Batch: 50, Avg: 2.0947599013646445




Batch: 100, Avg: 2.0936288810012362




Batch: 150, Avg: 2.0914106447965106




Batch: 200, Avg: 2.085969399456954




Batch: 250, Avg: 2.080829320200886




Batch: 300, Avg: 2.0789904586500505




Batch: 350, Avg: 2.0764146754544686




Batch: 400, Avg: 2.0742721370330774




Batch: 450, Avg: 2.073858916098686




Batch: 500, Avg: 2.0750527891094337




Batch: 550, Avg: 2.0754631580328553




Batch: 600, Avg: 2.074043715654713




Batch: 650, Avg: 2.0747376923920005




Batch: 700, Avg: 2.074551219606876




Batch: 750, Avg: 2.0728394145495725




Batch: 800, Avg: 2.072138767563895




Batch: 850, Avg: 2.0720395264418228


Batch Idx: (899/900) | Loss: 2.070 | Acc: 21.904% (9857/45000): : 900it [03:48,  3.94it/s]


==> Validating...


Batch Idx: (99/100) | Loss: 2.037 | Acc: 22.720% (1136/5000): : 100it [00:08, 12.22it/s]
Epoch: 2 | Val acc: 22.720:   1%|          | 2/200 [07:59<13:08:49, 239.04s/it]

==> Training...




Batch: 0, Avg: 2.2228989601135254




Batch: 50, Avg: 2.0424792415955486




Batch: 100, Avg: 2.0458858721327076




Batch: 150, Avg: 2.041804799970412




Batch: 200, Avg: 2.0388152569680664




Batch: 250, Avg: 2.0401026822656276




Batch: 300, Avg: 2.038749077787431




Batch: 350, Avg: 2.036048333869021




Batch: 400, Avg: 2.033569275292375




Batch: 450, Avg: 2.0313177124623976




Batch: 500, Avg: 2.0279384587339297




Batch: 550, Avg: 2.030109417416872




Batch: 600, Avg: 2.02917235088031




Batch: 650, Avg: 2.0281574649195515




Batch: 700, Avg: 2.0242219971182003




Batch: 750, Avg: 2.0230739470645687




Batch: 800, Avg: 2.023500199026234




Batch: 850, Avg: 2.0214989675618225


Batch Idx: (899/900) | Loss: 2.020 | Acc: 24.107% (10848/45000): : 900it [03:48,  3.94it/s]


==> Validating...


Batch Idx: (99/100) | Loss: 1.997 | Acc: 25.640% (1282/5000): : 100it [00:08, 12.35it/s]
Epoch: 3 | Val acc: 25.640:   2%|▏         | 3/200 [11:55<13:01:21, 237.98s/it]

==> Training...




Batch: 0, Avg: 1.8451147079467773




Batch: 50, Avg: 1.9943430961347093




Batch: 100, Avg: 1.9927877919508679




Batch: 150, Avg: 1.9942537506684561




Batch: 200, Avg: 1.9972219983143593




Batch: 250, Avg: 1.9954601502513505




Batch: 300, Avg: 1.9949078310367674




Batch: 350, Avg: 1.9939959375267355




Batch: 400, Avg: 1.994781742369445




Batch: 450, Avg: 1.995919249274514




Batch: 500, Avg: 1.9961761804873834




Batch: 550, Avg: 1.9953054918784194




Batch: 600, Avg: 1.9947256813826855




Batch: 650, Avg: 1.9959495442986672




Batch: 700, Avg: 1.9974948222218838




Batch: 750, Avg: 1.997178655172315




Batch: 800, Avg: 1.9985641278280004




Batch: 850, Avg: 1.9981115985000737


Batch Idx: (899/900) | Loss: 1.997 | Acc: 25.122% (11305/45000): : 900it [03:52,  3.87it/s]


==> Validating...


Batch Idx: (99/100) | Loss: 1.988 | Acc: 26.000% (1300/5000): : 100it [00:08, 11.99it/s]
Epoch: 4 | Val acc: 26.000:   2%|▏         | 4/200 [15:56<13:01:11, 239.14s/it]

==> Training...




Batch: 0, Avg: 2.2407357692718506




Batch: 50, Avg: 1.9843457165886373




Batch: 100, Avg: 1.9961516656497917




Batch: 150, Avg: 1.9890808198625678




Batch: 200, Avg: 1.9955175219483636




Batch: 250, Avg: 2.0006637991187106




Batch: 300, Avg: 2.001642516285082




Batch: 350, Avg: 1.9994675182549022




Batch: 400, Avg: 1.9954101510178717




Batch: 450, Avg: 1.994510409572965




Batch: 500, Avg: 1.9955311114202716




Batch: 550, Avg: 1.9967334876259528




Batch: 600, Avg: 1.9962522812968682




Batch: 650, Avg: 1.9955623854873001




Batch: 700, Avg: 1.9950861383266694




Batch: 750, Avg: 1.9954237509345247




Batch: 800, Avg: 1.9940524425697088




Batch: 850, Avg: 1.9944224041020968


Batch Idx: (899/900) | Loss: 1.993 | Acc: 25.391% (11426/45000): : 900it [03:50,  3.91it/s]


==> Validating...


Batch Idx: (99/100) | Loss: 1.983 | Acc: 25.440% (1272/5000): : 100it [00:07, 12.62it/s]
Epoch: 5 | Val acc: 25.440:   2%|▎         | 5/200 [19:54<12:55:53, 238.74s/it]

==> Training...




Batch: 0, Avg: 1.9140570163726807




Batch: 50, Avg: 1.9473319053649902




Batch: 100, Avg: 1.9738628049888234




Batch: 150, Avg: 1.9809631227657496




Batch: 200, Avg: 1.97920061937019




Batch: 250, Avg: 1.9829230816715742




Batch: 300, Avg: 1.9854135620237585




Batch: 350, Avg: 1.9899782716718495




Batch: 400, Avg: 1.9955439175156287




Batch: 450, Avg: 1.9927079688154674




Batch: 500, Avg: 1.9943043828724387




Batch: 550, Avg: 1.9939172241951724




Batch: 600, Avg: 1.994413694803806




Batch: 650, Avg: 1.9941324396616853


Batch Idx: (700/900) | Loss: 1.994 | Acc: 25.546% (8954/35050): : 701it [02:59,  3.90it/s]
Epoch: 5 | Val acc: 25.440:   2%|▎         | 5/200 [22:54<14:53:31, 274.93s/it]

Batch: 700, Avg: 1.9937379936688977





KeyboardInterrupt: 