#### **Simple Question Answering Using a Transformer Decoder**

We will explore how to train a character-level tranformer language model for a simple question answering task. Given a question of the form `Where was [X] born?` where `[X]` is the name of a public figure, the model will be trained to predict the output `[Y]` which is the name of the birthplace of that person. Our goal is to first pretrain the model on a wikipedia corpus (next character prediction) from which it is expected to acquire knowledge of persons and their birthplaces. Then we finetune the model with supervised training on `(x,y)` sequence pairs of the following form:

`x: Where was Albert Einstein born?%Germany%□□□□□□□□□□□□□□`

`y: □□□□□□□□□□□□□□□□□□□□□□□□□□□□□□%Germany%□□□□□□□□□□□□□□□`

where `x` is the input sequence and `y` is the predicted output sequence and `□` is a special padding token. This is a simple next character prediction task, however we do not want the model to predict the question itself, only the answer, which is why in the output sequence, we replace all characters from the question with the padding token and only have the model predict the characters from the answer. 

i.e. instead of

`y: here was Albert Einstein born?%Germany%□□□□□□□□□□□□□□□`

we use 

`y: □□□□□□□□□□□□□□□□□□□□□□□□□□□□□□%Germany%□□□□□□□□□□□□□□□`

We've also used a special token `%` to mark the beginning and end of the span containing the answer.

**The idea is that by training the model on this task, it can learn to answer a question by retreiving information pertaining to the answer from it's pretrained knowledge.** After training, we can test this idea by giving the model an input sequence which does not contain an answer, i.e. after the start of answer token `%`, we fill the rest of the sequence with padding tokens: 

`x: Where was Enrico Fermi born?%□□□□□□□□□□□□□□□□□□□□□□□□□`

Then if the predicted output sequence contains the right answer, then it will support our idea. We also make sure that person names which were not in the training set will be used during testing.


We will first train the model on the finetuning task without pretraining it and then look at the difference in performance with and without pretraining. 

In [54]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import math
from tqdm import tqdm
import psutil
import wandb
wandb.login()

True

In [9]:
mask_token = "<MASK>"
pad_token = "<PAD>"

In [10]:
# first get the character vocabulry from the pretraining dataset
with open("birth_place_data/wiki.txt", 'r') as file:
    pretrain_text = file.read()

vocab = sorted(list(set(pretrain_text)) + [mask_token, pad_token])
print(f"vocabulary: {vocab}")

