# LSTM Meta Learner on MiniImageNet Dataset

Please Download data using link and save it in save folder as of this notebook after extracting: https://drive.google.com/file/d/1rV3aj_hgfNTfCakffpPm7Vhpr1in87CR


In [1]:
#import libraries
from __future__ import division, print_function, absolute_import
import os
import re
import pdb
import copy
import glob
import numpy as np
import torch
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import PIL.Image as PILI
from tqdm import tqdm
from collections import OrderedDict
import random
import logging
import os
from torchvision.datasets.utils import download_file_from_google_drive, extract_archive

#### Step 1: Data Loader 

In [2]:
class EpisodeDataset(data.Dataset):

    def __init__(self, root, phase='train', n_shot=5, n_eval=15, transform=None):
        """Args:
            root (str): path to data
            phase (str): train, val or test
            n_shot (int): how many examples per class for training (k/n_support)
            n_eval (int): how many examples per class for evaluation
                - n_shot + n_eval = batch_size for data.DataLoader of ClassDataset
            transform (torchvision.transforms): data augmentation
        """
        root = os.path.join(root, phase)
        self.labels = sorted(os.listdir(root))
        images = [glob.glob(os.path.join(root, label, '*')) for label in self.labels]
#         print (len(images))
#         for s in self.labels:
#             print (s)
#         for i in range(0,len(images)):
#             if len(images)<1:
#                 print ("found true")
#                 del images[i]
#                 del self.labels[i]
        self.episode_loader = [data.DataLoader(
            ClassDataset(images=images[idx], label=idx, transform=transform),
            batch_size=n_shot+n_eval, shuffle=True, num_workers=0) for idx, _ in enumerate(self.labels)]

    def __getitem__(self, idx):
        return next(iter(self.episode_loader[idx]))

    def __len__(self):
        return len(self.labels)

In [3]:
class ClassDataset(data.Dataset):

    def __init__(self, images, label, transform=None):
        """Args:
            images (list of str): each item is a path to an image of the same label
            label (int): the label of all the images
        """
        self.images = images
        self.label = label
        self.transform = transform

    def __getitem__(self, idx):
        image = PILI.open(self.images[idx]).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)

        return image, self.label

    def __len__(self):
        return len(self.images)

In [4]:
class EpisodicSampler(data.Sampler):

    def __init__(self, total_classes, n_class, n_episode):
        self.total_classes = total_classes
        self.n_class = n_class
        self.n_episode = n_episode

    def __iter__(self):
        for i in range(self.n_episode):
            yield torch.randperm(self.total_classes)[:self.n_class]

    def __len__(self):
        return self.n_episode

