In [1]:
!pip install -q x-transformers

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m88.7/88.7 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m82.5/82.5 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m103.0/103.0 kB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.6/61.6 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[?25h

# Import

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from dataclasses import dataclass
from typing import List, Tuple
import sympy

from sympy import expand
from sympy import sympify
from sympy import series
from sympy import Symbol, symbols
from sympy import im, I

from tqdm import tqdm

import re

import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence

In [3]:
BOS_IDX = 0
PAD_IDX = 1
UNK_IDX = 2
EOS_IDX = 11

In [4]:
from dataclasses import dataclass, field, fields
from typing import Optional

@dataclass
class Config:
    experiment_name: Optional[str] = "seq2seq_transformer"
    root_dir: Optional[str] = "./"
    device: Optional[str] = "cuda:0"
        
    #training parameters
    epochs: Optional[int] = 100
    seed: Optional[int] = 42
    use_half_precision: Optional[bool] = True

    # scheduler parameters
    scheduler_type: Optional[str] = "cosine_annealing_warm_restart" # multi_step or none
    T_0: Optional[int] = 10
    T_mult: Optional[int] = 1

    # optimizer parameters
    optimizer_type: Optional[str] = "adam" # sgd or adam
    optimizer_lr: Optional[float] = 0.0001   
    optimizer_momentum: Optional[float] = 0.9
    optimizer_weight_decay: Optional[float] = 0.0001
    optimizer_no_decay: Optional[list] = field(default_factory=list)
    clip_grad_norm: Optional[float] = -1
        
    # Model Parameters
    model_name: Optional[str] = "seq2seq_transformer"
    hybrid: Optional[bool] = True
    embedding_size: Optional[int] = 128
    hidden_dim: Optional[int] = 128
    pff_dim: Optional[int] = 512
    nhead: Optional[int] = 8
    num_encoder_layers: Optional[int] = 2
    num_decoder_layers: Optional[int] = 6
    dropout: Optional[int] = 0.2
    pretrain: Optional[bool] = False
    input_emb_size: Optional[int] = 64
    max_input_points: Optional[int] = 210
    src_vocab_size: Optional[int] = 32
    tgt_vocab_size: Optional[int] = 22

    # Criterion
    criterion: Optional[str] = "cross_entropy"
        
    def print_config(self):
        print("="*50+"\nConfig\n"+"="*50)
        for field in fields(self):
            print(field.name.ljust(30), getattr(self, field.name))
        print("="*50)

    def save(self, root_dir):
        path = root_dir + "/config.txt"
        with open(path, "w") as f:
            f.write("="*50+"\nConfig\n"+"="*50 + "\n")
            for field in fields(self):
                f.write(field.name.ljust(30) + ": " + str(getattr(self, field.name)) + "\n")
            f.write("="*50)   

# Load and clean

In [5]:
pth = '/kaggle/input/data-no-dup/final_data_6519.csv'

df = pd.read_csv(pth)

In [6]:
def spt(i, order=4):
    # print(df['expansion'].iloc[i])
    # return df['expansion'].iloc[i].split('x**')
    expr = sympify(df['expansion'].iloc[i]).evalf(4).as_poly()
    # print(expr)
    # print(expr.free_symbols)
    coeffs = expr.all_coeffs()[::-1]
    if len(coeffs) < order + 1:
        coeffs += [0]*(order - len(coeffs) + 1)
    return coeffs

In [7]:
df_clean = pd.DataFrame(columns = df.columns)
coeffs = []
for idx, row in tqdm(df.iterrows(),total=len(df)):
    try:
        coeff = spt(idx)
        coeffs.append(coeff)
        df_clean.loc[len(df_clean)] = row
    except Exception as ex:
        continue

100%|██████████| 6519/6519 [00:22<00:00, 289.98it/s]


In [8]:
df_clean['coefficients'] = coeffs

In [9]:
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

def split_data(data, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
    train_data, temp_data = train_test_split(data, train_size=train_ratio, random_state=42)
    val_size = val_ratio / (val_ratio + test_ratio)
    val_data, test_data = train_test_split(temp_data, train_size=val_size, random_state=42)

    data = {
        'train': train_data,
        'valid': val_data,
        'test': test_data
    }
    return train_data, val_data, test_data

# Preprocessing

In [10]:
def replace_exponent(expr: str) -> str:
    # Step 1: Replace powers **2 to **4 with ^2 to ^4
    expr = re.sub(r'(\b\w+\b)\s*\*\*\s*([2-4])', r'<\1^\2>', expr)
 
    # Step 2: Replace plain variables with variable_1 (not touching already transformed ones)
    # Negative lookbehind for ^ or _, to avoid changing x^2 or x_1
    expr = re.sub(r'(?<![\^_])\bx\b(?![\^_])', '<x^1>', expr)

    return expr

In [11]:
def dec_preproc(input_str, num_token='<NUM>'):
    
    exp = input_str
    e = sympy.Symbol('e')
    expr = str(sympify(exp).evalf(6,subs={'e':sympy.core.numbers.E})).replace(' ','')
    expr = replace_exponent(expr)
    expr_arr = expr.replace('+', ' + ').replace('*',' * ').replace('-',' - ').split(' ')

    expr_mod = []

    def check_float(f):
        try:
            _ = float(f)
            return True
        except (ValueError, TypeError):
            return False
    
    for i in expr_arr:
        if i == '':
            continue

        if check_float(i) or check_float(i[:-1]):
            for char in str(i):
                expr_mod.append(char)
        else:
            expr_mod.append(i)
    return expr_mod

# Tokenizers

In [12]:
src_vocab = []

src_vocab += ['<PAD>', '<SOS>', '<UNK>', 'x']
src_vocab += [str(i) for i in range(10)]
src_vocab += ['pi', 's+', 's-', 'E']
src_vocab += ['add', 'mul', 'pow']
src_vocab += ['sin', 'cos', 'tan', 'cot']
src_vocab += ['asin', 'acos', 'atan', 'acot']
src_vocab += ['ln', 'exp']
src_vocab += ['<EOS>']

In [13]:
tgt_vocab = []

tgt_vocab += ['<PAD>', '<SOS>', '<UNK>']
tgt_vocab += [f'<x^{i}>' for i in range(1,5)]
tgt_vocab += [str(i) for i in range(10)]
tgt_vocab += ['*', '+', '-', '.']
tgt_vocab += ['<EOS>']

In [14]:
len(tgt_vocab), len(src_vocab)

(22, 32)

In [15]:
class Tokenizer:
    def __init__(self, vocab):
        # Initialize with default vocabulary if none provided
        self.vocab = vocab
        
        # Create mappings
        self.token_to_idx = {token: idx for idx, token in enumerate(self.vocab)}
        self.idx_to_token = {idx: token for idx, token in enumerate(self.vocab)}
        
        # Special tokens
        self.pad_idx = self.token_to_idx['<PAD>']
        self.sos_idx = self.token_to_idx['<SOS>']
        self.eos_idx = self.token_to_idx['<EOS>']
        self.unk_idx = self.token_to_idx['<UNK>']
    
    def encode(self, tokens, add_special_tokens=True, max_length=None):
        """
        Encode a list of tokens into indices
        
        Args:
            tokens (list): List of tokens to encode
            add_special_tokens (bool): Whether to add SOS and EOS tokens
            max_length (int, optional): Maximum length to pad/truncate to
            
        Returns:
            list: List of token indices
        """
        if add_special_tokens:
            tokens = ['<SOS>'] + tokens + ['<EOS>']
        
        # Convert tokens to indices
        indices = [self.token_to_idx.get(token, self.unk_idx) for token in tokens]
        
        # Handle padding/truncation if max_length specified
        if max_length is not None:
            if len(indices) < max_length:
                # Pad sequence
                indices += [self.pad_idx] * (max_length - len(indices))
            else:
                # Truncate sequence
                indices = indices[:max_length]
        
        return indices
    
    def decode(self, indices, remove_special_tokens=True):
        """
        Decode a list of indices back into tokens
        
        Args:
            indices (list): List of indices to decode
            remove_special_tokens (bool): Whether to remove special tokens
            
        Returns:
            list: List of decoded tokens
        """
        # Convert indices to tokens
        tokens = [self.idx_to_token.get(idx, '<UNK>') for idx in indices]
        
        # Remove special tokens if requested
        if remove_special_tokens:
            tokens = [token for token in tokens if token not in ['<PAD>', '<SOS>', '<EOS>']]
        
        return tokens
    
    def batch_encode(self, batch_tokens, add_special_tokens=True, max_length=None, return_tensors=False):
        """
        Encode a batch of token lists
        
        Args:
            batch_tokens (list): List of token lists to encode
            add_special_tokens (bool): Whether to add SOS and EOS tokens
            max_length (int, optional): Maximum length to pad/truncate to
            return_tensors (bool): Whether to return PyTorch tensors
            
        Returns:
            list or torch.Tensor: Batch of encoded sequences
        """
        encoded_batch = [self.encode(tokens, add_special_tokens, max_length) for tokens in batch_tokens]
        
        # If max_length not specified, pad to the longest sequence in batch
        if max_length is None and encoded_batch:
            max_len = max(len(seq) for seq in encoded_batch)
            encoded_batch = [seq + [self.pad_idx] * (max_len - len(seq)) for seq in encoded_batch]
        
        # Convert to tensors if requested
        if return_tensors:
            import torch
            encoded_batch = torch.tensor(encoded_batch, dtype=torch.long)
        
        return encoded_batch
    
    def save_vocabulary(self, filepath):
        """Save vocabulary to a file"""
        with open(filepath, 'w') as f:
            for token in self.vocab:
                f.write(f"{token}\n")
        print(f"Vocabulary saved to {filepath}")
    
    @classmethod
    def from_file(cls, filepath):
        """Load vocabulary from a file"""
        with open(filepath, 'r') as f:
            vocab = [line.strip() for line in f]
        return cls(vocab)
    
    def __len__(self):
        """Return the size of the vocabulary"""
        return len(self.vocab)
    
    def add_tokens(self, new_tokens):
        """
        Add new tokens to the vocabulary
        
        Args:
            new_tokens (list): List of tokens to add
            
        Returns:
            int: Number of tokens added
        """
        tokens_added = 0
        for token in new_tokens:
            if token not in self.vocab:
                self.vocab.append(token)
                self.token_to_idx[token] = len(self.vocab) - 1
                self.idx_to_token[len(self.vocab) - 1] = token
                tokens_added += 1
        
        return tokens_added

# data handling

In [16]:
import torch
from torch.utils.data import Dataset, DataLoader

class Taylor_data(Dataset):
    
    def __init__(self, df, src_vocab, tgt_vocab):
        self.enc_tokenizer = Tokenizer(src_vocab)
        self.dec_tokenizer = Tokenizer(tgt_vocab)

        self.df = df
        
        self.src = []
        self.tgt = []
        self.build_dataset()
    
    def build_dataset(self):
        for idx, row in tqdm(self.df.iterrows(), total=len(self.df)):
            if len(row['prefix']) > 200:
                continue
                
            tgt_preproc = dec_preproc(row['expansion'])
            src_ids = self.enc_tokenizer.encode(eval(row['prefix']))
            tgt_ids = self.dec_tokenizer.encode(tgt_preproc)
            
            self.src.append(src_ids)
            self.tgt.append(tgt_ids)
        
        print('Built Dataset')
    def __len__(self):
        return len(self.src)

    def __getitem__(self, idx):

        return torch.tensor(self.src[idx]).long(), torch.tensor(self.tgt[idx]).long()

In [17]:
def collate_fn(batch):
    src_batch, tgt_batch, num_batch = [], [], []
    for (src_sample, tgt_sample) in batch:
        
        src_batch.append(src_sample)
        tgt_batch.append(tgt_sample)
    

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX, batch_first=True)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX, batch_first=True)
    return src_batch, tgt_batch

