## Environment and Import

In [1]:
%cd ..

/Users/zhouyuqin/Desktop/Thesis/experiments/state-spaces


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
import numpy as np
import random

import torchvision
import torchvision.transforms as transforms


import os
import argparse
from tqdm.auto import tqdm

from src.utils.optim.schedulers import CosineWarmup
from src.models.sequence.rnns.cells.basic import RNNCell
from src.models.sequence.rnns.cells import CellBase
from src.models.sequence.rnns.rnn import RNN

# from models.s4.s4 import S4
from models.s4.s4d import S4D

In [4]:
# Use cuda if present
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Device available now:', device)

Device available now: cpu


In [5]:
def set_seed(seed = 1234):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    
set_seed()

## Hyperparameters

In [6]:
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
# Optimizer
parser.add_argument('--lr', default=0.001, type=float, help='Learning rate')
parser.add_argument('--weight_decay', default=0.05, type=float, help='Weight decay')


# Scheduler
# parser.add_argument('--patience', default=10, type=float, help='Patience for learning rate scheduler')
parser.add_argument('--epochs', default=10, type=float, help='Training epochs')


# Dataset
parser.add_argument('--dataset', default='cifar10', choices=['mnist', 'cifar10'], type=str, help='Dataset')
parser.add_argument('--grayscale', action='store_true', help='Use grayscale CIFAR10')


# Dataloader
parser.add_argument('--num_workers', default=0, type=int, help='Number of workers to use for dataloader')
parser.add_argument('--batch_size', default=64, type=int, help='Batch size')


# Model
parser.add_argument('--n_layers', default=2, type=int, help='Number of layers')
parser.add_argument('--d_model', default=128, type=int, help='Model dimension')
parser.add_argument('--dropout', default=0.1, type=float, help='Dropout')
parser.add_argument('--prenorm', action='store_true', help='Prenorm')


# General
parser.add_argument('--resume', '-r', action='store_true', help='Resume from checkpoint')

# args = parser.parse_args()
args, unknown = parser.parse_known_args()

best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

In [7]:
# Data
print(f'==> Preparing {args.dataset} data..')
print(args)

==> Preparing cifar10 data..
Namespace(batch_size=64, d_model=128, dataset='cifar10', dropout=0.1, epochs=10, grayscale=False, lr=0.001, n_layers=2, num_workers=0, prenorm=False, resume=False, weight_decay=0.05)


In [8]:
args.resume

False

## Datasets

In [9]:
def split_train_val(train, val_split):
    train_len = int(len(train) * (1.0-val_split))
    train, val = torch.utils.data.random_split(
        train,
        (train_len, len(train) - train_len),
        generator=torch.Generator().manual_seed(42),
    )
    return train, val

In [10]:
if args.dataset == 'cifar10':

    if args.grayscale:
        transform = transforms.Compose([
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize(mean=122.6 / 255.0, std=61.0 / 255.0),
            transforms.Lambda(lambda x: x.view(1, 1024).t())
        ])
    else:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            transforms.Lambda(lambda x: x.view(3, 1024).t())
        ])

    # S4 is trained on sequences with no data augmentation!
    transform_train = transform_test = transform

    trainset = torchvision.datasets.CIFAR10(
        root='./data/cifar/', train=True, download=True, transform=transform_train)
    trainset, _ = split_train_val(trainset, val_split=0.1)

    valset = torchvision.datasets.CIFAR10(
        root='./data/cifar/', train=True, download=True, transform=transform_test)
    _, valset = split_train_val(valset, val_split=0.1)

    testset = torchvision.datasets.CIFAR10(
        root='./data/cifar/', train=False, download=True, transform=transform_test)

    d_input = 3 if not args.grayscale else 1
    d_output = 10

elif args.dataset == 'mnist':

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.view(1, 784).t())
    ])
    transform_train = transform_test = transform

    trainset = torchvision.datasets.MNIST(
        root='./data', train=True, download=True, transform=transform_train)
    trainset, _ = split_train_val(trainset, val_split=0.1)

    valset = torchvision.datasets.MNIST(
        root='./data', train=True, download=True, transform=transform_test)
    _, valset = split_train_val(valset, val_split=0.1)

    testset = torchvision.datasets.MNIST(
        root='./data', train=False, download=True, transform=transform_test)

    d_input = 1
    d_output = 10
else: raise NotImplementedError

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [11]:
# Dataloaders
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
valloader = torch.utils.data.DataLoader(
    valset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)