In [5]:
def prepare_data(args):

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    
    transform1= transforms.Compose([
            transforms.Resize(args['image_size'] * 8 // 7),
            transforms.CenterCrop(args['image_size']),
            transforms.ToTensor(),
            normalize])
    
    transform2 = transforms.Compose([
            transforms.RandomResizedCrop(args['image_size']),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(
                brightness=0.4,
                contrast=0.4,
                saturation=0.4,
                hue=0.2),
            transforms.ToTensor(),
            normalize])
    
    train_set = EpisodeDataset(args['data_root'], 'train', args['n_shot'], args['n_eval'],
        transform=transform2)

    val_set = EpisodeDataset(args['data_root'], 'val', args['n_shot'], args['n_eval'],
        transform=transform1)

    test_set = EpisodeDataset(args['data_root'], 'test', args['n_shot'], args['n_eval'],
        transform=transform1)

    train_loader = data.DataLoader(train_set, num_workers=4, pin_memory=True,
        batch_sampler=EpisodicSampler(len(train_set), args['n_class'], args['episode']))

    val_loader = data.DataLoader(val_set, num_workers=2, pin_memory=True,
        batch_sampler=EpisodicSampler(len(val_set), args['n_class'], args['episode_val']))

    test_loader = data.DataLoader(test_set, num_workers=2, pin_memory=True,
        batch_sampler=EpisodicSampler(len(test_set), args['n_class'], args['episode_val']))
    return train_loader, val_loader, test_loader


#### Step 2: Learner 

In [6]:
class Learner(nn.Module):

    def __init__(self, image_size, bn_eps, bn_momentum, n_classes):
        super(Learner, self).__init__()
        self.model = nn.ModuleDict({'features': nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(3, 32, 3, padding=1)),
            ('norm1', nn.BatchNorm2d(32, bn_eps, bn_momentum)),
            ('relu1', nn.ReLU(inplace=False)),
            ('pool1', nn.MaxPool2d(2)),

            ('conv2', nn.Conv2d(32, 32, 3, padding=1)),
            ('norm2', nn.BatchNorm2d(32, bn_eps, bn_momentum)),
            ('relu2', nn.ReLU(inplace=False)),
            ('pool2', nn.MaxPool2d(2)),

            ('conv3', nn.Conv2d(32, 32, 3, padding=1)),
            ('norm3', nn.BatchNorm2d(32, bn_eps, bn_momentum)),
            ('relu3', nn.ReLU(inplace=False)),
            ('pool3', nn.MaxPool2d(2)),

            ('conv4', nn.Conv2d(32, 32, 3, padding=1)),
            ('norm4', nn.BatchNorm2d(32, bn_eps, bn_momentum)),
            ('relu4', nn.ReLU(inplace=False)),
            ('pool4', nn.MaxPool2d(2))]))
        })

        clr_in = image_size // 2**4
        self.model.update({'cls': nn.Linear(32 * clr_in * clr_in, n_classes)})
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        x = self.model.features(x)
        x = torch.reshape(x, [x.size(0), -1])
        outputs = self.model.cls(x)
        return outputs

    def get_flat_params(self):
        return torch.cat([p.view(-1) for p in self.model.parameters()], 0)

    def copy_flat_params(self, cI):
        idx = 0
        for p in self.model.parameters():
            plen = p.view(-1).size(0)
            p.data.copy_(cI[idx: idx+plen].view_as(p))
            idx += plen

    def transfer_params(self, learner_w_grad, cI):
        # Use load_state_dict only to copy the running mean/var in batchnorm, the values of the parameters
        #  are going to be replaced by cI
        self.load_state_dict(learner_w_grad.state_dict())
        #  replace nn.Parameters with tensors from cI (NOT nn.Parameters anymore).
        idx = 0
        for m in self.model.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.Linear):
                wlen = m._parameters['weight'].view(-1).size(0)
                m._parameters['weight'] = cI[idx: idx+wlen].view_as(m._parameters['weight']).clone()
                idx += wlen
                if m._parameters['bias'] is not None:
                    blen = m._parameters['bias'].view(-1).size(0)
                    m._parameters['bias'] = cI[idx: idx+blen].view_as(m._parameters['bias']).clone()
                    idx += blen

    def reset_batch_stats(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.reset_running_stats()

#### Step 4: Meta Learner 

In [7]:
class MetaLSTMCell(nn.Module):
    """C_t = f_t * C_{t-1} + i_t * \tilde{C_t}"""
    def __init__(self, input_size, hidden_size, n_learner_params):
        super(MetaLSTMCell, self).__init__()
        """Args:
            input_size (int): cell input size, default = 20
            hidden_size (int): should be 1
            n_learner_params (int): number of learner's parameters
        """
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.n_learner_params = n_learner_params
        self.WF = nn.Parameter(torch.Tensor(input_size + 2, hidden_size))
        self.WI = nn.Parameter(torch.Tensor(input_size + 2, hidden_size))
        self.cI = nn.Parameter(torch.Tensor(n_learner_params, 1))
        self.bI = nn.Parameter(torch.Tensor(1, hidden_size))
        self.bF = nn.Parameter(torch.Tensor(1, hidden_size))

        self.reset_parameters()

    def reset_parameters(self):
        for weight in self.parameters():
            nn.init.uniform_(weight, -0.01, 0.01)

        # want initial forget value to be high and input value to be low so that 
        #  model starts with gradient descent
        nn.init.uniform_(self.bF, 4, 6)
        nn.init.uniform_(self.bI, -5, -4)

    def init_cI(self, flat_params):
        self.cI.data.copy_(flat_params.unsqueeze(1))

    def forward(self, inputs, hx=None):
        """Args:
            inputs = [x_all, grad]:
                x_all (torch.Tensor of size [n_learner_params, input_size]): outputs from previous LSTM
                grad (torch.Tensor of size [n_learner_params]): gradients from learner
            hx = [f_prev, i_prev, c_prev]:
                f (torch.Tensor of size [n_learner_params, 1]): forget gate
                i (torch.Tensor of size [n_learner_params, 1]): input gate
                c (torch.Tensor of size [n_learner_params, 1]): flattened learner parameters
        """
        x_all, grad = inputs
        batch, _ = x_all.size()

        if hx is None:
            f_prev = torch.zeros((batch, self.hidden_size)).to(self.WF.device)
            i_prev = torch.zeros((batch, self.hidden_size)).to(self.WI.device)
            c_prev = self.cI
            hx = [f_prev, i_prev, c_prev]

        f_prev, i_prev, c_prev = hx
        
        # f_t = sigmoid(W_f * [grad_t, loss_t, theta_{t-1}, f_{t-1}] + b_f)
        f_next = torch.mm(torch.cat((x_all, c_prev, f_prev), 1), self.WF) + self.bF.expand_as(f_prev)
        # i_t = sigmoid(W_i * [grad_t, loss_t, theta_{t-1}, i_{t-1}] + b_i)
        i_next = torch.mm(torch.cat((x_all, c_prev, i_prev), 1), self.WI) + self.bI.expand_as(i_prev)
        # next cell/params
        c_next = torch.sigmoid(f_next).mul(c_prev) - torch.sigmoid(i_next).mul(grad)

        return c_next, [f_next, i_next, c_next]

    def extra_repr(self):
        s = '{input_size}, {hidden_size}, {n_learner_params}'
        return s.format(**self.__dict__)


class MetaLearner(nn.Module):

    def __init__(self, input_size, hidden_size, n_learner_params):
        super(MetaLearner, self).__init__()
        """Args:
            input_size (int): for the first LSTM layer, default = 4
            hidden_size (int): for the first LSTM layer, default = 20
            n_learner_params (int): number of learner's parameters
        """
        self.lstm = nn.LSTMCell(input_size=input_size, hidden_size=hidden_size)
        self.metalstm = MetaLSTMCell(input_size=hidden_size, hidden_size=1, n_learner_params=n_learner_params)

    def forward(self, inputs, hs=None):
        """Args:
            inputs = [loss, grad_prep, grad]
                loss (torch.Tensor of size [1, 2])
                grad_prep (torch.Tensor of size [n_learner_params, 2])
                grad (torch.Tensor of size [n_learner_params])
            hs = [(lstm_hn, lstm_cn), [metalstm_fn, metalstm_in, metalstm_cn]]
        """
        loss, grad_prep, grad = inputs
        loss = loss.expand_as(grad_prep)
        inputs = torch.cat((loss, grad_prep), 1)   # [n_learner_params, 4]

        if hs is None:
            hs = [None, None]

        lstmhx, lstmcx = self.lstm(inputs, hs[0])
        flat_learner_unsqzd, metalstm_hs = self.metalstm([lstmhx, grad], hs[1])

        return flat_learner_unsqzd.squeeze(), [(lstmhx, lstmcx), metalstm_hs]

#### Step 4: utils 

In [8]:
def accuracy(output, target, topk=(1,)):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res[0].item() if len(res) == 1 else [r.item() for r in res]

In [9]:
def preprocess_grad_loss(x):
    p = 10
    indicator = (x.abs() >= np.exp(-p)).to(torch.float32)

    # preproc1
    x_proc1 = indicator * torch.log(x.abs() + 1e-8) / p + (1 - indicator) * -1
    # preproc2
    x_proc2 = indicator * torch.sign(x) + (1 - indicator) * np.exp(p) * x
    return torch.stack((x_proc1, x_proc2), 1)

#### Step 5: Main

In [10]:
def meta_test(eps, eval_loader, learner_w_grad, learner_wo_grad, metalearner, args):
    for subeps, (episode_x, episode_y) in enumerate(tqdm(eval_loader, ascii=True)):
        train_input = episode_x[:, :args['n_shot']].reshape(-1, *episode_x.shape[-3:]) # [n_class * n_shot, :]
        train_target = torch.LongTensor(np.repeat(range(args['n_class']), args['n_shot'])) # [n_class * n_shot]
        test_input = episode_x[:, args['n_shot']:].reshape(-1, *episode_x.shape[-3:]) # [n_class * n_eval, :]
        test_target = torch.LongTensor(np.repeat(range(args['n_class']), args['n_eval'])) # [n_class * n_eval]

        # Train learner with metalearner
        learner_w_grad.reset_batch_stats()
        learner_wo_grad.reset_batch_stats()
        learner_w_grad.train()
        learner_wo_grad.eval()
        cI = train_learner(learner_w_grad, metalearner, train_input, train_target, args)

        learner_wo_grad.transfer_params(learner_w_grad, cI)
        output = learner_wo_grad(test_input)
        loss = learner_wo_grad.criterion(output, test_target)
        acc = accuracy(output, test_target)
        print ("Validation/Test Loss: {}, and Accuracy {}".format(loss.item(), acc))
        

    return acc


In [11]:
def train_learner(learner_w_grad, metalearner, train_input, train_target, args):
    cI = metalearner.metalstm.cI.data
    hs = [None]
    for _ in range(args['epoch']):
        for i in range(0, len(train_input), args['batch_size']):
            x = train_input[i:i+args['batch_size']]
            y = train_target[i:i+args['batch_size']]

            # get the loss/grad
            learner_w_grad.copy_flat_params(cI)
            output = learner_w_grad(x)
            loss = learner_w_grad.criterion(output, y)
            acc = accuracy(output, y)
            learner_w_grad.zero_grad()
            loss.backward()
            grad = torch.cat([p.grad.data.view(-1) / args['batch_size'] for p in learner_w_grad.parameters()], 0)

            # preprocess grad & loss and metalearner forward
            grad_prep = preprocess_grad_loss(grad)  # [n_learner_params, 2]
            loss_prep = preprocess_grad_loss(loss.data.unsqueeze(0)) # [1, 2]
            metalearner_input = [loss_prep, grad_prep, grad.unsqueeze(1)]
            cI, h = metalearner(metalearner_input, hs[-1])
            hs.append(h)
            print("training loss: {:8.6f} acc: {:6.3f}, mean grad: {:8.6f}".format(loss, acc, torch.mean(grad)))

    return cI


In [12]:
def main(args):
    seed = 2019
    np.random.seed(seed)
    torch.manual_seed(seed)

    # Get data
    train_loader, val_loader, test_loader = prepare_data(args)
    
    # Set up learner, meta-learner
    learner_w_grad = Learner(args['image_size'], args['bn_eps'], args['bn_momentum'], args['n_class'])
    learner_wo_grad = copy.deepcopy(learner_w_grad)
    metalearner = MetaLearner(args['input_size'], args['hidden_size'], learner_w_grad.get_flat_params().size(0))
    metalearner.metalstm.init_cI(learner_w_grad.get_flat_params())

    # Set up loss, optimizer, learning rate scheduler
    optim = torch.optim.Adam(metalearner.parameters(), args['lr'])

    if args['mode'] == 'test':
        _ = meta_test(args['episode'], test_loader, learner_w_grad, learner_wo_grad, metalearner, args)
        return
    best_acc = 0.0
    # Meta-training
    for eps, (episode_x, episode_y) in enumerate(train_loader):
        # episode_x.shape = [n_class, n_shot + n_eval, c, h, w]
        # episode_y.shape = [n_class, n_shot + n_eval] --> NEVER USED
        train_input = episode_x[:, :args['n_shot']].reshape(-1, *episode_x.shape[-3:]) # [n_class * n_shot, :]
        train_target = torch.LongTensor(np.repeat(range(args['n_shot']), args['n_shot'])) # [n_class * n_shot]
        test_input = episode_x[:, args['n_shot']:].reshape(-1, *episode_x.shape[-3:]) # [n_class * n_eval, :]
        test_target = torch.LongTensor(np.repeat(range(args['n_shot']), args['n_eval'])) # [n_class * n_eval]

        # Train learner with metalearner
        learner_w_grad.reset_batch_stats()
        learner_wo_grad.reset_batch_stats()
        learner_w_grad.train()
        learner_wo_grad.train()
        cI = train_learner(learner_w_grad, metalearner, train_input, train_target, args)

        # Train meta-learner with validation loss
        learner_wo_grad.transfer_params(learner_w_grad, cI)
        output = learner_wo_grad(test_input)
        loss = learner_wo_grad.criterion(output, test_target)
        acc = accuracy(output, test_target)
        
        optim.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(metalearner.parameters(), args['grad_clip'])
        optim.step()
        
        print (eps)
        # Meta-validation
        if eps % 10 == 0 and eps != 0:
            acc = meta_test(eps, val_loader, learner_w_grad, learner_wo_grad, metalearner, args)
            if acc > best_acc:
                best_acc = acc
                print ("* Best accuracy so far *\n")

    print("Done")

In [13]:
if __name__ == '__main__':
    args_train={'mode':'train','n_shot':5,'n_eval':15,'n_class':5,'input_size':4,'hidden_size':20,'lr':1e-3,'episode':50,
      'episode_val':50,'epoch':8,'batch_size':16,'image_size':84,'grad_clip':0.25,'bn_momentum': 0.95,'bn_eps': 1e-3,
       'data': "miniimagenet",'data_root': "./miniImagenet/", 'resume': None}
    
    
    args_test={'mode':'test','n_shot':5,'n_eval':15,'n_class':5,'input_size':4,'hidden_size':20,'lr':1e-3,'episode':50,
      'episode_val':10,'epoch':8,'batch_size':16,'image_size':84,'grad_clip':0.25,'bn_momentum': 0.95,'bn_eps': 1e-3,
       'data': "miniimagenet",'data_root': "./miniImagenet/", 'resume': None}
    
    
    print (" BEGIN TRAINING: ")
    main(args_train)
    
    
    print ("BEGIN TESTING")
    main(args_test)
    
    

 BEGIN TRAINING: 
training loss: 1.614745 acc: 31.250, mean grad: -0.000029
training loss: 2.118628 acc:  0.000, mean grad: 0.000112
training loss: 1.540848 acc: 37.500, mean grad: -0.000018
training loss: 1.856331 acc: 11.111, mean grad: 0.000140
0
training loss: 1.668032 acc: 18.750, mean grad: -0.000022
training loss: 1.827717 acc: 22.222, mean grad: -0.000048
training loss: 1.620747 acc: 18.750, mean grad: -0.000035
training loss: 1.624162 acc: 55.556, mean grad: -0.000062
1
training loss: 1.648508 acc: 12.500, mean grad: -0.000081
training loss: 2.014055 acc:  0.000, mean grad: -0.000178
training loss: 1.610496 acc: 18.750, mean grad: -0.000079
training loss: 1.832938 acc: 11.111, mean grad: -0.000149
2
training loss: 1.715487 acc: 12.500, mean grad: -0.000034
training loss: 1.952520 acc: 11.111, mean grad: -0.000026
training loss: 1.683787 acc: 12.500, mean grad: -0.000045
training loss: 1.769854 acc: 11.111, mean grad: -0.000026
3
training loss: 1.756584 acc: 18.750, mean grad: 

  0%|          | 0/5 [00:00<?, ?it/s]

10
training loss: 1.506445 acc: 31.250, mean grad: 0.000080
training loss: 1.924212 acc: 33.333, mean grad: -0.000004
training loss: 1.506989 acc: 25.000, mean grad: 0.000075
training loss: 1.708387 acc: 33.333, mean grad: -0.000000


 20%|##        | 1/5 [00:02<00:09,  2.47s/it]

Validation/Test Loss: 1.593719720840454, and Accuracy 21.33333396911621
training loss: 1.500121 acc: 37.500, mean grad: 0.000032
training loss: 1.905619 acc: 11.111, mean grad: 0.000107
training loss: 1.510576 acc: 43.750, mean grad: 0.000034
training loss: 1.721232 acc: 33.333, mean grad: 0.000089


 40%|####      | 2/5 [00:03<00:06,  2.13s/it]

Validation/Test Loss: 1.4927806854248047, and Accuracy 33.333335876464844
training loss: 1.534272 acc: 31.250, mean grad: 0.000059
training loss: 1.644273 acc: 33.333, mean grad: -0.000053
training loss: 1.533480 acc: 31.250, mean grad: 0.000069
training loss: 1.506519 acc: 33.333, mean grad: -0.000056


 60%|######    | 3/5 [00:05<00:03,  1.90s/it]

Validation/Test Loss: 1.6011277437210083, and Accuracy 22.666667938232422
training loss: 1.741906 acc: 18.750, mean grad: 0.000026
training loss: 1.586810 acc: 11.111, mean grad: -0.000130
training loss: 1.722555 acc: 12.500, mean grad: 0.000022
training loss: 1.462085 acc: 44.444, mean grad: -0.000112


 80%|########  | 4/5 [00:06<00:01,  1.77s/it]

Validation/Test Loss: 1.732759952545166, and Accuracy 17.33333396911621
training loss: 1.693859 acc: 31.250, mean grad: -0.000059
training loss: 1.936478 acc: 22.222, mean grad: 0.000046
training loss: 1.684606 acc: 31.250, mean grad: -0.000056
training loss: 1.776152 acc: 33.333, mean grad: 0.000030


100%|##########| 5/5 [00:08<00:00,  1.64s/it]

Validation/Test Loss: 1.7159435749053955, and Accuracy 17.33333396911621
* Best accuracy so far *






training loss: 1.700276 acc: 31.250, mean grad: 0.000071
training loss: 2.000686 acc:  0.000, mean grad: 0.000026
training loss: 1.675115 acc: 25.000, mean grad: 0.000052
training loss: 1.837954 acc:  0.000, mean grad: 0.000042
11
training loss: 1.669216 acc: 37.500, mean grad: 0.000142
training loss: 1.709039 acc: 11.111, mean grad: -0.000022
training loss: 1.624813 acc: 37.500, mean grad: 0.000142
training loss: 1.569145 acc: 22.222, mean grad: -0.000020
12
training loss: 1.657539 acc: 31.250, mean grad: -0.000030
training loss: 1.824310 acc: 11.111, mean grad: -0.000055
training loss: 1.626104 acc: 31.250, mean grad: -0.000019
training loss: 1.705141 acc: 11.111, mean grad: -0.000063
13
training loss: 1.409093 acc: 31.250, mean grad: 0.000014
training loss: 1.824450 acc:  0.000, mean grad: -0.000046
training loss: 1.419711 acc: 37.500, mean grad: 0.000012
training loss: 1.668188 acc: 11.111, mean grad: -0.000059
14
training loss: 1.565368 acc: 31.250, mean grad: -0.000007
training l

  0%|          | 0/5 [00:00<?, ?it/s]

20
training loss: 1.665912 acc:  6.250, mean grad: 0.000017
training loss: 1.890723 acc:  0.000, mean grad: 0.000004
training loss: 1.665951 acc:  6.250, mean grad: 0.000016
training loss: 1.742665 acc: 11.111, mean grad: 0.000013

 20%|##        | 1/5 [00:02<00:08,  2.14s/it]


Validation/Test Loss: 1.6213276386260986, and Accuracy 25.33333396911621
training loss: 1.635617 acc: 12.500, mean grad: 0.000032
training loss: 2.005629 acc:  0.000, mean grad: -0.000053
training loss: 1.624176 acc: 18.750, mean grad: 0.000035
training loss: 1.841934 acc: 11.111, mean grad: -0.000035


 40%|####      | 2/5 [00:03<00:05,  1.91s/it]

Validation/Test Loss: 1.6874781847000122, and Accuracy 13.333333969116211
training loss: 1.529977 acc: 31.250, mean grad: 0.000043
training loss: 1.874977 acc: 11.111, mean grad: 0.000018
training loss: 1.538827 acc: 25.000, mean grad: 0.000039
training loss: 1.724627 acc: 11.111, mean grad: 0.000017


 60%|######    | 3/5 [00:04<00:03,  1.73s/it]

Validation/Test Loss: 1.6405235528945923, and Accuracy 18.666667938232422
training loss: 1.724140 acc: 37.500, mean grad: -0.000021
training loss: 1.898619 acc: 11.111, mean grad: -0.000044
training loss: 1.706659 acc: 37.500, mean grad: -0.000018
training loss: 1.768893 acc: 22.222, mean grad: -0.000053


 80%|########  | 4/5 [00:06<00:01,  1.59s/it]

Validation/Test Loss: 1.68801748752594, and Accuracy 20.0
training loss: 1.571136 acc: 18.750, mean grad: 0.000060
training loss: 1.941251 acc:  0.000, mean grad: -0.000142
training loss: 1.564841 acc: 37.500, mean grad: 0.000065
training loss: 1.789155 acc:  0.000, mean grad: -0.000127


100%|##########| 5/5 [00:07<00:00,  1.47s/it]

Validation/Test Loss: 1.6470879316329956, and Accuracy 20.0
* Best accuracy so far *






training loss: 1.486757 acc: 31.250, mean grad: -0.000055
training loss: 2.025447 acc:  0.000, mean grad: -0.000005
training loss: 1.480276 acc: 37.500, mean grad: -0.000052
training loss: 1.819491 acc:  0.000, mean grad: -0.000005
21
training loss: 1.632090 acc: 12.500, mean grad: 0.000004
training loss: 1.925757 acc:  0.000, mean grad: -0.000057
training loss: 1.629442 acc: 12.500, mean grad: 0.000001
training loss: 1.795984 acc:  0.000, mean grad: -0.000049
22
training loss: 1.429525 acc: 37.500, mean grad: -0.000008
training loss: 1.979416 acc:  0.000, mean grad: 0.000035
training loss: 1.445377 acc: 50.000, mean grad: -0.000014
training loss: 1.837744 acc: 11.111, mean grad: 0.000036
23
training loss: 1.601067 acc: 18.750, mean grad: -0.000043
training loss: 1.831364 acc: 11.111, mean grad: -0.000025
training loss: 1.603745 acc: 18.750, mean grad: -0.000039
training loss: 1.715759 acc: 22.222, mean grad: -0.000024
24
training loss: 1.410628 acc: 31.250, mean grad: 0.000090
trainin

  0%|          | 0/5 [00:00<?, ?it/s]

30
training loss: 1.643109 acc:  6.250, mean grad: -0.000026
training loss: 1.683440 acc: 33.333, mean grad: -0.000016
training loss: 1.646062 acc:  6.250, mean grad: -0.000028
training loss: 1.568876 acc: 55.556, mean grad: -0.000010


 20%|##        | 1/5 [00:02<00:09,  2.29s/it]

Validation/Test Loss: 1.6116549968719482, and Accuracy 25.33333396911621
training loss: 1.425703 acc: 37.500, mean grad: -0.000014
training loss: 1.692899 acc: 33.333, mean grad: 0.000003
training loss: 1.445959 acc: 37.500, mean grad: -0.000012
training loss: 1.577181 acc: 33.333, mean grad: 0.000005


 40%|####      | 2/5 [00:03<00:06,  2.02s/it]

Validation/Test Loss: 1.6073495149612427, and Accuracy 26.666667938232422
training loss: 1.634763 acc: 31.250, mean grad: 0.000029
training loss: 1.891266 acc: 11.111, mean grad: -0.000019
training loss: 1.651148 acc: 31.250, mean grad: 0.000027
training loss: 1.752326 acc: 22.222, mean grad: -0.000016


 60%|######    | 3/5 [00:05<00:03,  1.83s/it]

Validation/Test Loss: 1.6562232971191406, and Accuracy 25.33333396911621
training loss: 1.692269 acc:  6.250, mean grad: -0.000006
training loss: 1.890070 acc: 11.111, mean grad: -0.000047
training loss: 1.695880 acc:  6.250, mean grad: -0.000008
training loss: 1.766104 acc: 11.111, mean grad: -0.000038


 80%|########  | 4/5 [00:06<00:01,  1.69s/it]

Validation/Test Loss: 1.7049757242202759, and Accuracy 16.0
training loss: 1.670734 acc: 18.750, mean grad: -0.000013
training loss: 1.690246 acc: 22.222, mean grad: -0.000075
training loss: 1.677693 acc: 25.000, mean grad: -0.000018
training loss: 1.550957 acc: 22.222, mean grad: -0.000073


100%|##########| 5/5 [00:07<00:00,  1.56s/it]

Validation/Test Loss: 1.6256543397903442, and Accuracy 22.666667938232422
* Best accuracy so far *






training loss: 1.684431 acc: 12.500, mean grad: -0.000029
training loss: 1.780886 acc:  0.000, mean grad: 0.000048
training loss: 1.686307 acc: 12.500, mean grad: -0.000031
training loss: 1.642941 acc: 44.444, mean grad: 0.000037
31
training loss: 1.647676 acc:  6.250, mean grad: 0.000030
training loss: 1.690422 acc: 22.222, mean grad: -0.000002
training loss: 1.652482 acc:  6.250, mean grad: 0.000025
training loss: 1.598753 acc: 22.222, mean grad: -0.000001
32
training loss: 1.526454 acc: 31.250, mean grad: 0.000006
training loss: 1.912634 acc: 11.111, mean grad: -0.000096
training loss: 1.536687 acc: 25.000, mean grad: 0.000007
training loss: 1.787290 acc: 22.222, mean grad: -0.000084
33
training loss: 1.578569 acc: 31.250, mean grad: -0.000013
training loss: 1.786916 acc: 11.111, mean grad: -0.000004
training loss: 1.580722 acc: 31.250, mean grad: -0.000020
training loss: 1.685310 acc: 11.111, mean grad: -0.000012
34
training loss: 1.746173 acc: 12.500, mean grad: 0.000001
training 

  0%|          | 0/5 [00:00<?, ?it/s]

40
training loss: 1.494308 acc: 50.000, mean grad: 0.000016
training loss: 1.987978 acc: 11.111, mean grad: -0.000043
training loss: 1.503294 acc: 50.000, mean grad: 0.000013
training loss: 1.857346 acc: 11.111, mean grad: -0.000041


 20%|##        | 1/5 [00:02<00:09,  2.32s/it]

Validation/Test Loss: 1.6918948888778687, and Accuracy 18.666667938232422
training loss: 1.656846 acc: 18.750, mean grad: -0.000006
training loss: 1.927852 acc:  0.000, mean grad: 0.000033
training loss: 1.664261 acc: 18.750, mean grad: -0.000007
training loss: 1.794069 acc:  0.000, mean grad: 0.000032


 40%|####      | 2/5 [00:03<00:06,  2.09s/it]

Validation/Test Loss: 1.7808585166931152, and Accuracy 14.666666984558105
training loss: 1.549928 acc: 37.500, mean grad: 0.000022
training loss: 1.906957 acc:  0.000, mean grad: 0.000022
training loss: 1.564751 acc: 37.500, mean grad: 0.000021
training loss: 1.757457 acc: 11.111, mean grad: 0.000023


 60%|######    | 3/5 [00:05<00:03,  1.91s/it]

Validation/Test Loss: 1.668221116065979, and Accuracy 26.666667938232422
training loss: 1.504888 acc: 43.750, mean grad: -0.000049
training loss: 1.702780 acc:  0.000, mean grad: 0.000023
training loss: 1.531997 acc: 31.250, mean grad: -0.000047
training loss: 1.577613 acc: 11.111, mean grad: 0.000025


 80%|########  | 4/5 [00:06<00:01,  1.75s/it]

Validation/Test Loss: 1.5447510480880737, and Accuracy 29.33333396911621
training loss: 1.689559 acc: 18.750, mean grad: 0.000003
training loss: 1.956663 acc:  0.000, mean grad: 0.000041
training loss: 1.694549 acc: 18.750, mean grad: -0.000001
training loss: 1.821139 acc:  0.000, mean grad: 0.000038


100%|##########| 5/5 [00:08<00:00,  1.66s/it]

Validation/Test Loss: 1.5948500633239746, and Accuracy 28.0
* Best accuracy so far *






training loss: 1.583406 acc: 18.750, mean grad: 0.000002
training loss: 1.925111 acc: 11.111, mean grad: -0.000011
training loss: 1.594309 acc: 18.750, mean grad: -0.000001
training loss: 1.794685 acc: 11.111, mean grad: -0.000017
41
training loss: 1.561712 acc: 31.250, mean grad: -0.000009
training loss: 1.740330 acc:  0.000, mean grad: -0.000006
training loss: 1.578705 acc: 31.250, mean grad: -0.000010
training loss: 1.648667 acc:  0.000, mean grad: -0.000014
42
training loss: 1.810461 acc: 18.750, mean grad: -0.000084
training loss: 1.781053 acc:  0.000, mean grad: -0.000021
training loss: 1.804751 acc: 18.750, mean grad: -0.000083
training loss: 1.698363 acc:  0.000, mean grad: -0.000014
43
training loss: 1.562874 acc: 18.750, mean grad: -0.000004
training loss: 1.784615 acc: 11.111, mean grad: 0.000062
training loss: 1.580992 acc: 25.000, mean grad: 0.000003
training loss: 1.674384 acc: 11.111, mean grad: 0.000063
44
training loss: 1.350529 acc: 56.250, mean grad: -0.000018
traini

  0%|          | 0/1 [00:00<?, ?it/s]

training loss: 1.740956 acc: 18.750, mean grad: 0.000093
training loss: 1.960083 acc: 11.111, mean grad: -0.000050
training loss: 1.638315 acc: 25.000, mean grad: 0.000081
training loss: 1.742578 acc: 11.111, mean grad: -0.000011


100%|##########| 1/1 [00:02<00:00,  2.11s/it]

Validation/Test Loss: 1.7285420894622803, and Accuracy 10.666666984558105