In [18]:
def get_dataloaders(datasets, train_bs, val_bs, test_bs):
    """
    Get data loaders for training, validation, and testing.

    Args:
    - datasets: Dictionary containing train, validation, and test datasets
    - train_bs: Batch size for training
    - val_bs: Batch size for validation
    - test_bs: Batch size for testing

    Returns:
    - dataloaders: Dictionary containing train, validation, and test data loaders
    """
    train_dataloader = DataLoader(datasets['train'], batch_size=train_bs,
                                  shuffle=True, num_workers=2, pin_memory=True, collate_fn=collate_fn)
    val_dataloader = DataLoader(datasets['valid'], batch_size=val_bs,
                                  shuffle=True, num_workers=2, pin_memory=True, collate_fn=collate_fn)
    test_dataloader = DataLoader(datasets['test'], batch_size=test_bs,
                                  shuffle=False, num_workers=2, pin_memory=False, collate_fn=collate_fn)
    
    dataloaders = {
        "train":train_dataloader,
        "test":test_dataloader,
        "valid":val_dataloader
        }
    
    return dataloaders

# Trainer utils

In [19]:
import os
import random

import torch
import numpy as np


class AverageMeter:
    """
    Computes and stores the average and current value
    """

    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def seed_everything(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

def generate_square_subsequent_mask(sz, device):
    mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

def create_mask(src, tgt, device):
    src_seq_len = src.shape[1]
    tgt_seq_len = tgt.shape[1]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device)
    src_mask = torch.zeros((src_seq_len, src_seq_len), device=device).type(torch.bool)

    src_padding_mask = (torch.zeros((src.shape[0], src_seq_len), device=device)).type(torch.bool)
    tgt_padding_mask = (tgt == PAD_IDX)
    tgt_mask = tgt_mask
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