vocabulary: ['\n', ' ', '!', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<MASK>', '<PAD>', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '\xa0', '£', '\xad', 'Á', 'Å', 'É', 'Ó', 'Ö', 'Ø', 'Ü', 'ß', 'à', 'á', 'ã', 'ä', 'å', 'æ', 'ç', 'è', 'é', 'ë', 'í', 'ï', 'ñ', 'ó', 'ô', 'ö', 'ø', 'ü', 'ý', 'ă', 'ą', 'ć', 'Č', 'č', 'ě', 'ğ', 'ī', 'İ', 'ı', 'ł', 'ń', 'ō', 'Ő', 'ő', 'œ', 'ř', 'ś', 'ş', 'Š', 'š', 'ť', 'ū', 'Ż', 'ż', 'Ž', 'ž', 'ș', 'Γ', 'Μ', 'ά', 'έ', 'α', 'γ', 'η', 'ι', 'κ', 'ν', 'ο', 'ρ', 'ς', 'τ', 'υ', 'ω', 'ώ', 'Ј', 'А', 'В', 'Г', 'И', 'К', 'П', 'Р', 'С', 'а', 'б', 'в', 'г', 'д', 'е', 'и', 'й', 'к', 'л', 'м', 'н', 'о', 'р', 'с', 'т', 'ц', 'ч', 'ь', 'я', 'ћ', 'א', 'ג', 'ה', 'ו', 'ז', 'ח', 'י', '

Create pytorch dataset for finetuning

In [50]:
class NameBirthplaceDataset(Dataset):
    def __init__(self, vocab, mask_token, pad_token, block_size=128, split="train"):
        self.vocab= vocab
        self.ctoi = {c:i for i,c in enumerate(vocab)}
        self.mask_token = mask_token 
        self.pad_token = pad_token
        self.block_size = block_size
        if split == "train":
            data_filename="birth_place_data/birth_places_train.tsv"
        elif split == "dev":
            data_filename="birth_place_data/birth_places_dev.tsv"
    
        self.data = self.read_data(data_filename)
         
    def read_data(self, filename):
        with open(filename, 'r') as f: 
            lines = f.readlines()
        return lines    

    @property
    def pad_token_index(self):
        return self.ctoi[self.pad_token]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        line = self.data[index]
        question, answer = line.strip().split('\t') 
        question, answer = list(question), list(answer)
        x = question + [self.mask_token] + answer + [self.mask_token] 
        x = x + (self.block_size-len(x)) * [self.pad_token] 
        y = [self.pad_token] * (len(question) -1) + x[len(question):]
        x = x[:-1] 

        x = torch.tensor([self.ctoi[c] for c in x], dtype=torch.long)
        y = torch.tensor([self.ctoi[c] for c in y], dtype=torch.long)
        return x, y


In [51]:
train_data = NameBirthplaceDataset(vocab, mask_token, pad_token)
dev_data = NameBirthplaceDataset(vocab, mask_token, pad_token, split="dev")

In [52]:
x, y = train_data[0]

def decode_token_indices(x):
    return "".join([vocab[i] for i in x])

x_decoded = decode_token_indices(x)
y_decoded = decode_token_indices(y)
print(x_decoded)
print(y_decoded)

Where was Khatchig Mouradian born?<MASK>Lebanon<MASK><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD>
<PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><MASK>Lebanon<MASK><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><P

#### Create the transformer question answering model

In [59]:
class TransformerQA(torch.nn.Module):
    def __init__(self, vocab_size, blocks_size, pad_token_index, embedding_dim=16, feedforward_dim=64, num_heads=1, num_layers=1, dropout_rate=0.1, device="cpu"):
        super().__init__()
        self.vocab_size = vocab_size
        self.block_size = blocks_size
        self.pad_token_index = pad_token_index
        # embedding layers
        self.emb = torch.nn.Embedding(vocab_size, embedding_dim)
        c = 0.01        
        torch.nn.init.uniform_(self.emb.weight, -c, c)

        # static positional encoding (max length set to 1024)
        self.pos_emb = torch.zeros(size=(1024, embedding_dim), device=device)
        for pos in range(1024):
            for i in range(0, embedding_dim, 2):
                self.pos_emb[pos, i] = math.sin(pos / (10000 ** ((2 * i)/embedding_dim)))
                if i+1 < embedding_dim:
                    self.pos_emb[pos, i+1] = math.cos(pos / (10000 ** ((2 * (i+1))/embedding_dim)))

        # transformer decoder
        decoder_layer = torch.nn.TransformerDecoderLayer(d_model=embedding_dim, nhead=num_heads, dim_feedforward=feedforward_dim, activation='gelu', dropout=dropout_rate, batch_first=True)
        self.transformer_decoder = torch.nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

        # create output layer (computes output class logits for each item in sequence)
        self.output_layer =  torch.nn.Linear(embedding_dim, vocab_size)
        # tie the output layer weights with the embedding layer weights
        self.output_layer.weight = self.emb.weight


    def create_causal_mask(self, input):
        _, L, _ = input.shape
        # create an L x L matrix with ones on and below diagonal
        mask = torch.tril(torch.ones(size=(L,L), device=input.device))
        # create mask in which the positions where there is a zero is filled with -infinity 
        mask = mask.masked_fill((mask==0), float("-inf"))
        return mask

    # forward pass
    def forward(self, x, y=None):
        # get embeddings for batch of input sequences of length L
        x = self.emb(x) # shape: (B,L,D)
        # add positional embedding
        x = x + self.pos_emb[:x.shape[1]] # shape: (B,L,D)
        # pass through transformer decoder layers
        mask = self.create_causal_mask(x)
        x = self.transformer_decoder(x, x, tgt_mask=mask) # shape: (B,L,D)
        # compute output logits
        x = self.output_layer(x) # shape: (B,L,vocab_size)

        if y==None:
            return x

        # reshape
        x = x.view(-1,x.shape[-1]) # shape: (B*L,vocab_size)
        y = y.view(-1) # shape: (B*L,)
        # compute cross entropy loss
        loss = F.cross_entropy(x, y, ignore_index=self.pad_token_index)
        return x, loss
    
    """    
    @torch.no_grad()
    def generate(self, subword2idx, block_size, temperature=1.0, topk=None, start_token="<s>", end_token="</s>", max_len=30, device="cpu"):
        self.eval()
        # generate one token at a time
        x = torch.full(size=(1,1), fill_value=subword2idx[start_token], dtype=torch.long, device=device)
        tokens = [x.item()]
        for _ in range(max_len):
            # crop the input sequence so that it doesn't exceed block size (only keep the last block_size tokens in the sequence to generate the next token)
            x = x[:,-block_size:]
            logits = self.forward(x) # shape: (1,L,V)
            # rescale the logits with the temperature
            logits = logits / temperature
            if topk is not None:
                topk_logits, idx = torch.sort(logits[0,-1,:], descending=True)
                # sample from the distribution for the last word in the sequence
                p = F.softmax(topk_logits, dim=-1) # shape: (V,)
                next_word_idx = idx[torch.multinomial(p, num_samples=1)]
            else:             
                # sample from the distribution for the last word in the sequence
                p = F.softmax(logits[:,-1,:], dim=-1) # shape: (V,)
                next_word_idx = torch.multinomial(p, num_samples=1)
            # append to the sequence
            x = torch.cat((x, next_word_idx.view(1,1)), dim=1)
            tokens.append(next_word_idx.item())

        self.train()
        return tokens
        """

# training loop
def train(model, optimizer, scheduler, train_dataloader, val_dataloader, device="cpu", num_epochs=10, val_every=1, save_every=None, log_metrics=None):
    avg_loss = 0
    train_acc = 0
    val_loss = 0
    val_acc = 0
    pp = 0
    for epoch in range(num_epochs):
        num_correct = 0
        num_total = 0
        pbar = tqdm(train_dataloader, desc="Epochs")
        for batch in pbar:
            inputs, targets = batch
            # move batch to device
            inputs, targets = inputs.to(device), targets.to(device)
            # reset gradients
            optimizer.zero_grad()
            # forward pass
            logits, loss = model(inputs, targets)
            # backward pass
            loss.backward()
            # optimizer step
            optimizer.step()
            avg_loss = 0.9* avg_loss + 0.1*loss.item()
            y_pred = logits.argmax(dim=-1) # shape (B*L)
            y = targets.view(-1) # shape (B*L)
            mask = (y != -1)
            num_correct += (torch.eq(y[mask], y_pred[mask])).sum().item()
            num_total += len(y[mask])
            train_acc = num_correct / num_total        
            
            pbar.set_description(f"Epoch {epoch + 1}, EMA Train Loss: {avg_loss:.3f}, Train Accuracy: {train_acc: .3f}, Val Loss: {val_loss: .3f}, Val Accuracy: {val_acc: .3f}")  

            if log_metrics:
                metrics = {"Batch loss" : loss.item(), "Moving Avg Loss" : avg_loss, "Val Loss": val_loss}
                log_metrics(metrics)

        scheduler.step()
        if epoch%val_every == 0:
            # compute validation loss
            val_loss, val_acc = validation(model, val_dataloader, device=device)
            pbar.set_description(f"Epoch {epoch + 1}, EMA Train Loss: {avg_loss:.3f}, Train Accuracy: {train_acc: .3f}, Val Loss: {val_loss: .3f}, Val Accuracy: {val_acc: .3f}") 

        if save_every is not None:
            if (epoch+1) % save_every == 0:
                save_model_checkpoint(model, optimizer, epoch, avg_loss)

def validation(model, val_dataloader, device="cpu"):
    model.eval()
    val_losses = torch.zeros(len(val_dataloader))
    with torch.no_grad():
        num_correct = 0
        num_total = 0
        for i,batch in enumerate(val_dataloader):
            inputs, targets = batch = batch
            inputs, targets = inputs.to(device), targets.to(device)
            logits, loss = model(inputs, targets)
            y_pred = logits.argmax(dim=-1) # shape (B*L)
            y = targets.view(-1) # shape (B*L)
            mask = (y != -1)
            num_correct += (torch.eq(y[mask], y_pred[mask])).sum().item()
            num_total += len(y[mask])
            val_losses[i] = loss.item()
    model.train()
    val_loss = val_losses.mean().item()
    val_accuracy = num_correct / num_total
    return val_loss, val_accuracy


def save_model_checkpoint(model, optimizer, epoch=None, loss=None):
    # Save the model and optimizer state_dict
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }

    # Save the checkpoint to a file
    torch.save(checkpoint, 'qa_model_checkpoint.pth')
    print(f"Saved model checkpoint!")


