# Baseline few-shot learner with CORD - Nearest neighbors

In [1]:
import pandas as pd
import numpy as np
import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from tqdm import tqdm_notebook
from tqdm import trange
from transformers import set_seed

SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
set_seed(SEED)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

'cuda'

In [2]:
import warnings
warnings.filterwarnings("ignore")

## Load data

In [3]:
from datasets import load_dataset_builder, load_dataset, concatenate_datasets
from pprint import pprint
import datasets

datasets.utils.logging.set_verbosity_error()

DATASET_URI = 'katanaml/cord'

dsbuild = load_dataset_builder(DATASET_URI)
pprint(dsbuild.info.features.keys())

print("No. of labels: ", len(dsbuild.info.features['ner_tags'].feature.names))

data = load_dataset(DATASET_URI)
print(data)

dict_keys(['id', 'words', 'bboxes', 'ner_tags', 'image_path'])
No. of labels:  23


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'words', 'bboxes', 'ner_tags', 'image_path'],
        num_rows: 800
    })
    test: Dataset({
        features: ['id', 'words', 'bboxes', 'ner_tags', 'image_path'],
        num_rows: 100
    })
    validation: Dataset({
        features: ['id', 'words', 'bboxes', 'ner_tags', 'image_path'],
        num_rows: 100
    })
})


## Data sampling function

Pick random sample of given support set size, query set size, no of classes and the dataset size. 

The function will be used to sample data every episode.

In [4]:
LABELS = dsbuild.info.features['ner_tags'].feature.names

# select 15 for training, 8 for testing
TRAINLABELS = np.random.choice(LABELS, 13, replace=False)
TESTLABELS = np.random.choice(LABELS, 10, replace=False)

HDIM = 768
# TODO: remove 'O' label?

Sampler below will be used for every episode to generate a sample of data to use in the episode.

This will by default return all entities in a sample eventhough they are selected for a particular entity. We need additional preprocessing before the modeling stage to only train with the tokens that correspond to our target entity.

In [5]:
def sample_data(dataset, n_way, n_support, n_query, labels):
    # randomly select B labels to sample in the
    B = np.random.choice(labels, n_way, replace=False)
    B = [LABELS.index(c) for c in B]
    cls_sample = {}  # class -> sample
    
    # select support and query samples for each of these labels
    for cls in B:
        # select data containing the tag
        sample = dataset.filter(lambda x: cls in x,
                                input_columns=['ner_tags'],
                               keep_in_memory=True)
        # sample a small subset of the result
        sample = sample.shuffle(keep_in_memory=True) \
                       .select(np.arange(n_support+n_query), 
                                keep_in_memory=True)
        cls_sample[cls] = sample
    
    return {
        'samples': cls_sample,
        'n_way': n_way,
        'n_support': n_support,
        'n_query': n_query
    }

In [6]:
id2label = data['test'].features['ner_tags'].feature.int2str

def align_labels_with_tokens(labels, word_ids):
    new_labels = []
    current_word = None
    for word_id in word_ids:
        if word_id != current_word:
            # Start of a new word!
            current_word = word_id
            label = -100 if word_id is None else id2label(labels[word_id])
            new_labels.append(label)
        elif word_id is None:
            # Special token
            new_labels.append(-100)
        else:
            # Same word as previous token
            label = labels[word_id]
            label = id2label(label)
            new_labels.append(label)

    return new_labels

def align_all_labels(encoded, orig_data): 
    all_labels = []
    for i in range(encoded['input_ids'].shape[0]): # for each instance
        labels = align_labels_with_tokens(orig_data['ner_tags'][i], encoded.word_ids(i))
        all_labels.append(labels)
    return all_labels

def getwords(wordids, words):
    return [words[i]  if i is not None else -100 for i in wordids]

def decode_tokens(input_ids):
    decoded = []
    for inp in input_ids:
        decoded.append([processor.decode(i) for i in inp])
    return decoded

def load_image(data):
    # update dataset with image
    data['image'] = Image.open(data['image_path']).convert('RGB')
    return data

In [7]:
def encode(sample, processor, model):
    enc = processor(sample['image'], sample['words'],
                   boxes = sample['bboxes'],
                    return_tensors='pt',
                   truncation=True, padding=True)
    with torch.no_grad():
        enc = enc.to(DEVICE)
        out = model(**enc)
    out = out.last_hidden_state.detach()
    out = out[:,:enc['input_ids'].shape[1],:]  # select only the word tokens
    return enc, out