def sequence_accuracy(y_pred, y_true):

    count = 0
    total = len(y_pred)
    for (predicted_tokens, original_tokens) in zip(y_pred, y_true):
        original_tokens = original_tokens.tolist()
        predicted_tokens = predicted_tokens.tolist()
        if original_tokens == predicted_tokens:
            count = count+1

    return count/total

# Model

In [20]:
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import Transformer
import math

# https://github.com/neerajanand321/SYMBA_Pytorch/blob/main/models/seq2seq_transformer.py
class TokenEmbedding(nn.Module):
    ''' helper Module to convert tensor of input indices into corresponding tensor of token embeddings'''
    
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

class PositionalEncoding(nn.Module):
    ''' helper Module that adds positional encoding to the token embedding to introduce a notion of word order.'''
    
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(0)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:, :token_embedding.size(1), :])


class Model(nn.Module):
    '''Seq2Seq Network'''
    
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 input_emb_size: int,
                 max_input_points: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1,):
        super(Model, self).__init__()
        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout,
                                       batch_first=True)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout)

    def forward(self,
                src: Tensor,
                trg: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):
        src_emb = self.src_tok_emb(src)
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))

        
        
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(self.src_tok_emb(src), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.positional_encoding(self.tgt_tok_emb(tgt)), memory, tgt_mask)

# Predictor

In [21]:
import os
import torch


class Predictor:
    """
    Predictor class for generating predictions using a trained model.
    """
    def __init__(self, config):
        """
        Initialize Predictor object.

        Args:
        - config: Configuration object containing model parameters
        """
        self.config = config
        self.device = torch.device(self.config.device)

        # Get the model
        self.model = self.get_model()
        self.model.to(self.device)

        # Load the best checkpoint
        self.logs_dir = os.path.join(self.config.root_dir, self.config.experiment_name)
        path = os.path.join(self.logs_dir, "best_checkpoint.pth")
        self.model.load_state_dict(torch.load(path)["state_dict"])
        
        # Set the model to evaluation mode
        self.model.eval()
        
    def get_model(self):
        if self.config.model_name == "seq2seq_transformer":
            model = Model(num_encoder_layers=self.config.num_encoder_layers,
                          num_decoder_layers=self.config.num_decoder_layers,
                          emb_size=self.config.embedding_size,
                          nhead=self.config.nhead,
                          src_vocab_size=self.config.src_vocab_size,
                          tgt_vocab_size=self.config.tgt_vocab_size,
                          input_emb_size=self.config.input_emb_size,
                          max_input_points=self.config.max_input_points,
                          )

        
        return model
    
    def generate_square_subsequent_mask(self, sz, device):
        mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask
    
    def greedy_decode(self, src, src_mask, max_len, start_symbol, src_padding_mask=None):
        src = src.to(self.device)
        src_mask = src_mask.to(self.device)
        src_padding_mask = src_padding_mask.to(self.device)
        dim = 1

        memory = self.model.encode(src, src_mask)
        memory = memory.to(self.device)
        dim = 1
        ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(self.device)
        for i in range(max_len-1):

            tgt_mask = (self.generate_square_subsequent_mask(ys.size(1), self.device).type(torch.bool)).to(self.device)

            out = self.model.decode(ys, memory, tgt_mask)
            prob = self.model.generator(out[:, -1])

            _, next_word = torch.max(prob, dim=1)
            next_word = next_word.item()

            ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=dim)
            if next_word == EOS_IDX:
                break

        return ys


    def predict(self, x):
        self.model.eval()
        
        if self.config.model_name == "seq2seq_transformer":
            src = x
            num_tokens = src.shape[1]

            src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
            src_padding_mask = torch.zeros(1, num_tokens).type(torch.bool)
            tgt_tokens = self.greedy_decode(src, src_mask, max_len=256, start_symbol=BOS_IDX, src_padding_mask=src_padding_mask).flatten()

            return tgt_tokens
        else:
            ys = torch.ones(1, 1).fill_(BOS_IDX).type(torch.long).to(self.device)
            e_mask = torch.zeros(1, x.shape[1]).type(torch.bool).to(self.device)
            memory = self.model.encoder(x, e_mask)

            for idx in range(1, 256):
                d_mask = torch.triu(torch.full((ys.size(1), ys.size(1)), float('-inf')), diagonal=1).to(self.device)
                d_out = self.model.decoder(ys, memory, e_mask, d_mask)

                prob = self.model.generator(d_out[:, -1])
                _, next_word = torch.max(prob, dim=1)
                next_word = next_word.item()
                ys = torch.cat([ys, torch.ones(1, 1).type_as(x.data).fill_(next_word)], dim=1)
                if next_word == EOS_IDX:
                    break

            return ys.flatten()