In [12]:
# Taking a single batch of the images
images, labels = next(iter(trainloader))
print(images.size(), labels)

torch.Size([64, 1024, 3]) tensor([4, 5, 8, 4, 3, 9, 8, 6, 9, 6, 4, 2, 8, 5, 5, 3, 1, 1, 7, 5, 8, 3, 2, 3,
        9, 6, 2, 8, 1, 8, 9, 8, 0, 4, 8, 8, 2, 4, 6, 0, 5, 4, 9, 6, 5, 0, 5, 0,
        5, 7, 0, 6, 5, 9, 0, 0, 6, 3, 7, 0, 4, 8, 7, 4])


In [13]:
len(trainset)

45000

## Model architecture

In [14]:
# Dropout broke in PyTorch 1.11
if tuple(map(int, torch.__version__.split('.')[:2])) == (1, 11):
    print("WARNING: Dropout is bugged in PyTorch 1.11. Results may be worse.")
    dropout_fn = nn.Dropout
if tuple(map(int, torch.__version__.split('.')[:2])) >= (1, 12):
    dropout_fn = nn.Dropout1d
else:
    dropout_fn = nn.Dropout2d
print(dropout_fn)

<class 'torch.nn.modules.dropout.Dropout1d'>


In [15]:
class Model(nn.Module):

    def __init__(
        self,
        d_input,
        d_output=10,
        d_model=256,
        n_layers=2,
        dropout=0.2,
        prenorm=True,
    ):
        super().__init__()

        self.prenorm = prenorm

        # Linear encoder (d_input = 1 for grayscale and 3 for RGB) (like embedding layer)
        self.encoder = nn.Linear(d_input, d_model)

        # Stack S4 layers as residual blocks
        self.layers = nn.ModuleList()
        self.norms = nn.ModuleList()
        self.dropouts = nn.ModuleList()
        self.FFNs = nn.ModuleList()


        for _ in range(n_layers):
            self.layers.append(
                RNN(d_input = args.d_model, d_model = args.d_model, cell = 'rnn', return_output=True, transposed=False, dropout=0.0)
                # nn.RNN(d_model, d_model, nonlinearity = "relu")
                # S4D(d_model, dropout=dropout, transposed=False, lr=min(1e-7, args.lr))
            )
            self.norms.append(nn.LayerNorm(d_model))
            self.dropouts.append(dropout_fn(dropout))
            self.FFNs.append(nn.Sequential(nn.Linear(d_model, d_model*2), nn.GLU())                     
                                 )

        # Linear decoder
        self.decoder = nn.Linear(d_model, d_output)

    def forward(self, x):
        """
        Input x is shape (B, L, d_input)
        """
        x = self.encoder(x)  # (B, L, d_input) -> (B, L, d_model)
        for layer, norm, dropout, FFN in zip(self.layers, self.norms, self.dropouts, self.FFNs):
            # Each iteration of this loop will map (B, L, d_model) -> (B, L, d_model)
            z = x
            if self.prenorm:
                # Prenorm
                z = norm(z)

            # Apply recurrence: we ignore the state input and output
            # z, _ = layer(z)

            z, _ = layer(z)

            # Dropout on the output of the S4 block
            z = dropout(z)
            
            # MLP +GLP
            z = FFN(z)

            z = dropout(z)

            # Residual connection
            x = z + x

        # Pooling: average pooling over the sequence length
        x = x.mean(dim=1)

        # Decode the outputs
        x = self.decoder(x)  # (B, d_model) -> (B, d_output)

        return x

In [16]:
# Model
print('==> Building model..')
model = Model(
    d_input=d_input,
    d_output=d_output,
    d_model=args.d_model,
    n_layers=args.n_layers,
    dropout=args.dropout,
    prenorm=args.prenorm,
)

model = model.to(device)
if device == 'cuda':
    cudnn.benchmark = True

==> Building model..


In [17]:
print(model)
sum(p.numel() for p in model.parameters() if p.requires_grad)

# Print model's state_dict
# print("Model's state_dict:")
# for param_tensor in model.state_dict():
#     print(param_tensor, "\t", model.state_dict()[param_tensor].size())