def encode_and_filter(sample, processor, model, load_image_on_sample=False):
    # output dimension: n_way, n_support+n_query, hiddendim
    output_tensors = []
    
    for cls in sample['samples'].keys():
        instance = sample['samples'][cls]
        if load_image_on_sample:
            instance = instance.map(load_image,
                                    keep_in_memory=True,
                                   num_proc=4
                                   )
        enc, out = encode(instance, processor, model)
        aligned_labels = align_all_labels(enc, instance)
        
        # output dimension: n_support+n_query, hiddendim
        selected = []
        for i, tag in enumerate(aligned_labels):
            targetclass = LABELS[cls]
            # ids of tokens of the target class
            idxs = [idx                
                    for idx, tok in enumerate(tag)
                    if tok == targetclass ]
            # randomly select 1 token of the class
            selidx = np.random.choice(idxs, 1)
            selected.append(out[i,selidx,:])
        selected = torch.cat(selected)
        output_tensors.append(selected)
        
    output_tensors = torch.stack(output_tensors)
    
    assert output_tensors.shape == (sample['n_way'], sample['n_support']+sample['n_query'], HDIM)
    return output_tensors

## Implement Nearest Neighbor Model

In [8]:
class Nearest(nn.Module):
    def __init__(self, encoder, processor, k):
        super(Nearest, self).__init__()
        self.encoder = encoder
        self.processor = processor
        self.k = k
        
        # exclude encoder params
        for p in self.encoder.parameters():
            p.requires_grad = False

    def set_forward_loss(self, sample, load_image_on_sample=False):
        n_way = sample['n_way']
        n_support = sample['n_support']
        n_query = sample['n_query']
        
        sample_enc = encode_and_filter(sample, self.processor, self.encoder, load_image_on_sample)

        x_support = sample_enc[:, :n_support]
        x_query = sample_enc[:, n_support:]
        
        x_support = x_support.contiguous().view(n_way * n_support, *x_support.size()[2:])
        x_query = x_query.contiguous().view(n_way * n_query, *x_query.size()[2:])
        
        target_sup = torch.arange(0, n_way) \
                      .view(n_way, 1, 1)  \
                      .expand(n_way, n_support, 1) \
                      .contiguous() \
                      .view(n_way * n_support) \
                      .long() \
                      .to(DEVICE)
        target_quer = torch.arange(0, n_way) \
                      .view(n_way, 1, 1)  \
                      .expand(n_way, n_query, 1) \
                      .contiguous() \
                      .view(n_way * n_query) \
                      .long() \
                      .to(DEVICE)
        
        distmat = x_query.matmul(x_support.transpose(0,1))  # dot product distance
        sortidx = distmat.argsort(descending=True)
        pred = target_sup[sortidx[:,:self.k]].mode().values
        
        acc_val = torch.eq(target_quer, pred).float().mean()

        return 10, {
            'loss': 10,   # dummy loss
            'acc': acc_val.item(),
            'y_hat': pred
            }

    def distance(self, x, y):
        # x: N x D
        # y: M x D
        n = x.size(0)
        m = y.size(0)
        d = x.size(1)
        assert d == y.size(1)

        x = x.unsqueeze(1).expand(n, m, d)
        y = y.unsqueeze(0).expand(n, m, d)

        return torch.pow(x - y, 2).sum(2)

## Training and Validation

In [9]:
def train(model,
          dataset, n_way, n_support, n_query, 
          max_epoch, epoch_size,
          load_image=False):
    epoch = 0
    
    while epoch < max_epoch:
        running_loss = 0
        running_acc = 0
        
        for episode in trange(epoch_size, desc="Epoch {:d}/{:d} train".format(epoch+1, max_epoch)):
            sample = sample_data(dataset,
                                 n_way, n_support, n_query,
                                 TRAINLABELS)
            loss, output = model.set_forward_loss(sample, load_image)
            running_loss += output['loss']
            running_acc += output['acc']
            
        epoch_loss = running_loss / epoch_size
        epoch_acc = running_acc / epoch_size
        print('Epoch {:d} -- Loss: {:.4f} Acc: {:.4f}'.format(epoch+1,epoch_loss, epoch_acc))
        epoch += 1

In [10]:
def test(model, 
         dataset, n_way, n_support, n_query, 
         test_episode,
            load_image_on_sample=False):
    running_loss = 0.0
    running_acc = 0.0
    for episode in trange(test_episode):
        sample = sample_data(dataset,
                             n_way, n_support, n_query, 
                             TESTLABELS)
        loss, output = model.set_forward_loss(sample, load_image_on_sample)
        running_loss += output['loss']
        running_acc += output['acc']
    avg_loss = running_loss / test_episode
    avg_acc = running_acc / test_episode
    print('Test results -- Loss: {:.4f} Acc: {:.4f}'.format(avg_loss, avg_acc))
    return avg_acc

## Train the model

In [18]:
N_WAY = 5
N_SUPPORT = 5
N_QUERY = 5
K = 3

MAX_EPOCHS = 1
EPOCH_SIZE = 50

In [12]:
from transformers import AutoProcessor, AutoModelForTokenClassification, AutoModel
processor = AutoProcessor.from_pretrained('microsoft/layoutlmv3-base', apply_ocr=False)
model = AutoModel.from_pretrained('microsoft/layoutlmv3-base')
model.eval()
model.to(DEVICE);