# Trainer

In [22]:
import os

import torch
import numpy as np
import pandas as pd
from tqdm import tqdm


class Trainer:
    """
    Trainer class for training and evaluating a PyTorch model.
    """
    def __init__(self, config, dataloaders):
        """
        Initialize Trainer object.

        Args:
        - config: Configuration object containing training parameters
        - dataloaders: Dictionary containing data loaders for train, validation, and test sets
        """
        self.config = config
        self.device = torch.device(self.config.device)
        self.dataloaders = dataloaders

        seed_everything(self.config.seed)

        self.scaler = torch.cuda.amp.GradScaler()
        if self.config.use_half_precision:
            self.dtype = torch.float16
        else:
            self.dtype = torch.float32

        # Initialize model, optimizer, scheduler, and criterion
        self.model = self.get_model()
        self.model.to(self.device)
        self.optimizer = self.get_optimizer()
        self.scheduler = self.get_scheduler()
        self.criterion = self.get_criterion()

        # Initialize training-related variables
        self.current_epoch = 0
        self.best_accuracy = -1
        self.best_val_loss = 1e6
        self.train_loss_list = []
        self.valid_loss_list = []
        self.valid_accuracy_tok_list = []

        # Create directory for saving logs
        self.logs_dir = os.path.join(self.config.root_dir, self.config.experiment_name)
        os.makedirs(self.logs_dir, exist_ok=True)

    def get_model(self):
        """
        Initialize and return the model based on the configuration.
        """
        if self.config.model_name == "seq2seq_transformer":
            model = Model(num_encoder_layers=self.config.num_encoder_layers,
                          num_decoder_layers=self.config.num_decoder_layers,
                          emb_size=self.config.embedding_size,
                          nhead=self.config.nhead,
                          src_vocab_size=self.config.src_vocab_size,
                          tgt_vocab_size=self.config.tgt_vocab_size,
                          input_emb_size=self.config.input_emb_size,
                          max_input_points=self.config.max_input_points,
                          )

        return model

    def get_optimizer(self):
        """
        Initialize and return the optimizer based on the configuration.
        """
        optimizer_parameters = self.model.parameters()

        if self.config.optimizer_type == "sgd":
            optimizer = torch.optim.SGD(optimizer_parameters, lr=self.config.optimizer_lr, momentum=self.config.optimizer_momentum,)
        elif self.config.optimizer_type == "adam":
            optimizer = torch.optim.Adam(optimizer_parameters, lr=self.config.optimizer_lr, eps=1e-8, weight_decay=self.config.optimizer_weight_decay)
        elif self.config.optimizer_type == "adamw":
            optimizer = torch.optim.AdamW(optimizer_parameters, lr=self.config.optimizer_lr, eps=1e-8, weight_decay=self.config.optimizer_weight_decay)
        else:
            raise NotImplementedError
        
        return optimizer
    
    def get_scheduler(self):
        """
        Initialize and return the learning rate scheduler based on the configuration.
        """
        if self.config.scheduler_type == "multi_step":
            scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=self.config.scheduler_milestones, gamma=self.config.scheduler_gamma)
        elif self.config.scheduler_type == "reduce_lr_on_plateau":
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', patience=2)
        elif self.config.scheduler_type == "cosine_annealing_warm_restart":
            scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer, self.config.T_0, self.config.T_mult)
        elif self.config.scheduler_type == "none":
            scheduler = None
        else:
            raise NotImplementedError
        
        return scheduler

    
    def get_criterion(self):
        """
        Initialize and return the loss function based on the configuration.
        """
        if self.config.criterion == "cross_entropy":
            criterion = torch.nn.CrossEntropyLoss()
        else:
            raise NotImplementedError
        
        return criterion

    def train_one_epoch(self):
        """
        Train the model for one epoch.
        """
        self.model.train()
        pbar = tqdm(self.dataloaders['train'], total=len(self.dataloaders['train']))
        pbar.set_description(f"[{self.current_epoch+1}/{self.config.epochs}] Train")
        running_loss = AverageMeter()
        for src, tgt in pbar:
            src = src.to(self.device)
            tgt = tgt.to(self.device)

            bs = src.size(0)

            with torch.autocast(device_type='cuda', dtype=self.dtype):
                if self.config.model_name == "seq2seq_transformer":
                    src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt[:, :-1], self.device)
                    logits = self.model(src, tgt[:, :-1], src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)
                    loss = self.criterion(logits.reshape(-1, logits.shape[-1]), tgt[:, 1:].reshape(-1))
                else:
                    logits = self.model(src, tgt[:, :-1])
                    loss = self.criterion(logits.reshape(-1, logits.shape[-1]), tgt[:, 1:].reshape(-1))
                
            running_loss.update(loss.item(), bs)
            pbar.set_postfix(loss=running_loss.avg)
            
            self.optimizer.zero_grad()
            self.scaler.scale(loss).backward()

            if self.config.clip_grad_norm > 0:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.clip_grad_norm)
            self.scaler.step(self.optimizer)
            self.scaler.update()

        return running_loss.avg

    def evaluate(self, phase):
        """
        Evaluate the model on validation or test data.

        Args:
        - phase: Phase of evaluation, either "valid" or "test".

        Returns:
        - Tuple containing average token accuracy and average loss.
        """
        self.model.eval()
        
        pbar = tqdm(self.dataloaders[phase], total=len(self.dataloaders[phase]))
        pbar.set_description(f"[{self.current_epoch+1}/{self.config.epochs}] {phase.capitalize()}")
        running_loss = AverageMeter()
        running_acc_tok = AverageMeter()
        
        
        for src, tgt in pbar:
            src = src.to(self.device)
            tgt = tgt.to(self.device)
            bs = src.size(0)
            
            with torch.autocast(device_type='cuda', dtype=self.dtype):
                if self.config.model_name == "seq2seq_transformer":
                    with torch.no_grad():
                        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt[:, :-1], self.device)
                        logits = self.model(src, tgt[:, :-1], src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)
                        loss = self.criterion(logits.reshape(-1, logits.shape[-1]), tgt[:, 1:].reshape(-1))
                else:
                    with torch.no_grad():
                        logits = self.model(src, tgt[:, :-1])
                        loss = self.criterion(logits.reshape(-1, logits.shape[-1]), tgt[:, 1:].reshape(-1))

            y_pred = torch.argmax(logits.reshape(-1, logits.shape[-1]), 1)
            correct = (y_pred == tgt[:, 1:].reshape(-1)).cpu().numpy().mean()
            
            running_loss.update(loss.item(), bs)
            running_acc_tok.update(correct, bs)
            
        return running_acc_tok.avg, running_loss.avg

    def train(self):
        """
        Main training loop.
        """
        start_epoch = self.current_epoch
        for self.current_epoch in range(start_epoch, self.config.epochs):
            training_loss = self.train_one_epoch() 
            valid_accuracy_tok, valid_loss = self.evaluate("valid")
            
            self.train_loss_list.append(round(training_loss, 7))
            self.valid_loss_list.append(round(valid_loss, 7))
            self.valid_accuracy_tok_list.append(round(valid_accuracy_tok, 7))
            
            if self.scheduler == "multi_step":
                self.scheduler.step()
            elif self.scheduler == "reduce_lr_on_plateau":
                self.scheduler.step(valid_loss)
                
            if valid_loss<self.best_val_loss:
                self.best_val_loss = valid_loss

            self.save_model("last_checkpoint.pth")

            if valid_accuracy_tok > self.best_accuracy:
                print(f"==> Best Accuracy improved to {round(valid_accuracy_tok, 7)} from {self.best_accuracy}")
                self.best_accuracy = round(valid_accuracy_tok, 7)
                self.save_model("best_checkpoint.pth")
            
            self.log_results()

        
    def save_model(self, file_name):
        """
        Save model checkpoints.
        """
        state_dict = self.model.state_dict()
        torch.save({
                "epoch": self.current_epoch + 1,
                "state_dict": state_dict,
                'optimizer': self.optimizer.state_dict(),
                "train_loss_list": self.train_loss_list,
                "valid_loss_list": self.valid_loss_list,
                "valid_accuracy_tok_list": self.valid_accuracy_tok_list,
            }, os.path.join(self.logs_dir, file_name))

    def log_results(self):
        """
        Log training results to a CSV file.
        """
        data_list = [self.train_loss_list, self.valid_loss_list, self.valid_accuracy_tok_list]
        column_list = ['train_losses', 'valid_losses', 'token_valid_accuracy']
        
        df_data = np.array(data_list).T
        df = pd.DataFrame(df_data, columns=column_list)
        df.to_csv(os.path.join(self.logs_dir, "logs.csv"))
        
    def test_seq_acc(self):
        """
        Evaluate model's sequence accuracy on test data.
        """
        file = os.path.join(self.logs_dir, "best_checkpoint.pth")
        state_dict = torch.load(file, map_location=self.device)['state_dict']
        self.model.load_state_dict(state_dict)
        
        test_accuracy_tok, _ = self.evaluate("test")
        
        predictor = Predictor(self.config)
        
        print("Calculating Sequence Accuracy for predictions")
        pbar = tqdm(self.dataloaders["test"], total=len(self.dataloaders["test"]))
        pbar.set_description(f"Test")
        
        y_preds = []
        y_true = []
        for src, tgt in pbar:
            src = src.to(self.device)
            tgt = tgt.numpy()
            bs = src.size(0)
            y_pred = predictor.predict(src[0].unsqueeze(0)) #only one example from each batch
            y_preds.append(y_pred.cpu().numpy())
            y_true.append(np.trim_zeros(tgt[0]))
        print(y_preds[1], y_true[1])
        test_accuracy_seq = sequence_accuracy(y_true, y_preds)
        f= open(os.path.join(self.logs_dir, "score.txt"),"w+")
        f.write(f"Token Accuracy = {(round(test_accuracy_tok, 7))}\n")
        f.write(f"Sequence Accuracy = {(round(test_accuracy_seq, 7))}\n")
        f.close()
        print(f"Test Accuracy: {round(test_accuracy_tok, 7)} | Valid Accuracy: {self.best_accuracy}") 
        print(f"Test Sequence Accuracy: {test_accuracy_seq}")