def load_model_checkpoint(model, optimizer):
    checkpoint = torch.load('qa_model_checkpoint.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    model.train()
    print("Loaded model from checkpoint!")
    return model, optimizer      

In [60]:
B = 32
D = 32
vocab_size = len(vocab)
num_heads = 4
num_layers = 2
learning_rate = 1e-4
DEVICE = "cuda"

train_dataloader = DataLoader(train_data, batch_size=B, shuffle=True, pin_memory=True, num_workers=2)
val_dataloader = DataLoader(dev_data, batch_size=B, shuffle=True, pin_memory=True, num_workers=2)


model = TransformerQA(vocab_size, blocks_size=128, pad_token_index=train_data.pad_token_index, embedding_dim=D, feedforward_dim=4*D, num_heads=num_heads, num_layers=num_layers, dropout_rate=0.2, device=DEVICE).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler =  torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.95)
#model, optimizer = load_model_checkpoint(model, optimizer)


num_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in transformer network: {num_params/1e6} M")
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")

Total number of parameters in transformer network: 0.042432 M
RAM used: 1014.29 MB


In [61]:
train(model, optimizer, scheduler, train_dataloader, val_dataloader, device=DEVICE, num_epochs=50, save_every=5, val_every=1) #, log_metrics=log_metrics)