Model(
  (encoder): Linear(in_features=3, out_features=128, bias=True)
  (layers): ModuleList(
    (0-1): 2 x RNN(
      (cell): RNNCell(
        (W_hx): Linear(in_features=128, out_features=128, bias=False)
        (activate): Tanh()
        (W_hh): Linear(in_features=128, out_features=128, bias=False)
      )
    )
  )
  (norms): ModuleList(
    (0-1): 2 x LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (dropouts): ModuleList(
    (0-1): 2 x Dropout1d(p=0.1, inplace=False)
  )
  (FFNs): ModuleList(
    (0-1): 2 x Sequential(
      (0): Linear(in_features=128, out_features=256, bias=True)
      (1): GLU(dim=-1)
    )
  )
  (decoder): Linear(in_features=128, out_features=10, bias=True)
)


133898

## Training

### Optimizer and Learning rate schedule

In [24]:
## for S4, S4D

def setup_optimizer(model, lr, weight_decay, epochs):
    """
    S4 requires a specific optimizer setup.
    The S4 layer (A, B, C, dt) parameters typically
    require a smaller learning rate (typically 0.001), with no weight decay.
    The rest of the model can be trained with a higher learning rate (e.g. 0.004, 0.01)
    and weight decay (if desired).
    """

    # All parameters in the model
    all_parameters = list(model.parameters())

    # General parameters don't contain the special _optim key
    params = [p for p in all_parameters if not hasattr(p, "_optim")]

    # Create an optimizer with the general parameters
    optimizer = optim.AdamW(params, lr=lr, weight_decay=weight_decay)

    # Add parameters with special hyperparameters
    hps = [getattr(p, "_optim") for p in all_parameters if hasattr(p, "_optim")]
        # e.g., p could be {'weight_decay': 0.0, 'lr': 1e-07}
    hps = [
        dict(s) for s in sorted(list(dict.fromkeys(frozenset(hp.items()) for hp in hps)))
    ]  # Unique dicts 
    
    
    for hp in hps: 
        params = [p for p in all_parameters if getattr(p, "_optim", None) == hp] ## select parameter matrices that have "_optim" and assign "_optim = None" to matrices that do not have
        optimizer.add_param_group(
            {"params": params, **hp} ## <**hp> referes to hyperparameters e.g., {'weight_decay': 0.0}
        )

    # Create a lr scheduler
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=patience, factor=0.2)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)

    ''' Print optimizer info '''
    keys = sorted(set([k for hp in hps for k in hp.keys()]))
    
    
    for i, g in enumerate(optimizer.param_groups):
        group_hps = {k: g.get(k, None) for k in keys}
        print(' | '.join([
            f"Optimizer group {i}",
            f"{len(g['params'])} tensors",
        ] + [f"{k} {v}" for k, v in group_hps.items()]))

    return optimizer, scheduler

optimizer, scheduler = setup_optimizer(
    model, lr=args.lr, weight_decay=args.weight_decay, epochs=args.epochs
)

Optimizer group 0 | 20 tensors


In [316]:
### For our implmentation
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay= args.weight_decay)
scheduler = CosineWarmup(optimizer, T_max = args.epochs, eta_min= 1e-7, warmup_step= int(args.epochs * 0.1) + 1) 
print(int(args.epochs * 0.1) + 1)

2


In [25]:
# Print optimizer's state_dict | ## parameter here seems like the index of the parmeter matrics 
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])