In [23]:
# def main():
df_train, df_valid, df_test = split_data(df_clean)
datasets = {
    'train': Taylor_data(df_train, src_vocab, tgt_vocab),
    'valid': Taylor_data(df_valid, src_vocab, tgt_vocab),
    'test': Taylor_data(df_test, src_vocab, tgt_vocab)
}
# dataloaders = get_dataloaders(datasets, )

100%|██████████| 4064/4064 [00:11<00:00, 357.71it/s]


Built Dataset


100%|██████████| 871/871 [00:02<00:00, 385.00it/s]


Built Dataset


100%|██████████| 872/872 [00:02<00:00, 368.93it/s]

Built Dataset





In [24]:
dataloaders = get_dataloaders(datasets, 16, 128, 1)

In [25]:
BOS_IDX = 1
PAD_IDX = 0
UNK_IDX = 2
EOS_IDX = 21

In [26]:
config = Config()
trainer = Trainer(config, dataloaders)

  self.scaler = torch.cuda.amp.GradScaler()


In [27]:
trainer.train()

[1/100] Train: 100%|██████████| 231/231 [00:09<00:00, 23.99it/s, loss=1.04]
[1/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 26.97it/s]


==> Best Accuracy improved to 0.814773 from -1


[2/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.96it/s, loss=0.642]
[2/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 34.84it/s]


