In [1]:
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

import logging

debug = logging.getLogger("Debug")
info  = print
plt.ion()   # interactive mode



<contextlib.ExitStack at 0x7d0681b0c880>

In [2]:
%pip install transformers



In [3]:
#check GPU
device = None
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Running CUDA Mode:", device, torch.cuda.get_device_name(0))
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Running MPS Mode:", device)
else:
    device = torch.device("cpu")
    print("Running CPU Mode:", device)



Running CUDA Mode: cuda Tesla P100-PCIE-16GB


## DataLoader

In [20]:
START_TOKEN = "<START>"
END_TOKEN = "<END>"
UNK_TOKEN = "<UNK>"
PAD_TOKEN = "<PAD>"

class Vocabulary:
    def __init__(self, freq_dict, wd_to_id, id_to_wd):
        self.freq_dict = freq_dict
        self.wd_to_id = wd_to_id
        self.id_to_wd = id_to_wd
        self.N = len(freq_dict)
    
    def get_id(self, word):
        if word in self.wd_to_id:
            return self.wd_to_id[word]
        else:
            return self.wd_to_id[UNK_TOKEN]

class LatexFormulaDataset(Dataset):
    """Latex Formula Dataset: Image and Text"""
    
    def __init__(self, csv_file, root_dir, transform = None, max_examples = None):
        """
        Arguments:
            csv_file (string): Path to the csv file with image name and text
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        info("Loading Dataset...")
        self.df = pd.read_csv(csv_file)
        
        #info("Loaded Dataset", self.df.info)
        
        #Slice the dataset if max_examples is not None
        if max_examples is not None:
            self.df = self.df.iloc[:max_examples, :]

        self.root_dir = root_dir
        self.transform = transform

        self.df['formula'] = self.df['formula'].apply(lambda x: x.split())
        self.df['formula'] = self.df['formula'].apply(lambda x: [START_TOKEN] + x + [END_TOKEN])

        self.maxlen = 0
        for formula in self.df['formula']:
            if len(formula) > self.maxlen:
                self.maxlen = len(formula)
        
        def convert_to_ids(formula):
            form2 = [self.vocab.get_id(wd) for wd in formula]
            return torch.tensor(form2, dtype=torch.int64)
        
        self.df['formula'] = self.df['formula'].apply(lambda x: x +[PAD_TOKEN]*(self.maxlen - len(x)))
        self.vocab= self.construct_vocab() 
        self.df['formula'] = self.df['formula'].apply(convert_to_ids)
        info("Loaded.")
        

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        """
        Returns sample of type image, textformula
        """
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.df.iloc[idx, 0])
        image = io.imread(img_name)
        formula = self.df.iloc[idx, 1]

        # formula = np.array([formula], dtype=str).reshape(-1, 1)
        # formula = [self.vocab.get_id(wd[0]) for wd in formula]
        
        sample = {'image': image, 'formula': formula}

        if self.transform:
            sample['image'] = self.transform(sample['image'])
            
        return sample 
    
    def construct_vocab(self):
        """
        Constructs vocabulary from the dataset formulas
        """
        #Split on spaces to tokenize
        freq_dict = {}
        for formula in self.df['formula']:
            for wd in formula:
                if wd not in freq_dict:
                    freq_dict[wd] = 1
                else:
                    freq_dict[wd] += 1
        freq_dict[UNK_TOKEN] = 1
        N = len(freq_dict)
        wd_to_id = {}
        for i, wd in enumerate(freq_dict):
            wd_to_id[wd] = i
        id_to_wd = {v: k for k, v in wd_to_id.items()}
    
        #pad the formulas with 
        return Vocabulary(freq_dict, wd_to_id, id_to_wd)      

def get_dataloader(csv_path, image_root, batch_size, transform = None, max_examples = None):
    """
    Returns dataloader,dataset for the dataset
    """
    dataset = LatexFormulaDataset(csv_path, image_root, max_examples=max_examples,transform=transform) #checked
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader, dataset
 

## Encoder: ViT

In [5]:
from transformers import ViTFeatureExtractor, ViTModel
from torch.nn import functional as F

# Initialize ViT model


Downloading (…)rocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [40]:
vit_model.config

ViTConfig {
  "_name_or_path": "google/vit-base-patch16-224",
  "architectures": [
    "ViTForImageClassification"
  ],
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "id2label": {
    "0": "tench, Tinca tinca",
    "1": "goldfish, Carassius auratus",
    "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
    "3": "tiger shark, Galeocerdo cuvieri",
    "4": "hammerhead, hammerhead shark",
    "5": "electric ray, crampfish, numbfish, torpedo",
    "6": "stingray",
    "7": "cock",
    "8": "hen",
    "9": "ostrich, Struthio camelus",
    "10": "brambling, Fringilla montifringilla",
    "11": "goldfinch, Carduelis carduelis",
    "12": "house finch, linnet, Carpodacus mexicanus",
    "13": "junco, snowbird",
    "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
    "15": "robin, American robin, Turdus migratorius",
    "16": "bulbul",
 

In [69]:
class ViTEncoder(nn.Module):
    def __init__(self,
                 pretrained_path='google/vit-base-patch16-224',
                 lr=1e-3,
                 warmup_steps=500,
                 total_steps=10000):
        super().__init__()
        feature_extractor = ViTFeatureExtractor.from_pretrained(pretrained_path)
        vit_model = ViTModel.from_pretrained(model_name)
        self.vit = vit_model
        self.vit.requires_grad = False # freeze vit model
        self.layer = torch.nn.Linear(self.vit.config.hidden_size,512)

    def forward(self, x):
        with torch.no_grad(): # vit model freeze during forward pass
            vit_out = self.vit(x)
        last_hidden_state = vit_out.last_hidden_state
        linear_input = last_hidden_state[:, 0, :]
        out = self.layer(linear_input)
        return out

## Decoder

In [70]:
class DecoderLSTM(nn.Module):
    """
    Inputs:
    (here M is whatever the batch size is passed)

    context_size : size of the context vector [shape: (1,M,context_size)]
    n_layers: number of layers [for our purposes, defaults to 1]
    hidden_size : size of the hidden state vectors [shape: (n_layers,M,hidden_size)]
    embed_size : size of the embedding vectors [shape: (1,M,embed_size)]
    vocab_size : size of the vocabulary
    max_length : maximum length of the formula
    """
    def __init__(self, context_size, vocab, max_seq_len, n_layers = 1, hidden_size = 512, embed_size = 512):
        super().__init__()
        self.context_size = context_size
        self.vocab = vocab
        self.vocab_size = vocab.N
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.embed_size = embed_size
        self.input_size = context_size + embed_size
        self.max_seq_len = max_seq_len

        self.embed = nn.Embedding(self.vocab_size, embed_size)
        self.lstm = nn.LSTMCell(self.input_size, self.hidden_size)
        self.linear = nn.Linear(hidden_size, self.vocab_size)
    
    def forward(self, context, target_tensor = None):
        """
        M: batch_size
        context is the context vector from the encoder [shape: (M,context_size)]
        target_tensor is the formula in tensor form [shape: (M,max_length)] (in the second dimension, it is sequence of indices of formula tokens)
            if target_tensor is not None, then we are in Teacher Forcing mode
            else normal jo bhi (last prediction ka embed is concatenated)
        """
        context.to(device)
        batch_size = context.shape[0]

        #initialize hidden state and cell state
        hidden = context
        cell = torch.zeros(batch_size, self.hidden_size).to(device)

        #initialize the input with embedding of the start token. Expand for batch size.
        init_embed = self.embed(torch.tensor([self.vocab.wd_to_id[START_TOKEN]]).to(device).expand(batch_size, -1)).squeeze()
        
        #initialize the output_history and init_output
        outputs = []
        output = torch.zeros((batch_size, self.vocab_size)).to(device)
        
        
        for i in range(self.max_seq_len):
            #teacher forcing: 50% times
            r = torch.rand(1)
            if r>0.5 and target_tensor is not None:
                if i==0 :
                    embedding = init_embed
                else: 
                    embedding = self.embed(target_tensor[:, i-1]).reshape((batch_size, self.embed_size)).to(device)            
            else:
                if i==0 :
                    embedding = init_embed

                else:
                    #create embedding from previous input
                    embedding = self.embed(torch.argmax(output, dim = 1))

            lstm_input = torch.cat([context, embedding], dim = 1).to(device)
    
            hidden, cell = self.lstm(lstm_input, (hidden, cell))
            output = self.linear(hidden)
            outputs.append(output)
            
        output_tensor = torch.stack(outputs).permute(1,0,2) #LBV - > BLV

        return output_tensor, hidden, cell
    

In [71]:
class HandwritingToLatexModel(nn.Module):
    def __init__(self, context_size, vocab, max_seq_len, n_layers, hidden_size, embed_size):
        super().__init__()
        self.encoder = ViTEncoder()
        self.decoder = DecoderLSTM(context_size, vocab, max_seq_len, n_layers, hidden_size, embed_size)
    
    def forward(self, image, target_tensor = None):
        context = self.encoder(image)
        outputs, hidden, cell = self.decoder(context, target_tensor)
        return outputs

In [72]:
import time
import math
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
from tqdm import tqdm

plt.switch_backend('agg')
def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)
def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))
    
def saveModel(save_path, model_state, optimiser_state, loss):
    torch.save({
            'model_state_dict': model_state,
            'optimizer_state_dict':optimiser_state,
            'loss': loss,  
    }, save_path)

In [73]:
def generated_formula(output, vocab):
    """
    output: [shape: (Max_length,vocab_size)]
    """
    output = torch.argmax(output, dim = 1)
    output = output.tolist()
    formula = ' '.join([vocab.id_to_wd[id] for id in output])
    return formula

In [74]:
def train_epoch(dataloader,model, optimizer, criterion):
    total_loss = 0
    idx = 0
    pb = tqdm(dataloader, desc="Batch")
    for data in pb:
        idx+=1
        input_tensor, target_tensor = data['image'].to(device), data['formula'].to(device)
        outputs = model(input_tensor, target_tensor)
        train_dataset = dataloader.dataset 
        if(train_dataset and idx%300==0):
            generated_formula = [train_dataset.vocab.id_to_wd[token.item()] for token in torch.argmax(outputs, dim=2)[0]]
            required_formula = [train_dataset.vocab.id_to_wd[token.item()] for token in target_tensor[0]]
            print(f"Generated: {' '.join(generated_formula)}")
            print(f"Actual: {' '.join(required_formula)}")

        output_logits = outputs.permute(0,2,1)
        
        loss = criterion(
            output_logits,
            target_tensor
        )
        
        #backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step() 

        total_loss += loss.item()
        pb.set_description(f"Loss: {loss.item()}")

    return total_loss / len(dataloader)

def train(train_dataloader, model, n_epochs, optimizer = None, learning_rate=0.001, print_every=1, save_interval=2, save_prefix = 'model'):
    start = time.time()
    plot_losses = []
    print_loss_total = 0  # Reset every print_every

    if not optimizer: optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss(ignore_index=train_dataloader.dataset.vocab.wd_to_id[PAD_TOKEN]).to(device) #as stated in assignment

    model.train()
    for epoch in range(1, n_epochs + 1):
        info(f"Epoch {epoch}")
        loss = train_epoch(train_dataloader, model, optimizer, criterion)
        print_loss_total += loss

        if epoch % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
    
        if epoch % save_interval == 0:
            saveModel(f'/kaggle/working/{save_prefix}_epoch_{epoch}.pt', model.state_dict(), optimizer.state_dict(), loss)    
                
        info('%s (%d %d%%) %.4f' % (timeSince(start, epoch / n_epochs), epoch, epoch / n_epochs * 100, print_loss_avg))

In [75]:

batch_size = 32
vocab_size = 1000
CONTEXT_SIZE = 512
HIDDEN_SIZE = 512
EMBED_SIZE = 512
MAX_EXAMPLES = 1000
# image processing
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
])

In [76]:
#part a
train_csv_path = "/kaggle/input/assignment-4-dataset/col_774_A4_2023/SyntheticData/train.csv"
image_root_path = "/kaggle/input/assignment-4-dataset/col_774_A4_2023/SyntheticData/images/"
train_dataloader, train_dataset = get_dataloader(train_csv_path, image_root_path, batch_size, transform, max_examples=None)

Loading Dataset...
Loaded.


In [77]:
#create a network instance
model = HandwritingToLatexModel(CONTEXT_SIZE, train_dataset.vocab, n_layers=1, hidden_size= HIDDEN_SIZE, embed_size=EMBED_SIZE, max_seq_len=train_dataset.maxlen).to(device)

In [None]:
train(train_dataloader, model, 20, save_interval=1, save_prefix = 'model3.0')

Epoch 1


Loss: 3.021838426589966:  13%|█▎        | 299/2344 [04:06<27:18,  1.25it/s] 

Generated: <START> $ { _ ( L } ( x ) } ( x ) { ( x ) { ( x ) { ( x ) { { ( } } } { 1 } { { 2 } } { { 2 } } 1 } ) { { { ( x ) ) 1 } } 1 ) } { { { 2 } } x ) } ) { ) } $ <END> <END> ) } ) ) { { ) } ) ) { ) { { } ) { { } } { { { } } ) } ) { ) { ) { { { } } ) ) ) ) ) { ) { { } { ) } { ) ) ) { ) ) { { { } ) } { { } ) ) ) { { { } } { ) } ) ) { { } { { { } } } { ) } { { { } } ) ) ) { ) { ) { ) { { { } } { { } } ) } ) ) ) ) ) ) ) { { ) } ) ) { { { } } ) } ) { { } ) ) { { } } ) } ) ) { ) { { } { { } ) } ) { ) ) ) { ) ) { ) ) { ) { ) ) ) { { ) ) { { } } ) } { { { } } ) ) } ) { ) ) { ) ) ) ) { { { } { ) ) { ) ) ) ) { ) { { { } } ) } ) ) { ) ) { { { } } { ) ) { { } ) { ) ) { { } { ) } { { { } ) } ) { ) ) ) { ) ) ) { { ) ) ) ) { ) { { } ) { ) { ) ) { { } { { { } } } { { } } ) ) { { } } ) } { { { } } { { } } } } ) } ) { { { } } } { { } ) } ) ) { { } ) } { ) ) ) { ) ) { { } { { } ) } ) { { { } } ) } { ) ) { ) ) { { } ) { ) ) { { } { ) } ) ) ) { { { } } { ) } { ) ) { ) { ) { ) { { } ) ) ) { ) { { { } )

Loss: 2.9704935550689697:  26%|██▌       | 599/2344 [08:20<25:15,  1.15it/s]

Generated: <START> $ { ^ { 2 } } } ( { } } { \, _ { = 0 , \quad ( { ) _ { 0 } ( = 0 _ _ { 0 } ( x { n } ( x ) { ) } ) { { { { } } , { n } n } } } { { { \ \ _ { 0 } ^ { 2 } 2 } } { 2 } } { { { } } { _ { { 2 } } } . $ <END> <END> <END> , , , , , $ <END> <END> <END> <END> <END> , $ <END> , , , , , , , , , , , , , , $ , , , $ <END> <END> , , , , $ <END> , $ <END> <END> <END> , , $ <END> , , $ , $ , , , $ <END> , $ , , $ <END> <END> , $ <END> , $ <END> , , $ , $ , $ <END> { , $ <END> <END> <END> , , $ <END> <END> , , , $ <END> <END> , , $ , , , , $ <END> , $ <END> <END> <END> <END> , $ <END> , $ <END> , $ , , , , , , $ , $ <END> , , $ , , $ <END> <END> , $ <END> <END> <END> , , $ , , , , $ <END> , $ , $ , , , $ , , $ , $ <END> <END> <END> <END> , , $ , , , , , $ , , , , , , , $ , $ , $ <END> , , $ <END> , , $ , $ <END> , $ , $ , $ , $ <END> { $ <END> <END> <END> } } , , $ <END> <END> <END> , , $ <END> , $ , $ , $ <END> , $ , $ , , , $ <END> <END> <END> <END> <END> , $ , $ , , , , $ , , , , 

Loss: 2.6897499561309814:  38%|███▊      | 899/2344 [12:26<20:00,  1.20it/s]

Generated: <START> $ S = } { { } } 2 } } ( { } } ( ( { ^ { 2 } } ( _ { 0 } } { { } } ) $ $ $ <END> <END> $ $ $ <END> $ $ <END> <END> <END> <END> <END> $ $ <END> <END> $ <END> <END> $ <END> <END> <END> <END> <END> $ $ $ <END> <END> <END> <END> $ <END> <END> <END> $ $ $ $ <END> $ <END> <END> <END> <END> <END> <END> } . $ . $ <END> $ <END> <END> $ $ $ $ $ <END> $ $ <END> <END> <END> $ <END> . $ $ $ $ $ <END> <END> <END> $ <END> $ $ <END> <END> <END> $ <END> $ <END> <END> <END> $ <END> <END> $ $ <END> $ <END> <END> $ $ <END> <END> <END> $ <END> <END> $ $ <END> <END> $ <END> <END> $ <END> $ $ <END> $ $ $ $ <END> . $ $ $ $ $ $ <END> $ $ <END> $ <END> $ $ $ $ $ <END> $ <END> <END> <END> . $ <END> . $ <END> <END> . $ <END> <END> <END> . $ $ $ $ $ $ <END> <END> $ <END> <END> $ $ $ <END> <END> <END> $ <END> <END> . $ <END> . $ $ $ <END> <END> <END> $ $ $ $ $ <END> <END> $ $ $ $ $ $ $ <END> <END> $ $ $ $ <END> <END> $ $ <END> <END> <END> <END> . $ . $ <END> <END> <END> <END> <END> $ <END> $ $ $ $

Loss: 2.8686411380767822:  51%|█████     | 1199/2344 [16:34<17:43,  1.08it/s]

Generated: <START> $ { _ { } = { _ = \frac _ { i } ( x ) ) { \frac { { } { 2 } } { _ { i } } { { } } { 2 } } } . ^ x ) { 2 $ <END> { ) } $ $ <END> , $ ) $ , , $ $ $ $ $ <END> $ $ $ <END> $ <END> $ $ <END> $ <END> <END> , $ <END> ) $ <END> ) $ <END> <END> $ <END> ) , $ <END> <END> $ , $ <END> <END> $ ) , , $ ) , , , , $ ) $ , $ <END> <END> <END> , $ ) , , $ <END> <END> $ <END> <END> <END> <END> ) $ <END> <END> <END> , , $ <END> <END> ) $ $ $ <END> , $ ) , , , $ <END> <END> $ , $ ) = 0 $ <END> $ <END> , $ $ $ <END> , $ $ $ $ <END> <END> , $ <END> { 2 } ) } $ , $ ) , $ ) $ <END> , $ ) $ , $ $ $ <END> <END> , , , $ <END> , , , $ ) $ , $ <END> <END> $ <END> ) , , $ <END> , , , , $ ) , $ ) , $ <END> , $ ) $ <END> <END> , , , , , $ <END> , $ <END> { , , $ <END> { , , $ <END> { , $ <END> , $ <END> , $ <END> , $ ) $ <END> , $ ) $ <END> , $ <END> { 2 } ) $ <END> <END> , , $ <END> , $ ) , $ ) $ , $ ) , $ ) , , , , , , $ <END> { , $ ) $ , $ ) = $ <END> ) $ <END> <END> , $ ) , $ ) $ <END> , , , , $

Loss: 2.9898264408111572:  64%|██████▍   | 1499/2344 [20:45<11:55,  1.18it/s]

Generated: <START> $ \frac { { \partial } } { \partial } } { \partial } } { { { \frac { 1 } { } } } \frac { 1 } { 2 } } } { \mu } } } $ <END> <END> $ <END> <END> <END> <END> <END> } $ $ $ $ $ <END> <END> $ <END> <END> <END> } $ $ <END> $ $ <END> <END> $ <END> $ $ <END> <END> } $ <END> <END> <END> } } $ $ $ <END> <END> <END> <END> } $ $ <END> <END> } } $ <END> $ <END> <END> } } $ $ $ <END> <END> $ $ <END> $ <END> <END> } } $ <END> $ $ <END> $ <END> $ <END> <END> } $ <END> <END> <END> } } $ $ <END> $ $ $ <END> <END> $ <END> $ $ $ <END> <END> <END> } } $ . $ <END> $ $ $ <END> $ <END> $ <END> <END> } $ <END> $ $ $ <END> <END> <END> } } $ <END> <END> } $ <END> <END> } } . $ . $ $ <END> $ $ $ <END> <END> <END> } $ <END> $ <END> $ $ $ <END> <END> } } . $ . $ $ <END> <END> <END> } } $ . $ <END> $ <END> $ <END> $ <END> $ $ $ <END> $ <END> <END> } } $ . $ . $ <END> $ $ <END> <END> } $ $ $ <END> $ $ $ <END> $ <END> $ $ $ $ $ $ $ $ $ <END> <END> <END> <END> <END> <END> <END> <END> } , $ ) } $ <END

Loss: 2.832369804382324:  77%|███████▋  | 1799/2344 [24:53<07:19,  1.24it/s] 

Generated: <START> $ \int _ { \mathrm { = { 2 } = \int _ ^ { \infty { 0 } } { { { { { 2 } } } x ) } { { { { ^ { - \frac } { { { { } } } { { _ } { \mu } } ^ _ x ^ { \mu } } $ <END> { , , , , , , $ , , , $ <END> , $ , , $ , , $ <END> { { } } , $ <END> , , , , , $ <END> , $ , $ , $ , , $ , $ ) , , $ <END> , $ <END> , , , $ <END> { { } } { , $ <END> { , $ , , , $ <END> { { } } } } { { } } , $ , , $ <END> , , $ <END> , $ , $ <END> , , $ , , , $ , $ <END> , $ , $ <END> { { } } { , $ ) , $ , , $ <END> , , $ <END> , $ , , , $ <END> { { } , $ , $ , , $ <END> , $ , $ <END> , $ <END> , , , $ , , , , , $ <END> { , , $ , , $ <END> , , , $ , , , $ , , , $ <END> , , , , $ <END> , , , , $ , , $ , $ <END> , $ , , $ , $ <END> { , , , , $ , , , $ <END> { , , , , $ <END> { , , , , $ <END> , , $ <END> { , $ , , $ <END> { , , , $ , , , , , , $ , $ , $ , , , , , $ <END> { , $ , , $ , $ , , $ , $ , $ <END> { { } , , $ <END> , , , , , , $ <END> { , , , , $ <END> { , , $ <END> , $ , , $ , , , , , , , , $ <END> 

Loss: 2.7443127632141113:  81%|████████  | 1904/2344 [26:19<05:55,  1.24it/s]