In [21]:
baseline = Nearest(model, processor, K)

In [14]:
dataset = concatenate_datasets([data['train'], data['test']])

In [15]:
dataset = dataset.map(load_image, keep_in_memory=True, num_proc=3)  
load_image_on_sample = False

Map (num_proc=3):   0%|          | 0/900 [00:00<?, ? examples/s]

In [None]:
train(baseline,
      dataset,
      N_WAY, N_SUPPORT, N_QUERY,
      MAX_EPOCHS, EPOCH_SIZE,
      load_image_on_sample)

In [22]:
acc = test(baseline,
      dataset,
      N_WAY, N_SUPPORT, N_QUERY,
      EPOCH_SIZE,
      load_image_on_sample)

  0%|          | 0/50 [00:00<?, ?it/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

  2%|▏         | 1/50 [00:05<04:09,  5.09s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

  4%|▍         | 2/50 [00:10<04:12,  5.26s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

  6%|▌         | 3/50 [00:15<04:07,  5.27s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

  8%|▊         | 4/50 [00:20<03:45,  4.91s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 10%|█         | 5/50 [00:26<04:05,  5.47s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 12%|█▏        | 6/50 [00:30<03:38,  4.98s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 14%|█▍        | 7/50 [00:34<03:25,  4.78s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 16%|█▌        | 8/50 [00:39<03:19,  4.75s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 18%|█▊        | 9/50 [00:44<03:21,  4.92s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 20%|██        | 10/50 [00:50<03:20,  5.00s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 22%|██▏       | 11/50 [00:55<03:25,  5.26s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 24%|██▍       | 12/50 [01:00<03:15,  5.13s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 26%|██▌       | 13/50 [01:05<03:01,  4.91s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 28%|██▊       | 14/50 [01:09<02:49,  4.72s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 30%|███       | 15/50 [01:14<02:50,  4.87s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 32%|███▏      | 16/50 [01:19<02:42,  4.79s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 34%|███▍      | 17/50 [01:23<02:33,  4.64s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 36%|███▌      | 18/50 [01:27<02:24,  4.51s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 38%|███▊      | 19/50 [01:31<02:15,  4.38s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 40%|████      | 20/50 [01:37<02:21,  4.73s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 42%|████▏     | 21/50 [01:41<02:11,  4.53s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 44%|████▍     | 22/50 [01:46<02:09,  4.62s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 46%|████▌     | 23/50 [01:51<02:06,  4.69s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 48%|████▊     | 24/50 [01:54<01:54,  4.41s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 50%|█████     | 25/50 [01:58<01:47,  4.29s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 52%|█████▏    | 26/50 [02:04<01:50,  4.62s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 54%|█████▍    | 27/50 [02:08<01:44,  4.54s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 56%|█████▌    | 28/50 [02:13<01:42,  4.67s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 58%|█████▊    | 29/50 [02:17<01:32,  4.42s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 60%|██████    | 30/50 [02:21<01:27,  4.37s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 62%|██████▏   | 31/50 [02:25<01:20,  4.25s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 64%|██████▍   | 32/50 [02:30<01:19,  4.42s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 66%|██████▌   | 33/50 [02:35<01:15,  4.47s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 68%|██████▊   | 34/50 [02:39<01:11,  4.46s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 70%|███████   | 35/50 [02:44<01:09,  4.61s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 72%|███████▏  | 36/50 [02:48<00:59,  4.28s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 74%|███████▍  | 37/50 [02:53<00:59,  4.54s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 76%|███████▌  | 38/50 [02:58<00:55,  4.67s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 78%|███████▊  | 39/50 [03:03<00:53,  4.82s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 80%|████████  | 40/50 [03:06<00:44,  4.41s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 82%|████████▏ | 41/50 [03:11<00:40,  4.50s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 84%|████████▍ | 42/50 [03:15<00:35,  4.42s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 86%|████████▌ | 43/50 [03:19<00:30,  4.36s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 88%|████████▊ | 44/50 [03:24<00:26,  4.49s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 90%|█████████ | 45/50 [03:30<00:24,  4.86s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 92%|█████████▏| 46/50 [03:34<00:18,  4.74s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 94%|█████████▍| 47/50 [03:39<00:13,  4.65s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 96%|█████████▌| 48/50 [03:44<00:09,  4.69s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

 98%|█████████▊| 49/50 [03:47<00:04,  4.43s/it]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

Filter:   0%|          | 0/900 [00:00<?, ? examples/s]

100%|██████████| 50/50 [03:52<00:00,  4.64s/it]

Test results -- Loss: 10.0000 Acc: 0.5224





In [23]:
print("Accuracy: ", acc)

Accuracy:  0.522399990260601