==> Best Accuracy improved to 0.8291174 from 0.814773


[3/100] Train: 100%|██████████| 231/231 [00:08<00:00, 26.16it/s, loss=0.579]
[3/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 33.89it/s]


==> Best Accuracy improved to 0.8474968 from 0.8291174


[4/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.68it/s, loss=0.545]
[4/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 34.97it/s]


==> Best Accuracy improved to 0.8519207 from 0.8474968


[5/100] Train: 100%|██████████| 231/231 [00:08<00:00, 26.19it/s, loss=0.515]
[5/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.31it/s]


==> Best Accuracy improved to 0.8610919 from 0.8519207


[6/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.94it/s, loss=0.494]
[6/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 36.21it/s]
[7/100] Train: 100%|██████████| 231/231 [00:09<00:00, 25.62it/s, loss=0.49] 
[7/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 34.11it/s]


==> Best Accuracy improved to 0.8659865 from 0.8610919


[8/100] Train: 100%|██████████| 231/231 [00:08<00:00, 26.16it/s, loss=0.469]
[8/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.34it/s]


==> Best Accuracy improved to 0.8687012 from 0.8659865


[9/100] Train: 100%|██████████| 231/231 [00:08<00:00, 26.05it/s, loss=0.458]
[9/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.69it/s]
[10/100] Train: 100%|██████████| 231/231 [00:08<00:00, 26.23it/s, loss=0.453]
[10/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 34.98it/s]


==> Best Accuracy improved to 0.8731965 from 0.8687012


[11/100] Train: 100%|██████████| 231/231 [00:09<00:00, 25.54it/s, loss=0.439]
[11/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.59it/s]


==> Best Accuracy improved to 0.8763033 from 0.8731965


[12/100] Train: 100%|██████████| 231/231 [00:08<00:00, 26.06it/s, loss=0.435]
[12/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 34.43it/s]


==> Best Accuracy improved to 0.8785771 from 0.8763033


[13/100] Train: 100%|██████████| 231/231 [00:08<00:00, 26.28it/s, loss=0.438]
[13/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 30.02it/s]
[14/100] Train: 100%|██████████| 231/231 [00:08<00:00, 26.06it/s, loss=0.423]
[14/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 34.92it/s]
[15/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.86it/s, loss=0.418]
[15/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 34.05it/s]
[16/100] Train: 100%|██████████| 231/231 [00:08<00:00, 26.09it/s, loss=0.416]
[16/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.54it/s]


==> Best Accuracy improved to 0.8824905 from 0.8785771


[17/100] Train: 100%|██████████| 231/231 [00:08<00:00, 26.17it/s, loss=0.407]
[17/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 33.81it/s]
[18/100] Train: 100%|██████████| 231/231 [00:09<00:00, 25.66it/s, loss=0.407]
[18/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 33.80it/s]


==> Best Accuracy improved to 0.8854014 from 0.8824905


[19/100] Train: 100%|██████████| 231/231 [00:08<00:00, 26.07it/s, loss=0.405]
[19/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 34.58it/s]
[20/100] Train: 100%|██████████| 231/231 [00:08<00:00, 26.11it/s, loss=0.401]
[20/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.93it/s]


==> Best Accuracy improved to 0.8861566 from 0.8854014


[21/100] Train: 100%|██████████| 231/231 [00:09<00:00, 25.63it/s, loss=0.397]
[21/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 34.74it/s]


==> Best Accuracy improved to 0.8865182 from 0.8861566


[22/100] Train: 100%|██████████| 231/231 [00:08<00:00, 26.13it/s, loss=0.394]
[22/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.22it/s]
[23/100] Train: 100%|██████████| 231/231 [00:08<00:00, 26.08it/s, loss=0.387]
[23/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 34.96it/s]
[24/100] Train: 100%|██████████| 231/231 [00:08<00:00, 26.11it/s, loss=0.378]
[24/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 33.85it/s]
[25/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.70it/s, loss=0.382]
[25/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.50it/s]
[26/100] Train: 100%|██████████| 231/231 [00:08<00:00, 26.46it/s, loss=0.376]
[26/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 34.63it/s]