Epoch 1, EMA Train Loss: 5.287, Train Accuracy:  0.004, Val Loss:  0.000, Val Accuracy:  0.000: 100%|██████████| 63/63 [00:00<00:00, 74.51it/s]
Epoch 2, EMA Train Loss: 4.953, Train Accuracy:  0.004, Val Loss:  5.225, Val Accuracy:  0.005: 100%|██████████| 63/63 [00:00<00:00, 100.37it/s]
Epoch 3, EMA Train Loss: 4.636, Train Accuracy:  0.006, Val Loss:  4.886, Val Accuracy:  0.006: 100%|██████████| 63/63 [00:00<00:00, 96.28it/s] 
Epoch 4, EMA Train Loss: 4.359, Train Accuracy:  0.006, Val Loss:  4.577, Val Accuracy:  0.006: 100%|██████████| 63/63 [00:00<00:00, 96.20it/s] 
Epoch 5, EMA Train Loss: 4.128, Train Accuracy:  0.006, Val Loss:  4.306, Val Accuracy:  0.006: 100%|██████████| 63/63 [00:00<00:00, 98.02it/s] 


Saved model checkpoint!


Epoch 6, EMA Train Loss: 3.915, Train Accuracy:  0.006, Val Loss:  4.072, Val Accuracy:  0.006: 100%|██████████| 63/63 [00:00<00:00, 96.22it/s] 
Epoch 7, EMA Train Loss: 3.743, Train Accuracy:  0.006, Val Loss:  3.877, Val Accuracy:  0.006: 100%|██████████| 63/63 [00:00<00:00, 90.98it/s]
Epoch 8, EMA Train Loss: 3.612, Train Accuracy:  0.006, Val Loss:  3.712, Val Accuracy:  0.006: 100%|██████████| 63/63 [00:00<00:00, 98.26it/s] 
Epoch 9, EMA Train Loss: 3.499, Train Accuracy:  0.010, Val Loss:  3.582, Val Accuracy:  0.006: 100%|██████████| 63/63 [00:00<00:00, 101.65it/s]
Epoch 10, EMA Train Loss: 3.418, Train Accuracy:  0.016, Val Loss:  3.476, Val Accuracy:  0.016: 100%|██████████| 63/63 [00:00<00:00, 100.28it/s]


Saved model checkpoint!