Optimizer's state_dict:
state 	 {}
param_groups 	 [{'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0.05, 'amsgrad': False, 'foreach': None, 'maximize': False, 'capturable': False, 'differentiable': False, 'fused': None, 'initial_lr': 0.001, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]}]


Loss

In [26]:
criterion = nn.CrossEntropyLoss()

### Training Process (train())

In [27]:
if args.resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/ckpt.pth')
    model.load_state_dict(checkpoint['model'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

In [28]:
###############################################################################
# Everything after this point is standard PyTorch training!
###############################################################################

# Training
def train():
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    pbar = tqdm(enumerate(trainloader))
    print_every = 50
    for batch_idx, (inputs, targets) in pbar:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        if (batch_idx % print_every) == 0:
            print(f"Batch: {batch_idx}, Avg: {train_loss/(batch_idx+1)}")
        pbar.set_description(
            'Batch Idx: (%d/%d) | Loss: %.3f | Acc: %.3f%% (%d/%d)' %
            (batch_idx, len(trainloader), train_loss/(batch_idx+1), 100.*correct/total, correct, total)
        )


def eval(epoch, dataloader, checkpoint=False):
    global best_acc
    model.eval()
    eval_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        pbar = tqdm(enumerate(dataloader))
        for batch_idx, (inputs, targets) in pbar:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            eval_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()


            pbar.set_description(
                'Batch Idx: (%d/%d) | Loss: %.3f | Acc: %.3f%% (%d/%d)' %
                (batch_idx, len(dataloader), eval_loss/(batch_idx+1), 100.*correct/total, correct, total)
            )

    # Save checkpoint.
    if checkpoint:
        acc = 100.*correct/total
        if acc > best_acc:
            state = {
                'model': model.state_dict(),
                'acc': acc,
                'epoch': epoch,
            }
            if not os.path.isdir('checkpoint'):
                os.mkdir('checkpoint')
            torch.save(state, './checkpoint/ckpt.pth')
            best_acc = acc

        return acc

In [29]:
pbar = tqdm(range(start_epoch, args.epochs))
for epoch in pbar:
    if epoch == 0:
        pbar.set_description('Epoch: %d' % (epoch))
    else:
        pbar.set_description('Epoch: %d | Val acc: %1.3f' % (epoch, val_acc))
    train()
    val_acc = eval(epoch, valloader, checkpoint=True)
    scheduler.step()
    # print(f"Epoch {epoch} learning rate: {scheduler.get_last_lr()}")

Epoch: 0:   0%|          | 0/10 [00:00<?, ?it/s]

Batch: 0, Avg: 2.298004150390625




Batch: 50, Avg: 2.086724566478355




Batch: 100, Avg: 2.057391896106229




Batch: 150, Avg: 2.0309115885109303




Batch: 200, Avg: 2.016208737643797




Batch: 250, Avg: 2.0009899733076058




Batch: 300, Avg: 1.9892766257853207




Batch: 350, Avg: 1.975150657175613




Batch: 400, Avg: 1.9663986797047375




Batch: 450, Avg: 1.9544992010767868




Batch: 500, Avg: 1.9449558741080308




Batch: 550, Avg: 1.937387930719476




Batch: 600, Avg: 1.9330700086476205




Batch: 650, Avg: 1.926252675679056




Batch: 700, Avg: 1.921979105591604


Batch Idx: (703/704) | Loss: 1.922 | Acc: 28.640% (12888/45000): : 704it [22:03,  1.88s/it]
Batch Idx: (78/79) | Loss: 1.824 | Acc: 32.640% (1632/5000): : 79it [00:52,  1.51it/s]
Epoch: 1 | Val acc: 32.640:  10%|█         | 1/10 [22:55<3:26:19, 1375.52s/it]

Batch: 0, Avg: 1.8721773624420166




Batch: 50, Avg: 1.8599746741500556




Batch: 100, Avg: 1.802751077283727




Batch: 150, Avg: 1.8038773828784362




Batch: 200, Avg: 1.8010317419298845




Batch: 250, Avg: 1.7994761367243124




Batch: 300, Avg: 1.7922027808091172




Batch: 350, Avg: 1.7875041302792367




Batch: 400, Avg: 1.7844640524904627




Batch: 450, Avg: 1.7814418119760416




Batch: 500, Avg: 1.7757178461718226




Batch: 550, Avg: 1.7701020251601232




Batch: 600, Avg: 1.7645734120129348




Batch: 650, Avg: 1.758605768420546




Batch: 700, Avg: 1.7508270575554667


Batch Idx: (703/704) | Loss: 1.751 | Acc: 36.529% (16438/45000): : 704it [21:56,  1.87s/it]
Batch Idx: (78/79) | Loss: 1.685 | Acc: 40.400% (2020/5000): : 79it [00:55,  1.43it/s]
Epoch: 2 | Val acc: 40.400:  20%|██        | 2/10 [45:46<3:03:04, 1373.03s/it]

Batch: 0, Avg: 1.7080577611923218




Batch: 50, Avg: 1.6902075515073889




Batch: 100, Avg: 1.6698996572211238




Batch: 150, Avg: 1.6569338839575154




Batch: 200, Avg: 1.65300056768294




Batch: 250, Avg: 1.6494339234325515




Batch: 300, Avg: 1.6600929059063478




Batch: 350, Avg: 1.668319285764993




Batch: 400, Avg: 1.6716786920281121




Batch: 450, Avg: 1.6735141081714842




Batch: 500, Avg: 1.6689388999444044




Batch: 550, Avg: 1.6663048875310245




Batch: 600, Avg: 1.6594256710093747




Batch: 650, Avg: 1.656609974881654




Batch: 700, Avg: 1.65251731855553


Batch Idx: (703/704) | Loss: 1.653 | Acc: 41.360% (18612/45000): : 704it [21:52,  1.86s/it]
Batch Idx: (78/79) | Loss: 1.608 | Acc: 42.160% (2108/5000): : 79it [00:51,  1.53it/s]
Epoch: 3 | Val acc: 42.160:  30%|███       | 3/10 [1:08:30<2:39:41, 1368.78s/it]

Batch: 0, Avg: 1.5215574502944946




Batch: 50, Avg: 1.5918174491209143




Batch: 100, Avg: 1.5881250660018165




Batch: 150, Avg: 1.580806410075813




Batch: 200, Avg: 1.5809063502212068




Batch: 250, Avg: 1.5760601250773882




Batch: 300, Avg: 1.5781035098522604




Batch: 350, Avg: 1.5783035761950024




Batch: 400, Avg: 1.578275391883089




Batch: 450, Avg: 1.5747348337110025




Batch: 500, Avg: 1.574222963727163




Batch: 550, Avg: 1.5733485957455504




Batch: 600, Avg: 1.5679396513892887




Batch: 650, Avg: 1.5692814207663002




Batch: 700, Avg: 1.5690519020662839


Batch Idx: (703/704) | Loss: 1.568 | Acc: 44.624% (20081/45000): : 704it [1:11:45,  6.12s/it]
Batch Idx: (78/79) | Loss: 1.573 | Acc: 44.140% (2207/5000): : 79it [00:52,  1.52it/s]
Epoch: 4 | Val acc: 44.140:  40%|████      | 4/10 [2:21:07<4:14:51, 2548.60s/it]

Batch: 0, Avg: 1.608256459236145




Batch: 50, Avg: 1.50482954230963




Batch: 100, Avg: 1.5044438532083342




Batch: 150, Avg: 1.5067985326249078




Batch: 200, Avg: 1.5069933441740957




Batch: 250, Avg: 1.5166791544492502




Batch: 300, Avg: 1.5222617150145115




Batch: 350, Avg: 1.5265727179002897




Batch: 400, Avg: 1.5221880329517354




Batch: 450, Avg: 1.5206320259364905




Batch: 500, Avg: 1.5201987070951632




Batch: 550, Avg: 1.5203881153394003




Batch: 600, Avg: 1.5187464424853714




Batch: 650, Avg: 1.5192422121534332




Batch: 700, Avg: 1.5184245060241852


Batch Idx: (703/704) | Loss: 1.518 | Acc: 46.389% (20875/45000): : 704it [21:39,  1.85s/it]
Batch Idx: (78/79) | Loss: 1.547 | Acc: 44.660% (2233/5000): : 79it [00:53,  1.48it/s]
Epoch: 5 | Val acc: 44.660:  50%|█████     | 5/10 [2:43:40<2:56:26, 2117.32s/it]

Batch: 0, Avg: 1.4050382375717163




Batch: 50, Avg: 1.491633361461116




Batch: 100, Avg: 1.4879884188718135




Batch: 150, Avg: 1.4843499163128682




Batch: 200, Avg: 1.481457106509612




Batch: 250, Avg: 1.4847005676938243




Batch: 300, Avg: 1.486024959142818




Batch: 350, Avg: 1.4872781259042245




Batch: 400, Avg: 1.4871612154040252




Batch: 450, Avg: 1.484542067722312




Batch: 500, Avg: 1.4824448571709579




Batch: 550, Avg: 1.478610132220869




Batch: 600, Avg: 1.4788036638012345




Batch: 650, Avg: 1.4782142100795623




Batch: 700, Avg: 1.4778676235386716


Batch Idx: (703/704) | Loss: 1.478 | Acc: 48.062% (21628/45000): : 704it [21:37,  1.84s/it]
Batch Idx: (78/79) | Loss: 1.496 | Acc: 46.440% (2322/5000): : 79it [00:51,  1.53it/s]
Epoch: 6 | Val acc: 46.440:  60%|██████    | 6/10 [3:06:09<2:03:45, 1856.27s/it]

Batch: 0, Avg: 1.3629142045974731




Batch: 50, Avg: 1.4668405757230871




Batch: 100, Avg: 1.4601059302245036




Batch: 150, Avg: 1.4465318302445065




Batch: 200, Avg: 1.447788856515837




Batch: 250, Avg: 1.4530785496966296




Batch: 300, Avg: 1.449582157736997




Batch: 350, Avg: 1.4473677190959964




Batch: 400, Avg: 1.4494663908000005




Batch: 450, Avg: 1.4475076293733855




Batch: 500, Avg: 1.4456111004728518




Batch: 550, Avg: 1.446248950196697




Batch: 600, Avg: 1.447979275478103




Batch: 650, Avg: 1.4463443605947421




Batch: 700, Avg: 1.4462058659797048


Batch Idx: (703/704) | Loss: 1.446 | Acc: 49.531% (22289/45000): : 704it [8:36:18, 44.00s/it]
Batch Idx: (78/79) | Loss: 1.491 | Acc: 47.080% (2354/5000): : 79it [00:51,  1.53it/s]
Epoch: 7 | Val acc: 47.080:  70%|███████   | 7/10 [11:43:20<9:29:42, 11394.00s/it]

Batch: 0, Avg: 1.476455807685852




Batch: 50, Avg: 1.442802209480136




Batch: 100, Avg: 1.4263146041643502




Batch: 150, Avg: 1.4287957481990587




Batch: 200, Avg: 1.417360781437129




Batch: 250, Avg: 1.4145674620016637




Batch: 300, Avg: 1.4138910493185355




Batch: 350, Avg: 1.4167415603273614




Batch: 400, Avg: 1.4157315608569214




Batch: 450, Avg: 1.4137103446571366




Batch: 500, Avg: 1.4167143063630887




Batch: 550, Avg: 1.417726729396467




Batch: 600, Avg: 1.4185057590884496




Batch: 650, Avg: 1.4189744951538226




Batch: 700, Avg: 1.4167527106621125


Batch Idx: (703/704) | Loss: 1.417 | Acc: 50.569% (22756/45000): : 704it [21:46,  1.86s/it]
Batch Idx: (78/79) | Loss: 1.446 | Acc: 49.020% (2451/5000): : 79it [00:51,  1.52it/s]
Epoch: 8 | Val acc: 49.020:  80%|████████  | 8/10 [12:05:58<4:33:18, 8199.08s/it] 

Batch: 0, Avg: 1.3239713907241821




Batch: 50, Avg: 1.3824400165501762




Batch: 100, Avg: 1.3896823962136071




Batch: 150, Avg: 1.3891515791021436




Batch: 200, Avg: 1.3856340317583795




Batch: 250, Avg: 1.3811477301605195




Batch: 300, Avg: 1.3813024143443948




Batch: 350, Avg: 1.3818576374964158




Batch: 400, Avg: 1.3833008783119278




Batch: 450, Avg: 1.3870755871489413




Batch: 500, Avg: 1.3894774734141109




Batch: 550, Avg: 1.3880858199566115




Batch: 600, Avg: 1.3879033889429344




Batch: 650, Avg: 1.39056781442484




Batch: 700, Avg: 1.3905285206568905


Batch Idx: (703/704) | Loss: 1.390 | Acc: 51.789% (23305/45000): : 704it [21:39,  1.85s/it]
Batch Idx: (78/79) | Loss: 1.412 | Acc: 49.620% (2481/5000): : 79it [00:52,  1.51it/s]
Epoch: 9 | Val acc: 49.620:  90%|█████████ | 9/10 [12:28:30<1:40:58, 6058.66s/it]

Batch: 0, Avg: 1.4281963109970093




Batch: 50, Avg: 1.3799313447054695




Batch: 100, Avg: 1.3666988245331415




Batch: 150, Avg: 1.3611986321329281




Batch: 200, Avg: 1.3737354320080126




Batch: 250, Avg: 1.3807935956939759




Batch: 300, Avg: 1.3766217580269342




Batch: 350, Avg: 1.380200047099013




Batch: 400, Avg: 1.3823961539756033




Batch: 450, Avg: 1.380147173240814




Batch: 500, Avg: 1.3776853660385528




Batch: 550, Avg: 1.3772931168170244




Batch: 600, Avg: 1.3776404815981669




Batch: 650, Avg: 1.3781920896697155




Batch: 700, Avg: 1.377910297539367


Batch Idx: (703/704) | Loss: 1.378 | Acc: 52.253% (23514/45000): : 704it [21:49,  1.86s/it]
Batch Idx: (78/79) | Loss: 1.409 | Acc: 49.760% (2488/5000): : 79it [00:51,  1.53it/s]
Epoch: 9 | Val acc: 49.620: 100%|██████████| 10/10 [12:51:12<00:00, 4627.21s/it] 


In [None]:
##  xavier_uniform_ initlization +  with tanh + 1) with bias + 2) without lr schedule + 3) other hyper parameters

## Test

In [30]:
eval(epoch, testloader)

Batch Idx: (156/157) | Loss: 1.423 | Acc: 50.150% (5015/10000): : 157it [01:44,  1.50it/s]