==> Best Accuracy improved to 0.8880329 from 0.8865182


[27/100] Train: 100%|██████████| 231/231 [00:08<00:00, 26.23it/s, loss=0.374]
[27/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 36.11it/s]


==> Best Accuracy improved to 0.8901149 from 0.8880329


[28/100] Train: 100%|██████████| 231/231 [00:09<00:00, 25.38it/s, loss=0.371]
[28/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.83it/s]
[29/100] Train: 100%|██████████| 231/231 [00:08<00:00, 26.07it/s, loss=0.371]
[29/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.59it/s]


==> Best Accuracy improved to 0.8904302 from 0.8901149


[30/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.88it/s, loss=0.362]
[30/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.75it/s]


==> Best Accuracy improved to 0.8925449 from 0.8904302


[31/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.98it/s, loss=0.361]
[31/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 30.82it/s]
[32/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.98it/s, loss=0.362]
[32/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.27it/s]


==> Best Accuracy improved to 0.8927173 from 0.8925449


[33/100] Train: 100%|██████████| 231/231 [00:08<00:00, 26.07it/s, loss=0.359]
[33/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.22it/s]
[34/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.95it/s, loss=0.353]
[34/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.39it/s]
[35/100] Train: 100%|██████████| 231/231 [00:09<00:00, 25.56it/s, loss=0.351]
[35/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 34.51it/s]


==> Best Accuracy improved to 0.8938555 from 0.8927173


[36/100] Train: 100%|██████████| 231/231 [00:08<00:00, 26.29it/s, loss=0.349]
[36/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 34.92it/s]
[37/100] Train: 100%|██████████| 231/231 [00:08<00:00, 26.17it/s, loss=0.345]
[37/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.32it/s]


==> Best Accuracy improved to 0.8952365 from 0.8938555


[38/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.82it/s, loss=0.344]
[38/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.84it/s]


==> Best Accuracy improved to 0.8953019 from 0.8952365


[39/100] Train: 100%|██████████| 231/231 [00:08<00:00, 26.39it/s, loss=0.34] 
[39/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 32.02it/s]


==> Best Accuracy improved to 0.8963053 from 0.8953019


[40/100] Train: 100%|██████████| 231/231 [00:08<00:00, 26.40it/s, loss=0.339]
[40/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.59it/s]
[41/100] Train: 100%|██████████| 231/231 [00:08<00:00, 26.01it/s, loss=0.339]
[41/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.01it/s]
[42/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.92it/s, loss=0.34] 
[42/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.40it/s]
[43/100] Train: 100%|██████████| 231/231 [00:08<00:00, 26.29it/s, loss=0.335]
[43/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.47it/s]


==> Best Accuracy improved to 0.8974132 from 0.8963053


[44/100] Train: 100%|██████████| 231/231 [00:08<00:00, 26.13it/s, loss=0.332]
[44/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 34.17it/s]


==> Best Accuracy improved to 0.8983773 from 0.8974132


[45/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.72it/s, loss=0.33] 
[45/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 34.38it/s]
[46/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.83it/s, loss=0.326]
[46/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.56it/s]
[47/100] Train: 100%|██████████| 231/231 [00:08<00:00, 26.18it/s, loss=0.331]
[47/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 34.19it/s]
[48/100] Train: 100%|██████████| 231/231 [00:09<00:00, 25.21it/s, loss=0.323]
[48/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.42it/s]
[49/100] Train: 100%|██████████| 231/231 [00:09<00:00, 25.48it/s, loss=0.323]
[49/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 34.69it/s]


==> Best Accuracy improved to 0.899369 from 0.8983773


[50/100] Train: 100%|██████████| 231/231 [00:09<00:00, 25.54it/s, loss=0.32] 
[50/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.42it/s]
[51/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.76it/s, loss=0.32] 
[51/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.39it/s]
[52/100] Train: 100%|██████████| 231/231 [00:09<00:00, 25.20it/s, loss=0.317]
[52/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.60it/s]
[53/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.88it/s, loss=0.314]
[53/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 32.40it/s]
[54/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.73it/s, loss=0.316]
[54/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 34.14it/s]


==> Best Accuracy improved to 0.900502 from 0.899369


[55/100] Train: 100%|██████████| 231/231 [00:09<00:00, 25.53it/s, loss=0.312]
[55/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.31it/s]
[56/100] Train: 100%|██████████| 231/231 [00:09<00:00, 25.48it/s, loss=0.311]
[56/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 34.86it/s]
[57/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.85it/s, loss=0.312]
[57/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 34.58it/s]


==> Best Accuracy improved to 0.9008622 from 0.900502


[58/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.77it/s, loss=0.308]
[58/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 33.50it/s]
[59/100] Train: 100%|██████████| 231/231 [00:09<00:00, 25.13it/s, loss=0.306]
[59/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 33.61it/s]


==> Best Accuracy improved to 0.901219 from 0.9008622


[60/100] Train: 100%|██████████| 231/231 [00:09<00:00, 25.57it/s, loss=0.301]
[60/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.22it/s]
[61/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.72it/s, loss=0.301]
[61/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 34.32it/s]


==> Best Accuracy improved to 0.9013636 from 0.901219


[62/100] Train: 100%|██████████| 231/231 [00:09<00:00, 25.15it/s, loss=0.302]
[62/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 33.98it/s]


==> Best Accuracy improved to 0.9014002 from 0.9013636


[63/100] Train: 100%|██████████| 231/231 [00:09<00:00, 25.60it/s, loss=0.299]
[63/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 33.34it/s]