Epoch 11, EMA Train Loss: 3.339, Train Accuracy:  0.016, Val Loss:  3.395, Val Accuracy:  0.016: 100%|██████████| 63/63 [00:00<00:00, 102.74it/s]
Epoch 12, EMA Train Loss: 3.289, Train Accuracy:  0.016, Val Loss:  3.328, Val Accuracy:  0.016: 100%|██████████| 63/63 [00:00<00:00, 103.51it/s]
Epoch 13, EMA Train Loss: 3.256, Train Accuracy:  0.016, Val Loss:  3.276, Val Accuracy:  0.016: 100%|██████████| 63/63 [00:00<00:00, 101.11it/s]
Epoch 14, EMA Train Loss: 3.224, Train Accuracy:  0.016, Val Loss:  3.235, Val Accuracy:  0.016: 100%|██████████| 63/63 [00:00<00:00, 99.93it/s] 
Epoch 15, EMA Train Loss: 3.189, Train Accuracy:  0.016, Val Loss:  3.205, Val Accuracy:  0.016: 100%|██████████| 63/63 [00:00<00:00, 98.03it/s] 


Saved model checkpoint!


Epoch 16, EMA Train Loss: 3.162, Train Accuracy:  0.016, Val Loss:  3.180, Val Accuracy:  0.016: 100%|██████████| 63/63 [00:00<00:00, 96.04it/s]
Epoch 17, EMA Train Loss: 3.167, Train Accuracy:  0.016, Val Loss:  3.164, Val Accuracy:  0.016: 100%|██████████| 63/63 [00:00<00:00, 100.58it/s]
Epoch 18, EMA Train Loss: 3.139, Train Accuracy:  0.016, Val Loss:  3.149, Val Accuracy:  0.016: 100%|██████████| 63/63 [00:00<00:00, 102.35it/s]
Epoch 19, EMA Train Loss: 3.144, Train Accuracy:  0.016, Val Loss:  3.137, Val Accuracy:  0.016: 100%|██████████| 63/63 [00:00<00:00, 103.06it/s]
Epoch 20, EMA Train Loss: 3.128, Train Accuracy:  0.016, Val Loss:  3.129, Val Accuracy:  0.016: 100%|██████████| 63/63 [00:00<00:00, 103.60it/s]


Saved model checkpoint!


Epoch 21, EMA Train Loss: 3.124, Train Accuracy:  0.016, Val Loss:  3.121, Val Accuracy:  0.016: 100%|██████████| 63/63 [00:00<00:00, 97.56it/s] 
Epoch 22, EMA Train Loss: 3.119, Train Accuracy:  0.016, Val Loss:  3.117, Val Accuracy:  0.016: 100%|██████████| 63/63 [00:00<00:00, 99.86it/s] 
Epoch 23, EMA Train Loss: 3.101, Train Accuracy:  0.016, Val Loss:  3.113, Val Accuracy:  0.016: 100%|██████████| 63/63 [00:00<00:00, 100.22it/s]
Epoch 24, EMA Train Loss: 3.098, Train Accuracy:  0.016, Val Loss:  3.103, Val Accuracy:  0.016: 100%|██████████| 63/63 [00:00<00:00, 103.70it/s]
Epoch 25, EMA Train Loss: 3.105, Train Accuracy:  0.016, Val Loss:  3.098, Val Accuracy:  0.016: 100%|██████████| 63/63 [00:00<00:00, 101.64it/s]


Saved model checkpoint!


Epoch 26, EMA Train Loss: 3.062, Train Accuracy:  0.016, Val Loss:  3.090, Val Accuracy:  0.016: 100%|██████████| 63/63 [00:00<00:00, 96.06it/s] 
Epoch 27, EMA Train Loss: 3.031, Train Accuracy:  0.016, Val Loss:  3.046, Val Accuracy:  0.016: 100%|██████████| 63/63 [00:00<00:00, 96.23it/s] 
Epoch 28, EMA Train Loss: 3.004, Train Accuracy:  0.016, Val Loss:  3.013, Val Accuracy:  0.016: 100%|██████████| 63/63 [00:00<00:00, 101.89it/s]
Epoch 29, EMA Train Loss: 2.970, Train Accuracy:  0.016, Val Loss:  2.978, Val Accuracy:  0.016: 100%|██████████| 63/63 [00:00<00:00, 101.43it/s]
Epoch 30, EMA Train Loss: 2.947, Train Accuracy:  0.016, Val Loss:  2.945, Val Accuracy:  0.016: 100%|██████████| 63/63 [00:00<00:00, 99.88it/s] 


Saved model checkpoint!


Epoch 31, EMA Train Loss: 2.906, Train Accuracy:  0.016, Val Loss:  2.917, Val Accuracy:  0.016: 100%|██████████| 63/63 [00:00<00:00, 96.65it/s] 
Epoch 32, EMA Train Loss: 2.854, Train Accuracy:  0.016, Val Loss:  2.872, Val Accuracy:  0.016: 100%|██████████| 63/63 [00:00<00:00, 101.28it/s]
Epoch 33, EMA Train Loss: 2.828, Train Accuracy:  0.017, Val Loss:  2.838, Val Accuracy:  0.016: 100%|██████████| 63/63 [00:00<00:00, 97.00it/s] 
Epoch 34, EMA Train Loss: 2.776, Train Accuracy:  0.017, Val Loss:  2.784, Val Accuracy:  0.017: 100%|██████████| 63/63 [00:00<00:00, 96.24it/s]
Epoch 35, EMA Train Loss: 2.740, Train Accuracy:  0.018, Val Loss:  2.741, Val Accuracy:  0.018: 100%|██████████| 63/63 [00:00<00:00, 90.77it/s]


Saved model checkpoint!


Epoch 36, EMA Train Loss: 2.690, Train Accuracy:  0.018, Val Loss:  2.701, Val Accuracy:  0.018: 100%|██████████| 63/63 [00:00<00:00, 90.26it/s]
Epoch 37, EMA Train Loss: 2.656, Train Accuracy:  0.019, Val Loss:  2.659, Val Accuracy:  0.019: 100%|██████████| 63/63 [00:00<00:00, 93.26it/s]
Epoch 38, EMA Train Loss: 2.631, Train Accuracy:  0.019, Val Loss:  2.624, Val Accuracy:  0.019: 100%|██████████| 63/63 [00:00<00:00, 94.33it/s]
Epoch 39, EMA Train Loss: 2.594, Train Accuracy:  0.020, Val Loss:  2.587, Val Accuracy:  0.020: 100%|██████████| 63/63 [00:00<00:00, 80.54it/s]
Epoch 40, EMA Train Loss: 2.574, Train Accuracy:  0.020, Val Loss:  2.558, Val Accuracy:  0.021: 100%|██████████| 63/63 [00:00<00:00, 92.37it/s]


Saved model checkpoint!


Epoch 41, EMA Train Loss: 2.528, Train Accuracy:  0.021, Val Loss:  2.529, Val Accuracy:  0.021: 100%|██████████| 63/63 [00:00<00:00, 99.02it/s] 
Epoch 42, EMA Train Loss: 2.507, Train Accuracy:  0.021, Val Loss:  2.498, Val Accuracy:  0.021: 100%|██████████| 63/63 [00:00<00:00, 102.75it/s]
Epoch 43, EMA Train Loss: 2.485, Train Accuracy:  0.022, Val Loss:  2.466, Val Accuracy:  0.022: 100%|██████████| 63/63 [00:00<00:00, 95.63it/s]
Epoch 44, EMA Train Loss: 2.468, Train Accuracy:  0.022, Val Loss:  2.439, Val Accuracy:  0.023: 100%|██████████| 63/63 [00:00<00:00, 95.32it/s] 
Epoch 45, EMA Train Loss: 2.437, Train Accuracy:  0.022, Val Loss:  2.410, Val Accuracy:  0.023: 100%|██████████| 63/63 [00:00<00:00, 95.99it/s] 


Saved model checkpoint!


Epoch 46, EMA Train Loss: 2.413, Train Accuracy:  0.023, Val Loss:  2.376, Val Accuracy:  0.024: 100%|██████████| 63/63 [00:00<00:00, 98.44it/s] 
Epoch 47, EMA Train Loss: 2.390, Train Accuracy:  0.023, Val Loss:  2.353, Val Accuracy:  0.024: 100%|██████████| 63/63 [00:00<00:00, 94.52it/s] 
Epoch 48, EMA Train Loss: 2.382, Train Accuracy:  0.023, Val Loss:  2.330, Val Accuracy:  0.024: 100%|██████████| 63/63 [00:00<00:00, 91.05it/s]
Epoch 49, EMA Train Loss: 2.354, Train Accuracy:  0.023, Val Loss:  2.311, Val Accuracy:  0.025: 100%|██████████| 63/63 [00:00<00:00, 98.41it/s] 
Epoch 50, EMA Train Loss: 2.330, Train Accuracy:  0.023, Val Loss:  2.290, Val Accuracy:  0.025: 100%|██████████| 63/63 [00:00<00:00, 98.87it/s] 


Saved model checkpoint!