==> Best Accuracy improved to 0.9020313 from 0.9014002


[64/100] Train: 100%|██████████| 231/231 [00:08<00:00, 26.04it/s, loss=0.299]
[64/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 34.83it/s]
[65/100] Train: 100%|██████████| 231/231 [00:08<00:00, 26.04it/s, loss=0.296]
[65/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 33.77it/s]
[66/100] Train: 100%|██████████| 231/231 [00:09<00:00, 25.51it/s, loss=0.294]
[66/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 33.69it/s]


==> Best Accuracy improved to 0.9029605 from 0.9020313


[67/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.90it/s, loss=0.294]
[67/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 33.87it/s]
[68/100] Train: 100%|██████████| 231/231 [00:08<00:00, 26.11it/s, loss=0.291]
[68/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 34.87it/s]
[69/100] Train: 100%|██████████| 231/231 [00:09<00:00, 25.64it/s, loss=0.288]
[69/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 33.84it/s]


==> Best Accuracy improved to 0.9040429 from 0.9029605


[70/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.94it/s, loss=0.287]
[70/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.89it/s]
[71/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.99it/s, loss=0.285]
[71/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 34.43it/s]


==> Best Accuracy improved to 0.9052028 from 0.9040429


[72/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.99it/s, loss=0.285]
[72/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.17it/s]
[73/100] Train: 100%|██████████| 231/231 [00:09<00:00, 25.62it/s, loss=0.281]
[73/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 34.86it/s]
[74/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.97it/s, loss=0.284]
[74/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.60it/s]
[75/100] Train: 100%|██████████| 231/231 [00:08<00:00, 26.13it/s, loss=0.281]
[75/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 34.58it/s]
[76/100] Train: 100%|██████████| 231/231 [00:09<00:00, 25.43it/s, loss=0.278]
[76/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.36it/s]


==> Best Accuracy improved to 0.9054257 from 0.9052028


[77/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.94it/s, loss=0.276]
[77/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 36.19it/s]
[78/100] Train: 100%|██████████| 231/231 [00:08<00:00, 26.10it/s, loss=0.273]
[78/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 33.89it/s]
[79/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.88it/s, loss=0.273]
[79/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 32.23it/s]


==> Best Accuracy improved to 0.9066054 from 0.9054257


[80/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.69it/s, loss=0.27] 
[80/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 34.01it/s]
[81/100] Train: 100%|██████████| 231/231 [00:09<00:00, 25.50it/s, loss=0.271]
[81/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 34.58it/s]
[82/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.90it/s, loss=0.27] 
[82/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 34.27it/s]
[83/100] Train: 100%|██████████| 231/231 [00:09<00:00, 25.47it/s, loss=0.266]
[83/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.68it/s]
[84/100] Train: 100%|██████████| 231/231 [00:09<00:00, 25.38it/s, loss=0.27] 
[84/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 33.77it/s]
[85/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.97it/s, loss=0.264]
[85/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 34.77it/s]
[86/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.71it/s, loss=0.264]
[86/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.79it/s]
[87/100] Train: 100%

==> Best Accuracy improved to 0.9072825 from 0.9066054


[88/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.92it/s, loss=0.262]
[88/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 34.40it/s]


==> Best Accuracy improved to 0.9074697 from 0.9072825


[89/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.97it/s, loss=0.258]
[89/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.60it/s]
[90/100] Train: 100%|██████████| 231/231 [00:09<00:00, 25.35it/s, loss=0.255]
[90/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 34.82it/s]
[91/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.82it/s, loss=0.254]
[91/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.47it/s]


==> Best Accuracy improved to 0.9079218 from 0.9074697


[92/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.96it/s, loss=0.257]
[92/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 34.40it/s]


==> Best Accuracy improved to 0.9084729 from 0.9079218


[93/100] Train: 100%|██████████| 231/231 [00:09<00:00, 25.64it/s, loss=0.253]
[93/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 34.68it/s]
[94/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.87it/s, loss=0.25] 
[94/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.18it/s]
[95/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.89it/s, loss=0.247]
[95/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 34.94it/s]
[96/100] Train: 100%|██████████| 231/231 [00:09<00:00, 25.66it/s, loss=0.247]
[96/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 33.61it/s]
[97/100] Train: 100%|██████████| 231/231 [00:09<00:00, 25.47it/s, loss=0.246]
[97/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 32.86it/s]
[98/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.91it/s, loss=0.244]
[98/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.01it/s]
[99/100] Train: 100%|██████████| 231/231 [00:08<00:00, 25.78it/s, loss=0.241]
[99/100] Valid: 100%|██████████| 7/7 [00:00<00:00, 35.47it/s]
[100/100] Train: 100

In [28]:
len(tgt_vocab)

22

In [29]:
trainer.test_seq_acc()

  state_dict = torch.load(file, map_location=self.device)['state_dict']
[100/100] Test: 100%|██████████| 786/786 [00:10<00:00, 76.63it/s]
  self.model.load_state_dict(torch.load(path)["state_dict"])


Calculating Sequence Accuracy for predictions


Test: 100%|██████████| 786/786 [01:57<00:00,  6.71it/s]

[ 1  7 20  7 11  8 13 13 13 14 17  6 19  7 20  8 13 13 13 13 14 17  5 18
  7 20 12 17  4 19  3 19  8 20  7 21] [ 1  7 20  7 11  8 13 13 13 14 17  6 18  7 20  8 13 13 13 13 14 17  5 19
  7 20 12 17  4 19  3 18  8 20  7 21]
Test Accuracy: 0.7534554 | Valid Accuracy: 0.9084729
Test Sequence Accuracy: 0.11323155216284987



