# Imports & Installations

In [1]:
import numpy as np
import pandas as pd
import torch
import re
import matplotlib.pyplot as plt

In [2]:
!pip install -q x-transformers

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m88.2/88.2 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m80.6/80.6 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m103.0/103.0 kB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.6/61.6 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[?25h

# Read Data

In [3]:
def read_txt(path):
    with open(path,'r', encoding='utf-8') as f:
        lines = f.read().split('\n')
    text_pairs = []
    mx = 0 
    mn = 100
    x = []
    for _line in lines:
        arr = _line.split(':')
        if len(arr) == 1:
            continue
        text_pairs.append({'Event Type': arr[0], \
                          'Feynman Diagram': arr[1],   \
                           'Amplitude': arr[-2],        \
                           'Squared Amplitude': arr[-1] \
                          })
    return text_pairs

final_pairs = [read_txt(
            f'/kaggle/input/squared-amplitudes-test-data/SYMBA - Test Data/QED-2-to-2-diag-TreeLevel-{i}.txt')
            for i in range(10)]
final_pairs = [xx for x in final_pairs for xx in x]

# df = pd.DataFrame(final_pairs,columns=['Event Type','Feynman Diagram', 'Amplitude', 'Squared Amplitude'],dtype=['str','str','str','str'])
df = pd.DataFrame(final_pairs, columns=['Event Type', 'Feynman Diagram', 'Amplitude', 'Squared Amplitude'])
df = df.astype({'Event Type': 'string', 'Feynman Diagram': 'string', 'Amplitude': 'string', 'Squared Amplitude': 'string'})

from sklearn.model_selection import train_test_split

train_df, temp_df = train_test_split(df, test_size=0.2, random_state=42, shuffle=True)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42, shuffle=True)

print(f"Train size: {len(train_df)}, Validation size: {len(val_df)}, Test size: {len(test_df)}")

train_df.to_csv('train.csv')
val_df.to_csv('val.csv')
test_df.to_csv('test.csv')

Train size: 12441, Validation size: 1555, Test size: 1556


# Vocab Class

In [4]:
def reporthook(t):
    """
    https://github.com/tqdm/tqdm.
    """
    last_b = [0]

    def inner(b=1, bsize=1, tsize=None):
        """
        b: int, optional
        Number of blocks just transferred [default: 1].
        bsize: int, optional
        Size of each block (in tqdm units) [default: 1].
        tsize: int, optional
        Total size (in tqdm units). If [default: None] remains unchanged.
        """
        if tsize is not None:
            t.total = tsize
        t.update((b - last_b[0]) * bsize)
        last_b[0] = b

    return inner

In [5]:
import torch
import logging
import os
import zipfile
import gzip
from urllib.request import urlretrieve
from tqdm import tqdm
import tarfile
from functools import partial

logger = logging.getLogger(__name__)


def _infer_shape(f):
    num_lines, vector_dim = 0, None
    for line in f:
        if vector_dim is None:
            row = line.rstrip().split(b" ")
            vector = row[1:]
            # Assuming word, [vector] format
            if len(vector) > 2:
                # The header present in some (w2v) formats contains two elements.
                vector_dim = len(vector)
                num_lines += 1  # First element read
        else:
            num_lines += 1
    f.seek(0)
    return num_lines, vector_dim


class Vectors(object):

    def __init__(self, name, cache=None,
                 url=None, unk_init=None, max_vectors=None):
        """
        Args:

            name: name of the file that contains the vectors
            cache: directory for cached vectors
            url: url for download if vectors not found in cache
            unk_init (callback): by default, initialize out-of-vocabulary word vectors
                to zero vectors; can be any function that takes in a Tensor and returns a Tensor of the same size
            max_vectors (int): this can be used to limit the number of
                pre-trained vectors loaded.
                Most pre-trained vector sets are sorted
                in the descending order of word frequency.
                Thus, in situations where the entire set doesn't fit in memory,
                or is not needed for another reason, passing `max_vectors`
                can limit the size of the loaded set.
        """

        cache = '.vector_cache' if cache is None else cache
        self.itos = None
        self.stoi = None
        self.vectors = None
        self.dim = None
        self.unk_init = torch.Tensor.zero_ if unk_init is None else unk_init
        self.cache(name, cache, url=url, max_vectors=max_vectors)

    def __getitem__(self, token):
        if token in self.stoi:
            return self.vectors[self.stoi[token]]
        else:
            return self.unk_init(torch.Tensor(self.dim))

    def cache(self, name, cache, url=None, max_vectors=None):
        import ssl
        ssl._create_default_https_context = ssl._create_unverified_context
        if os.path.isfile(name):
            path = name
            if max_vectors:
                file_suffix = '_{}.pt'.format(max_vectors)
            else:
                file_suffix = '.pt'
            path_pt = os.path.join(cache, os.path.basename(name)) + file_suffix
        else:
            path = os.path.join(cache, name)
            if max_vectors:
                file_suffix = '_{}.pt'.format(max_vectors)
            else:
                file_suffix = '.pt'
            path_pt = path + file_suffix

        if not os.path.isfile(path_pt):
            if not os.path.isfile(path) and url:
                logger.info('Downloading vectors from {}'.format(url))
                if not os.path.exists(cache):
                    os.makedirs(cache)
                dest = os.path.join(cache, os.path.basename(url))
                if not os.path.isfile(dest):
                    with tqdm(unit='B', unit_scale=True, miniters=1, desc=dest) as t:
                        try:
                            urlretrieve(url, dest, reporthook=reporthook(t))
                        except KeyboardInterrupt as e:  # remove the partial zip file
                            os.remove(dest)
                            raise e
                logger.info('Extracting vectors into {}'.format(cache))
                ext = os.path.splitext(dest)[1][1:]
                if ext == 'zip':
                    with zipfile.ZipFile(dest, "r") as zf:
                        zf.extractall(cache)
                elif ext == 'gz':
                    if dest.endswith('.tar.gz'):
                        with tarfile.open(dest, 'r:gz') as tar:
                            tar.extractall(path=cache)
            if not os.path.isfile(path):
                raise RuntimeError('no vectors found at {}'.format(path))

            logger.info("Loading vectors from {}".format(path))
            ext = os.path.splitext(path)[1][1:]
            if ext == 'gz':
                open_file = gzip.open
            else:
                open_file = open

            vectors_loaded = 0
            with open_file(path, 'rb') as f:
                num_lines, dim = _infer_shape(f)
                if not max_vectors or max_vectors > num_lines:
                    max_vectors = num_lines

                itos, vectors, dim = [], torch.zeros((max_vectors, dim)), None

                for line in tqdm(f, total=max_vectors):
                    # Explicitly splitting on " " is important, so we don't
                    # get rid of Unicode non-breaking spaces in the vectors.
                    entries = line.rstrip().split(b" ")

                    word, entries = entries[0], entries[1:]
                    if dim is None and len(entries) > 1:
                        dim = len(entries)
                    elif len(entries) == 1:
                        logger.warning("Skipping token {} with 1-dimensional "
                                       "vector {}; likely a header".format(word, entries))
                        continue
                    elif dim != len(entries):
                        raise RuntimeError(
                            "Vector for token {} has {} dimensions, but previously "
                            "read vectors have {} dimensions. All vectors must have "
                            "the same number of dimensions.".format(word, len(entries),
                                                                    dim))

                    try:
                        if isinstance(word, bytes):
                            word = word.decode('utf-8')
                    except UnicodeDecodeError:
                        logger.info("Skipping non-UTF8 token {}".format(repr(word)))
                        continue

                    vectors[vectors_loaded] = torch.tensor([float(x) for x in entries])
                    vectors_loaded += 1
                    itos.append(word)

                    if vectors_loaded == max_vectors:
                        break

            self.itos = itos
            self.stoi = {word: i for i, word in enumerate(itos)}
            self.vectors = torch.Tensor(vectors).view(-1, dim)
            self.dim = dim
            logger.info('Saving vectors to {}'.format(path_pt))
            if not os.path.exists(cache):
                os.makedirs(cache)
            torch.save((self.itos, self.stoi, self.vectors, self.dim), path_pt)
        else:
            logger.info('Loading vectors from {}'.format(path_pt))
            self.itos, self.stoi, self.vectors, self.dim = torch.load(path_pt)

    def __len__(self):
        return len(self.vectors)

    def get_vecs_by_tokens(self, tokens, lower_case_backup=False):
        """Look up embedding vectors of tokens.

        Args:
            tokens: a token or a list of tokens. if `tokens` is a string,
                returns a 1-D tensor of shape `self.dim`; if `tokens` is a
                list of strings, returns a 2-D tensor of shape=(len(tokens),
                self.dim).
            lower_case_backup : Whether to look up the token in the lower case.
                If False, each token in the original case will be looked up;
                if True, each token in the original case will be looked up first,
                if not found in the keys of the property `stoi`, the token in the
                lower case will be looked up. Default: False.

        Examples:
            >>> examples = ['chip', 'baby', 'Beautiful']
            >>> vec = text.vocab.GloVe(name='6B', dim=50)
            >>> ret = vec.get_vecs_by_tokens(tokens, lower_case_backup=True)
        """
        to_reduce = False

        if not isinstance(tokens, list):
            tokens = [tokens]
            to_reduce = True

        if not lower_case_backup:
            indices = [self[token] for token in tokens]
        else:
            indices = [self[token] if token in self.stoi
                       else self[token.lower()]
                       for token in tokens]

        vecs = torch.stack(indices)
        return vecs[0] if to_reduce else vecs


class GloVe(Vectors):
    url = {
        '42B': 'http://nlp.stanford.edu/data/glove.42B.300d.zip',
        '840B': 'http://nlp.stanford.edu/data/glove.840B.300d.zip',
        'twitter.27B': 'http://nlp.stanford.edu/data/glove.twitter.27B.zip',
        '6B': 'http://nlp.stanford.edu/data/glove.6B.zip',
    }

    def __init__(self, name='840B', dim=300, **kwargs):
        url = self.url[name]
        name = 'glove.{}.{}d.txt'.format(name, str(dim))
        super(GloVe, self).__init__(name, url=url, **kwargs)


class FastText(Vectors):

    url_base = 'https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.{}.vec'

    def __init__(self, language="en", **kwargs):
        url = self.url_base.format(language)
        name = os.path.basename(url)
        super(FastText, self).__init__(name, url=url, **kwargs)


class CharNGram(Vectors):

    name = 'charNgram.txt'
    url = ('http://www.logos.t.u-tokyo.ac.jp/~hassy/publications/arxiv2016jmt/'
           'jmt_pre-trained_embeddings.tar.gz')

    def __init__(self, **kwargs):
        super(CharNGram, self).__init__(self.name, url=self.url, **kwargs)

    def __getitem__(self, token):
        vector = torch.Tensor(1, self.dim).zero_()
        if token == "<unk>":
            return self.unk_init(vector)
        chars = ['#BEGIN#'] + list(token) + ['#END#']
        num_vectors = 0
        for n in [2, 3, 4]:
            end = len(chars) - n + 1
            grams = [chars[i:(i + n)] for i in range(end)]
            for gram in grams:
                gram_key = '{}gram-{}'.format(n, ''.join(gram))
                if gram_key in self.stoi:
                    vector += self.vectors[self.stoi[gram_key]]
                    num_vectors += 1
        if num_vectors > 0:
            vector /= num_vectors
        else:
            vector = self.unk_init(vector)
        return vector


pretrained_aliases = {
    "charngram.100d": partial(CharNGram),
    "fasttext.en.300d": partial(FastText, language="en"),
    "fasttext.simple.300d": partial(FastText, language="simple"),
    "glove.42B.300d": partial(GloVe, name="42B", dim="300"),
    "glove.840B.300d": partial(GloVe, name="840B", dim="300"),
    "glove.twitter.27B.25d": partial(GloVe, name="twitter.27B", dim="25"),
    "glove.twitter.27B.50d": partial(GloVe, name="twitter.27B", dim="50"),
    "glove.twitter.27B.100d": partial(GloVe, name="twitter.27B", dim="100"),
    "glove.twitter.27B.200d": partial(GloVe, name="twitter.27B", dim="200"),
    "glove.6B.50d": partial(GloVe, name="6B", dim="50"),
    "glove.6B.100d": partial(GloVe, name="6B", dim="100"),
    "glove.6B.200d": partial(GloVe, name="6B", dim="200"),
    "glove.6B.300d": partial(GloVe, name="6B", dim="300")
}
"""Mapping from string name to factory function"""

'Mapping from string name to factory function'

In [6]:
from collections import defaultdict
import logging
import torch
from tqdm import tqdm
from collections import Counter
# from torchtext.vocab import (
#     pretrained_aliases,  # not in legacy
#     Vectors,  # not in legacy
# )
from typing import List
logger = logging.getLogger(__name__)


class Vocab(object):
    """Defines a vocabulary object that will be used to numericalize a field.

    Attributes:
        freqs: A collections.Counter object holding the frequencies of tokens
            in the data used to build the Vocab.
        stoi: A collections.defaultdict instance mapping token strings to
            numerical identifiers.
        itos: A list of token strings indexed by their numerical identifiers.
    """

    # TODO (@mttk): Populate classs with default values of special symbols
    UNK = '<unk>'

    def __init__(self, counter, max_size=None, min_freq=1, specials=('<unk>', '<pad>'),
                 vectors=None, unk_init=None, vectors_cache=None, specials_first=True):
        """Create a Vocab object from a collections.Counter.

        Args:
            counter: collections.Counter object holding the frequencies of
                each value found in the data.
            max_size: The maximum size of the vocabulary, or None for no
                maximum. Default: None.
            min_freq: The minimum frequency needed to include a token in the
                vocabulary. Values less than 1 will be set to 1. Default: 1.
            specials: The list of special tokens (e.g., padding or eos) that
                will be prepended to the vocabulary. Default: ['<unk'>, '<pad>']
            vectors: One of either the available pretrained vectors
                or custom pretrained vectors (see Vocab.load_vectors);
                or a list of aforementioned vectors
            unk_init (callback): by default, initialize out-of-vocabulary word vectors
                to zero vectors; can be any function that takes in a Tensor and
                returns a Tensor of the same size. Default: 'torch.zeros'
            vectors_cache: directory for cached vectors. Default: '.vector_cache'
            specials_first: Whether to add special tokens into the vocabulary at first.
                If it is False, they are added into the vocabulary at last.
                Default: True.
        """
        self.freqs = counter
        counter = counter.copy()
        min_freq = max(min_freq, 1)

        self.itos = list()
        self.unk_index = None
        if specials_first:
            self.itos = list(specials)
            # only extend max size if specials are prepended
            max_size = None if max_size is None else max_size + len(specials)

        # frequencies of special tokens are not counted when building vocabulary
        # in frequency order
        for tok in specials:
            if tok in counter:
                del counter[tok]

        # sort by frequency, then alphabetically
        words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0])
        words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)

        for word, freq in words_and_frequencies:
            if freq < min_freq or len(self.itos) == max_size:
                break
            self.itos.append(word)

        if Vocab.UNK in specials:  # hard-coded for now
            unk_index = specials.index(Vocab.UNK)  # position in list
            # account for ordering of specials, set variable
            self.unk_index = unk_index if specials_first else len(self.itos) + unk_index
            self.stoi = defaultdict(self._default_unk_index)
        else:
            self.stoi = defaultdict()

        if not specials_first:
            self.itos.extend(list(specials))

        # stoi is simply a reverse dict for itos
        self.stoi.update({tok: i for i, tok in enumerate(self.itos)})

        self.vectors = None
        if vectors is not None:
            self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache)
        else:
            assert unk_init is None and vectors_cache is None

    def forward(self, tokens: List[str]) -> List[int]:
        return self.lookup_indices(tokens)

    def set_default_index(self, idx):
        self.unk_index = idx
        
    def _default_unk_index(self):
        return self.unk_index

    def __getitem__(self, token):
        return self.stoi.get(token, self.stoi.get(Vocab.UNK))

    def __getstate__(self):
        # avoid picking defaultdict
        attrs = dict(self.__dict__)
        # cast to regular dict
        attrs['stoi'] = dict(self.stoi)
        return attrs

    def __setstate__(self, state):
        if state.get("unk_index", None) is None:
            stoi = defaultdict()
        else:
            stoi = defaultdict(self._default_unk_index)
        stoi.update(state['stoi'])
        state['stoi'] = stoi
        self.__dict__.update(state)

    def __eq__(self, other):
        if self.freqs != other.freqs:
            return False
        if self.stoi != other.stoi:
            return False
        if self.itos != other.itos:
            return False
        if self.vectors != other.vectors:
            return False
        return True

    def __len__(self):
        return len(self.itos)

    def lookup_indices(self, tokens):
        indices = [self.__getitem__(token) for token in tokens]
        return indices

    def extend(self, v, sort=False):
        words = sorted(v.itos) if sort else v.itos
        for w in words:
            if w not in self.stoi:
                self.itos.append(w)
                self.stoi[w] = len(self.itos) - 1

    def load_vectors(self, vectors, **kwargs):
        """
        Args:
            vectors: one of or a list containing instantiations of the
                GloVe, CharNGram, or Vectors classes. Alternatively, one
                of or a list of available pretrained vectors:

                charngram.100d
                fasttext.en.300d
                fasttext.simple.300d
                glove.42B.300d
                glove.840B.300d
                glove.twitter.27B.25d
                glove.twitter.27B.50d
                glove.twitter.27B.100d
                glove.twitter.27B.200d
                glove.6B.50d
                glove.6B.100d
                glove.6B.200d
                glove.6B.300d

            Remaining keyword arguments: Passed to the constructor of Vectors classes.
        """
        if not isinstance(vectors, list):
            vectors = [vectors]
        for idx, vector in enumerate(vectors):
            if isinstance(vector, str):
                # Convert the string pretrained vector identifier
                # to a Vectors object
                if vector not in pretrained_aliases:
                    raise ValueError("Got string input vector {}, but allowed pretrained vectors are {}".format(vector, list(pretrained_aliases.keys())))
                vectors[idx] = pretrained_aliases[vector](**kwargs)
            elif not isinstance(vector, Vectors):
                raise ValueError( "Got input vectors of type {}, expected str or Vectors object".format(type(vector)))
        tot_dim = sum(v.dim for v in vectors)
        self.vectors = torch.Tensor(len(self), tot_dim)
        for i, token in enumerate(self.itos):
            start_dim = 0
            for v in vectors:
                end_dim = start_dim + v.dim
                self.vectors[i][start_dim:end_dim] = v[token.strip()]
                start_dim = end_dim
            assert(start_dim == tot_dim)

    def set_vectors(self, stoi, vectors, dim, unk_init=torch.Tensor.zero_):
        """
        Set the vectors for the Vocab instance from a collection of Tensors.

        Args:
            stoi: A dictionary of string to the index of the associated vector
                in the `vectors` input argument.
            vectors: An indexed iterable (or other structure supporting __getitem__) that
                given an input index, returns a FloatTensor representing the vector
                for the token associated with the index. For example,
                vector[stoi["string"]] should return the vector for "string".
            dim: The dimensionality of the vectors.
            unk_init (callback): by default, initialize out-of-vocabulary word vectors
                to zero vectors; can be any function that takes in a Tensor and
                returns a Tensor of the same size. Default: 'torch.zeros'
        """
        self.vectors = torch.Tensor(len(self), dim)
        for i, token in enumerate(self.itos):
            wv_index = stoi.get(token, None)
            if wv_index is not None:
                self.vectors[i] = vectors[wv_index]
            else:
                self.vectors[i] = unk_init(self.vectors[i])


class SubwordVocab(Vocab):

    def __init__(self, counter, max_size=None, specials=('<pad>'),
                 vectors=None, unk_init=torch.Tensor.zero_):
        """Create a revtok subword vocabulary from a collections.Counter.

        Args:
            counter: collections.Counter object holding the frequencies of
                each word found in the data.
            max_size: The maximum size of the subword vocabulary, or None for no
                maximum. Default: None.
            specials: The list of special tokens (e.g., padding or eos) that
                will be prepended to the vocabulary in addition to an <unk>
                token.
            vectors: One of either the available pretrained vectors
                or custom pretrained vectors (see Vocab.load_vectors);
                or a list of aforementioned vectors
            unk_init (callback): by default, initialize out-of-vocabulary word vectors
                to zero vectors; can be any function that takes in a Tensor and
                returns a Tensor of the same size. Default: 'torch.zeros
        """
        try:
            import revtok
        except ImportError:
            print("Please install revtok.")
            raise

        # Hardcode unk_index as subword_vocab has no specials_first argument
        self.unk_index = (specials.index(SubwordVocab.UNK)
                          if SubwordVocab.UNK in specials else None)

        if self.unk_index is None:
            self.stoi = defaultdict()
        else:
            self.stoi = defaultdict(self._default_unk_index)

        self.stoi.update({tok: i for i, tok in enumerate(specials)})
        self.itos = specials.copy()

        self.segment = revtok.SubwordSegmenter(counter, max_size)

        max_size = None if max_size is None else max_size + len(self.itos)

        # sort by frequency/entropy, then alphabetically
        toks = sorted(self.segment.vocab.items(), key=lambda tup: (len(tup[0]) != 1, -tup[1], tup[0]))

        for tok, _ in toks:
            if len(self.itos) == max_size:
                break
            self.itos.append(tok)
            self.stoi[tok] = len(self.itos) - 1

        if vectors is not None:
            self.load_vectors(vectors, unk_init=unk_init)


def build_vocab_from_iterator(iterator, num_lines=None):
    """
    Build a Vocab from an iterator.

    Args:
        iterator: Iterator used to build Vocab. Must yield list or iterator of tokens.
        num_lines: The expected number of elements returned by the iterator.
            (Default: None)
            Optionally, if known, the expected number of elements can be passed to
            this factory function for improved progress reporting.
    """

    counter = Counter()
    with tqdm(unit_scale=0, unit='lines', total=num_lines) as t:
        for tokens in iterator:
            counter.update(tokens)
            t.update(1)
    word_vocab = Vocab(counter)
    return word_vocab

# Tokenizer

In [7]:
from collections import Counter, OrderedDict
from itertools import cycle
import re
import random
# from torchtext.vocab import vocab # Built it custom
from tqdm import tqdm
import warnings

class Tokenizer:
    """
    Tokenizer for processing symbolic mathematical expressions.
    """
    def __init__(self, df, index_token_pool_size, momentum_token_pool_size, special_symbols, UNK_IDX, to_replace):
        self.amps = df.Amplitude.tolist()
        self.sqamps = df['Squared Amplitude'].tolist()

        # Issue warnings if token pool sizes are too small
        if index_token_pool_size < 100:
            warnings.warn(f"Index token pool size ({index_token_pool_size}) is small. Consider increasing it.", UserWarning)
        if momentum_token_pool_size < 100:
            warnings.warn(f"Momentum token pool size ({momentum_token_pool_size}) is small. Consider increasing it.", UserWarning)
        
        # Generate token pools
        self.tokens_pool = [f"_{i}" for i in range(index_token_pool_size)]
        self.momentum_pool = [f"MOMENTUM_{i}" for i in range(momentum_token_pool_size)]
        
        # Regular expression patterns for token replacement
        self.pattern_momentum = re.compile(r'\b[p]_\d{1,}\b')
        self.pattern_num_123 = re.compile(r'\b(?![ps]_)\w+_\d{1,}\b')
        self.pattern_special = re.compile(r'\b\w+_+\w+\b\\')
        self.pattern_prop = re.compile(r'Prop')
        self.pattern_int = re.compile(r'int\{')
        self.pattern_operators = {
            '+': re.compile(r'\+'), '-': re.compile(r'-'), '*': re.compile(r'\*'),
            ',': re.compile(r','), '^': re.compile(r'\^'), '%': re.compile(r'%'),
            '}': re.compile(r'\}'), '(': re.compile(r'\('), ')': re.compile(r'\)')
        }
        self.function_opening = re.compile(r'(\w+_\{)')
        self.pattern_mass = re.compile(r'\b\w+_\w\b')
        self.pattern_s = re.compile(r'\b\w+_\d{2,}\b')
        self.pattern_reg_prop = re.compile(r'\b\w+_\d{1}\b')
        self.pattern_antipart = re.compile(r'(\w)_\w+_\d+\(X\)\^\(\*\)')
        self.pattern_part = re.compile(r'(\w)_\w+_\d+\(X\)')
        self.pattern_index = re.compile(r'\b\w+_\w+_\d{2,}\b')
        
        self.special_symbols = special_symbols
        self.UNK_IDX = UNK_IDX
        self.to_replace = to_replace

    @staticmethod
    def remove_whitespace(expression):
        """Remove all forms of whitespace from the expression."""
        return re.sub(r'\s+', '', expression)

    @staticmethod
    def split_expression(expression):
        """Split the expression by space delimiter."""
        return re.split(r' ', expression)

    def build_tgt_vocab(self):
        """Build vocabulary for target sequences."""
        counter = Counter()
        for eqn in tqdm(self.sqamps, desc='Processing target vocab'):
            counter.update(self.tgt_tokenize(eqn))
        voc = Vocab(OrderedDict(counter), specials=self.special_symbols[:], specials_first=True)
        voc.set_default_index(self.UNK_IDX)
        return voc

    def build_src_vocab(self, seed):
        """Build vocabulary for source sequences."""
        counter = Counter()
        for diag in tqdm(self.amps, desc='Processing source vocab'):
            counter.update(self.src_tokenize(diag, seed))
        voc = Vocab(OrderedDict(counter), specials=self.special_symbols[:], specials_first=True)
        voc.set_default_index(self.UNK_IDX)
        return voc
    
    def src_replace(self, ampl, seed):
        """Replace indexed and momentum variables with tokenized equivalents."""
        ampl = self.remove_whitespace(ampl)
        
        random.seed(seed)
        token_cycle = cycle(random.sample(self.tokens_pool, len(self.tokens_pool)))
        momentum_cycle = cycle(random.sample(self.momentum_pool, len(self.momentum_pool)))
        
        # Replace momentum tokens
        temp_ampl = ampl
        momentum_mapping = {match: next(momentum_cycle) for match in set(self.pattern_momentum.findall(ampl))}
        for key, value in momentum_mapping.items():
            temp_ampl = temp_ampl.replace(key, value)


        def replace_123_match(match):
            word, num = match.group().rsplit('_', 1)
            # if word == 's':
            #     # Mandstein
            #     return match.group()
            # if word == 'p':
            #     # return match.group()  # Keep 'p_X' unchanged
            #     return f'MOMENTUM_{num}'
            # IDX_POOL.setdefault(num, len(IDX_POOL) + 1)
            return f"{word}_INDEX_{next(token_cycle)}"
        # Replace index tokens
        # num_123_mapping = {match: match.rsplit('_',1) + next(token_cycle) for match in set(self.pattern_num_123.findall(ampl))}
        # for key, value in num_123_mapping.items():
        #     temp_ampl = temp_ampl.replace(key, value)
        temp_ampl = re.sub(self.pattern_num_123,replace_123_match,ampl)
        return temp_ampl
    
    def src_tokenize(self, ampl, seed):
        """Tokenize source expression, optionally applying replacements."""
        temp_ampl = self.src_replace(ampl, seed) if self.to_replace else ampl
        temp_ampl = temp_ampl.replace('\\\\', '\\').replace('\\', ' \\ ').replace('%', '')
        
        for symbol, pattern in self.pattern_operators.items():
            temp_ampl = pattern.sub(f' {symbol} ', temp_ampl)

        temp_ampl = self.function_opening.sub(r'\1 ', temp_ampl)
        
        temp_ampl = re.sub(r' {2,}', ' ', temp_ampl)
        return [token for token in self.split_expression(temp_ampl) if token]

    def tgt_tokenize(self, sqampl):
        """Tokenize target expression."""
        sqampl = self.remove_whitespace(sqampl)
        temp_sqampl = sqampl
        
        for symbol, pattern in self.pattern_operators.items():
            temp_sqampl = pattern.sub(f' {symbol} ', temp_sqampl)
        
        for pattern in [self.pattern_reg_prop, self.pattern_mass, self.pattern_s]:
            temp_sqampl = pattern.sub(lambda match: f' {match.group(0)} ', temp_sqampl)
        
        temp_sqampl = re.sub(r' {2,}', ' ', temp_sqampl)
        return [token for token in self.split_expression(temp_sqampl) if token]

In [8]:
BOS_IDX, PAD_IDX, EOS_IDX, UNK_IDX, SEP_IDX = 0, 1, 2, 3, 4
special_symbols = ['<S>', '<PAD>', '</S>', '<UNK>', '<SEP>']
tokenizer = Tokenizer(train_df, 500, 500, special_symbols, UNK_IDX, False)

In [9]:
def normalize_indices(tokenizer, expressions, index_token_pool_size=50, momentum_token_pool_size=50):
    # Function to replace indices with a new set of tokens for each expression
    def replace_indices(token_list, index_map):
        new_index = (f"_{i}" for i in range(index_token_pool_size))  # Local generator for new indices
        new_tokens = []
        for token in token_list:
            if "INDEX_" in token:
                if token not in index_map:
                    try:
                        index_map[token] = token.rsplit('_',1)[0] + next(new_index)
                    except StopIteration:
                        # Handle the case where no more indices are available
                        raise ValueError("Ran out of unique indices, increase token_pool_size")
                new_tokens.append(index_map[token])
            else:
                new_tokens.append(token)
        return new_tokens

    def replace_momenta(token_list, index_map):
        new_index = (f"MOMENTUM_{i}" for i in range(momentum_token_pool_size))  # Local generator for new indices
        new_tokens = []
        for token in token_list:
            if "MOMENTUM_" in token:
                if token not in index_map:
                    try:
                        index_map[token] = next(new_index)
                    except StopIteration:
                        # Handle the case where no more indices are available
                        raise ValueError("Ran out of unique indices, increase momentum_token_pool_size")
                new_tokens.append(index_map[token])
            else:
                new_tokens.append(token)
        return new_tokens

    normalized_expressions = []
    # Replace indices in each expression randomly
    for expr in tqdm(expressions,desc="Normalizing.."):
        toks = tokenizer.src_tokenize(expr,42)
        normalized_expressions.append(replace_momenta(replace_indices(toks, {}), {}))

    return normalized_expressions


def aug_data(df):
    # Extract columns
    amps = df['Amplitude']
    sqamps = df['Squared Amplitude']

    # Data augmentation
    n_samples = 1 #args.n_samples
    aug_amps = []

    for amp in tqdm(amps, desc='processing'):
        random_seed = [random.randint(1, 1000) for _ in range(n_samples)]
        for seed in random_seed:
            aug_amps.append(tokenizer.src_replace(amp, seed))
    aug_sqamps = [sqamp for sqamp in sqamps for _ in range(n_samples)]

    if True:
        normal_amps = normalize_indices(tokenizer, aug_amps, 500, 500)
        aug_amps = []
        for amp in normal_amps:
            aug_amps.append("".join(amp))

    # Create augmented DataFrame
    df_aug = pd.DataFrame({"Amplitude": aug_amps, "Squared Amplitude": aug_sqamps})

    return df_aug

In [10]:
train_df_aug = aug_data(train_df)

processing: 100%|██████████| 12441/12441 [00:06<00:00, 1939.89it/s]
Normalizing..: 100%|██████████| 12441/12441 [00:02<00:00, 5838.41it/s]


In [11]:
val_df_aug = aug_data(val_df)
test_df_aug = aug_data(test_df)

processing: 100%|██████████| 1555/1555 [00:00<00:00, 1923.91it/s]
Normalizing..: 100%|██████████| 1555/1555 [00:00<00:00, 6024.53it/s]
processing: 100%|██████████| 1556/1556 [00:00<00:00, 1948.63it/s]
Normalizing..: 100%|██████████| 1556/1556 [00:00<00:00, 5774.30it/s]


In [12]:
train_df_aug.to_csv('train_df_aug.csv')
val_df_aug.to_csv('val_df_aug.csv')
test_df_aug.to_csv('test_df_aug.csv')

In [13]:
tokenizer2 = Tokenizer(train_df_aug, 500, 500, special_symbols, UNK_IDX, False)

In [14]:
src_vocab2 = tokenizer2.build_src_vocab(42)
len(src_vocab2.itos)

Processing source vocab: 100%|██████████| 12441/12441 [00:01<00:00, 7657.53it/s]


459

In [15]:
tgt_vocab2 = tokenizer2.build_tgt_vocab()

Processing target vocab: 100%|██████████| 12441/12441 [00:01<00:00, 8103.24it/s]


In [16]:
len(tgt_vocab2.itos)

59

# Torch Dataset

In [17]:
from torch.utils.data import Dataset
import torch

def causal_mask(size):
    """Create a causal mask for a sequence of given size."""
    mask = torch.triu(torch.ones((1, size, size)), diagonal=1).int()
    return mask == 0
    
class Data(Dataset):
    """
    Custom PyTorch dataset for handling data.

    Args:
        df (DataFrame): DataFrame containing data.
    """

    def __init__(self, df, tokenizer, config, src_vocab, tgt_vocab):
        super(Data, self).__init__()
        self.tgt_vals = df['Squared Amplitude']
        self.src_vals = df['Amplitude']
        self.tgt_tokenize = tokenizer.tgt_tokenize
        self.src_tokenize = tokenizer.src_tokenize
        self.bos_token = torch.tensor([BOS_IDX], dtype=torch.int64)
        self.eos_token = torch.tensor([EOS_IDX], dtype=torch.int64)
        self.pad_token = torch.tensor([PAD_IDX], dtype=torch.int64)
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab
        self.config = config

    def __len__(self):
        """
        Get the length of the dataset.

        Returns:
            int: Length of the dataset.
        """
        return len(self.src_vals)

    def __getitem__(self, idx):
        """
        Get an item from the dataset at the specified index.

        Args:
            idx (int): Index of the item.

        Returns:
            tuple: Tuple containing source and target tensors.
        """
        # print(f'index: {idx}')
        src_tokenized = self.src_tokenize(self.src_vals[idx],self.config.seed)
        tgt_tokenized = self.tgt_tokenize(self.tgt_vals[idx])
        src_ids = self.src_vocab.forward(src_tokenized)
        tgt_ids = self.tgt_vocab.forward(tgt_tokenized)

        enc_num_padding_tokens = self.config.src_max_len - len(src_ids) - 2
        dec_num_padding_tokens = self.config.tgt_max_len - len(tgt_ids) - 1
        # print(f'src_ids: {len(src_ids)} tgt_ids:  {len(tgt_ids)} enc_num: {enc_num_padding_tokens} dec_num: {dec_num_padding_tokens} \n' )
        if self.config.truncate:
            if enc_num_padding_tokens < 0:
                src_ids = src_ids[:self.config.src_max_len-2]
            if dec_num_padding_tokens < 0:
                tgt_ids = tgt_ids[:self.config.tgt_max_len-1]
        else:
            if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
                raise ValueError("Sentence is too long")
        src_tensor = torch.cat(
            [
                self.bos_token,
                torch.tensor(src_ids, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] *
                             enc_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        tgt_tensor = torch.cat(
            [
                self.bos_token,
                torch.tensor(tgt_ids, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] *
                             dec_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        label = torch.cat(
            [
                torch.tensor(tgt_ids, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        src_mask = (src_tensor != self.pad_token).unsqueeze(0).unsqueeze(0).int() # (1, 1, seq_len)
        tgt_mask = (tgt_tensor != self.pad_token).unsqueeze(0).int() & causal_mask(tgt_tensor.size(0)) # (1, seq_len) & (1, seq_len, seq_len),

        return src_tensor, tgt_tensor, label, src_mask, tgt_mask#, len(src_ids), len(tgt_ids)

    @staticmethod
    def get_data(df_train, df_test, df_valid, config, tokenizer, src_vocab,tgt_vocab):
        """
        Create datasets (train, test, and valid)

        Returns:
            dict: Dictionary containing train, test, and valid datasets.
        """
        train = Data(df_train, tokenizer, config,src_vocab,tgt_vocab)
        test = Data(df_test, tokenizer, config,src_vocab,tgt_vocab)
        valid = Data(df_valid, tokenizer, config,src_vocab,tgt_vocab)

        return {'train': train, 'test': test, 'valid': valid}

In [18]:
class data_config:
    src_max_len = 300
    tgt_max_len = 325
    truncate = False
    seed = 42

In [19]:
dataset = Data.get_data(train_df_aug, test_df_aug, val_df_aug, data_config, tokenizer2, src_vocab2, tgt_vocab2)

# Transformer

In [20]:
from x_transformers import XTransformer

In [21]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm
import numpy as np

# Hyperparameters
batch_size = 32
# learning_rate = 5e-5

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Create dataloaders
dataloaders = {
    split: DataLoader(
        dataset[split],
        batch_size=batch_size if split == 'train' else 64,
        shuffle=(split == 'train'),
        pin_memory=True,
        num_workers=4
    ) for split in ['train', 'valid', 'test']
}

In [22]:
def calculate_accuracy(model, dataloader, device):
    model.eval()
    model.to(device)
    total_seq = 0
    correct_seq = 0
    total_tokens = 0
    correct_tokens = 0
    loss_sum = 0
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Calculating accuracy"):
            # Move batch to device
            src_tensor, tgt_tensor, label, src_mask, tgt_mask = batch
            start_tokens = (torch.zeros(src_tensor.shape[0],1)).long().cuda()
            src_tensor = src_tensor.to(device)
            tgt_tensor = tgt_tensor.to(device)
            label = label.to(device)
            src_mask = src_mask.squeeze(1,2).to(device=device,dtype=torch.bool)
            tgt_mask = tgt_mask.squeeze(1,2).to(device=device,dtype=torch.bool)
            # Generate predictions
            outputs = model.generate(
                src_tensor, start_tokens, 325, mask = src_mask)
            
            total_tokens += outputs.numel()
            correct_tokens += ((outputs == label)).sum().item()
            
            batch_size = label.size(0)
            seq_correct = torch.all((outputs == label), dim=1)
            correct_seq += seq_correct.sum().item()
            total_seq += batch_size
    
    # Calculate metrics
    token_accuracy = correct_tokens / total_tokens if total_tokens > 0 else 0
    sequence_accuracy = correct_seq / total_seq if total_seq > 0 else 0
    avg_loss = loss_sum / len(dataloader)
    
    return {
        'token_accuracy': token_accuracy,
        'sequence_accuracy': sequence_accuracy,
        'loss': avg_loss
    }

In [23]:
def calculate_line_params(point1, point2):

    x1, y1 = point1
    x2, y2 = point2

    # Check if the x coordinates are the same to avoid division by zero
    if x1 == x2:
        raise ValueError(
            "The x coordinates of the two points must be different to define a straight line.")

    # Calculate the slope (m)
    m = (y2 - y1) / (x2 - x1)

    # Calculate the intercept (b)
    b = y1 - m * x1

    return m, b

In [24]:
weight_decay = 0.01
# num_epochs = 12
grad_accumulation_steps = 1
max_grad_norm = 1.0
warmup_steps = 1000

In [25]:
from torch.optim.lr_scheduler import LambdaLR
import copy

model = XTransformer(
    dim = 512,
    enc_num_tokens = 459,
    enc_depth = 3,
    enc_heads = 8,
    enc_max_seq_len = 300,
    dec_num_tokens = 59,
    dec_depth = 3,
    dec_heads = 8,
    dec_max_seq_len = 325,
    # mult=8,
    tie_token_emb = False,      # tie embeddings of encoder and decoder
)


start_lr = 1e-3
end_lr = 1e-8
warmup_steps = 250
num_epochs = 30

optimizer = AdamW(model.parameters(), lr=start_lr, weight_decay=weight_decay)
m_warm, c_warm = calculate_line_params( (0, end_lr), (warmup_steps, start_lr))

def lam_warm(step): 
    return (1/start_lr)*(m_warm*step + c_warm)
                        
warm_scheduler = LambdaLR(optimizer, lr_lambda=lam_warm)

m_decay, c_decay = calculate_line_params( (0, start_lr), (num_epochs, end_lr))

def lam(epoch): 
    return (1/start_lr) * (m_decay*epoch + c_decay)
    
lr_scheduler = LambdaLR(optimizer, lr_lambda=lam)

# Move model to device
model = model.to(device)
global_steps = 0

# Training loop
best_valid_loss = float('inf')
best_valid_epoch = 0
for epoch in range(num_epochs):
    # Training phase
    model.train()
    train_loss = 0
    train_steps = 0
    
    progress_bar = tqdm(dataloaders['train'], desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
    for i, batch in enumerate(progress_bar):
        
        # Move batch to device
        src_tensor, tgt_tensor, label, src_mask, tgt_mask = batch
        
        src_tensor = src_tensor.to(device)
        tgt_tensor = tgt_tensor.to(device)
        src_mask = src_mask.squeeze(1,2).to(device=device,dtype=torch.bool)
        
        # Forward pass
        loss = model(src_tensor, tgt_tensor, mask = src_mask)
        loss = loss / grad_accumulation_steps
        loss.backward()
        
        # Update metrics
        train_loss += loss.item() * grad_accumulation_steps
        train_steps += 1
        global_steps += 1
        # Update progress bar
        progress_bar.set_postfix({'loss': train_loss / train_steps})
        
        # Gradient accumulation
        if (i + 1) % grad_accumulation_steps == 0:
            # Clip gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            
            # Update weights
            optimizer.step()
            # scheduler.step()
            optimizer.zero_grad()
        
            if global_steps < warmup_steps:
                warm_scheduler.step()
            elif global_steps <= warmup_steps + grad_accumulation_steps:
                print('Warmup completed!')
            
        src_tensor = src_tensor.to('cpu')
        tgt_tensor = tgt_tensor.to('cpu')
        src_mask = src_mask.to(device='cpu')   

    # Calculate average training loss
    avg_train_loss = train_loss / train_steps
    if global_steps > warmup_steps:
        lr_scheduler.step(epoch)
        print('LR changed to:', lr_scheduler.get_last_lr())
        
    # Validation phase
    model.eval()
    valid_loss = 0
    valid_steps = 0
    
    with torch.no_grad():
        progress_bar = tqdm(dataloaders['valid'], desc=f"Epoch {epoch+1}/{num_epochs} [Valid]")
        for batch in progress_bar:
            # Move batch to device
            src_tensor, tgt_tensor, label, src_mask, tgt_mask = batch
            
            src_tensor = src_tensor.to(device)
            tgt_tensor = tgt_tensor.to(device)
            src_mask = src_mask.squeeze(1,2).to(device=device,dtype=torch.bool)
            # Forward pass
            loss = model(src_tensor, tgt_tensor, mask = src_mask)
            
            # Update metrics
            valid_loss += loss.item()
            valid_steps += 1
            
            # Update progress bar
            progress_bar.set_postfix({'loss': valid_loss / valid_steps})
            src_tensor = src_tensor.to('cpu')
            tgt_tensor = tgt_tensor.to('cpu')
            src_mask = src_mask.to(device='cpu')   
            
    # Calculate average validation loss
    avg_valid_loss = valid_loss / valid_steps
    
    # Print epoch summary
    print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {avg_train_loss:.5f}, Valid Loss: {avg_valid_loss:.5f}")
    
    # Save checkpoint if validation loss improved
    if avg_valid_loss < best_valid_loss:
        best_valid_loss = avg_valid_loss
        best_valid_epoch = epoch + 1
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': lr_scheduler.state_dict(),
            'train_loss': avg_train_loss,
            'valid_loss': avg_valid_loss,
        }, f'best_model_checkpoint_{epoch+1}.pt')
        print(f"Model checkpoint saved (Valid Loss: {avg_valid_loss:.5f})")

    if (epoch + 1) % 5 == 0:
        tmodel = copy.deepcopy(model)
        tmodel.load_state_dict(torch.load(f'/kaggle/working/best_model_checkpoint_{best_valid_epoch}.pt')['model_state_dict'])
        test_metrics = calculate_accuracy(tmodel, dataloaders['test'], device)
        print(test_metrics)
# Final evaluation on test set
model.eval()
test_loss = 0
test_steps = 0

with torch.no_grad():
    progress_bar = tqdm(dataloaders['test'], desc="Testing")
    for batch in progress_bar:
        # Move batch to device
        src_tensor, tgt_tensor, label, src_mask, tgt_mask = batch
        # labels = batch['labels'].to(device) if 'labels' in batch else None
        src_tensor = src_tensor.to(device)
        tgt_tensor = tgt_tensor.to(device)
        src_mask = src_mask.squeeze(1,2).to(device=device,dtype=torch.bool)
        # Forward pass
        loss = model(src_tensor, tgt_tensor, mask = src_mask)
        
        # Update metrics
        test_loss += loss.item()
        test_steps += 1
        
        # Update progress bar
        progress_bar.set_postfix({'loss': test_loss / test_steps})
        
        src_tensor = src_tensor.to('cpu')
        tgt_tensor = tgt_tensor.to('cpu')
        src_mask = src_mask.to(device='cpu')   
# Calculate average test loss
avg_test_loss = test_loss / test_steps
print(f"Test Loss: {avg_test_loss:.4f}")
print("Training completed!")

Epoch 1/30 [Train]:  64%|██████▍   | 250/389 [01:17<00:42,  3.31it/s, loss=0.459]

Warmup completed!


Epoch 1/30 [Train]:  65%|██████▍   | 251/389 [01:17<00:41,  3.31it/s, loss=0.458]

Warmup completed!


Epoch 1/30 [Train]: 100%|██████████| 389/389 [01:59<00:00,  3.27it/s, loss=0.309]


LR changed to: [0.001]


Epoch 1/30 [Valid]: 100%|██████████| 25/25 [00:05<00:00,  4.29it/s, loss=0.0235]


Epoch 1/30 - Train Loss: 0.30878, Valid Loss: 0.02353
Model checkpoint saved (Valid Loss: 0.02353)


Epoch 2/30 [Train]: 100%|██████████| 389/389 [01:57<00:00,  3.30it/s, loss=0.0085]


LR changed to: [0.000966667]


Epoch 2/30 [Valid]: 100%|██████████| 25/25 [00:05<00:00,  4.47it/s, loss=0.0228]


Epoch 2/30 - Train Loss: 0.00850, Valid Loss: 0.02285
Model checkpoint saved (Valid Loss: 0.02285)


Epoch 3/30 [Train]: 100%|██████████| 389/389 [01:57<00:00,  3.30it/s, loss=0.00435]


LR changed to: [0.0009333340000000001]


Epoch 3/30 [Valid]: 100%|██████████| 25/25 [00:05<00:00,  4.45it/s, loss=0.00133]


Epoch 3/30 - Train Loss: 0.00435, Valid Loss: 0.00133
Model checkpoint saved (Valid Loss: 0.00133)


Epoch 4/30 [Train]: 100%|██████████| 389/389 [01:58<00:00,  3.30it/s, loss=0.00575]


LR changed to: [0.000900001]


Epoch 4/30 [Valid]: 100%|██████████| 25/25 [00:05<00:00,  4.37it/s, loss=0.00567]


Epoch 4/30 - Train Loss: 0.00575, Valid Loss: 0.00567


Epoch 5/30 [Train]: 100%|██████████| 389/389 [01:58<00:00,  3.29it/s, loss=0.00438]


LR changed to: [0.000866668]


Epoch 5/30 [Valid]: 100%|██████████| 25/25 [00:05<00:00,  4.41it/s, loss=0.00396]
  tmodel.load_state_dict(torch.load(f'/kaggle/working/best_model_checkpoint_{best_valid_epoch}.pt')['model_state_dict'])


Epoch 5/30 - Train Loss: 0.00438, Valid Loss: 0.00396


Calculating accuracy: 100%|██████████| 25/25 [01:03<00:00,  2.54s/it]


{'token_accuracy': 0.9085169072572672, 'sequence_accuracy': 0.781491002570694, 'loss': 0.0}


Epoch 6/30 [Train]: 100%|██████████| 389/389 [01:58<00:00,  3.29it/s, loss=0.00371]


LR changed to: [0.000833335]


Epoch 6/30 [Valid]: 100%|██████████| 25/25 [00:05<00:00,  4.34it/s, loss=0.00266]


Epoch 6/30 - Train Loss: 0.00371, Valid Loss: 0.00266


Epoch 7/30 [Train]: 100%|██████████| 389/389 [01:58<00:00,  3.29it/s, loss=0.00535]


LR changed to: [0.000800002]


Epoch 7/30 [Valid]: 100%|██████████| 25/25 [00:05<00:00,  4.41it/s, loss=0.0023]


Epoch 7/30 - Train Loss: 0.00535, Valid Loss: 0.00230


Epoch 8/30 [Train]: 100%|██████████| 389/389 [01:58<00:00,  3.29it/s, loss=0.00212]


LR changed to: [0.000766669]


Epoch 8/30 [Valid]: 100%|██████████| 25/25 [00:05<00:00,  4.38it/s, loss=0.0021]


Epoch 8/30 - Train Loss: 0.00212, Valid Loss: 0.00210


Epoch 9/30 [Train]: 100%|██████████| 389/389 [01:58<00:00,  3.29it/s, loss=0.00345]


LR changed to: [0.000733336]


Epoch 9/30 [Valid]: 100%|██████████| 25/25 [00:05<00:00,  4.42it/s, loss=0.00445]


Epoch 9/30 - Train Loss: 0.00345, Valid Loss: 0.00445


Epoch 10/30 [Train]: 100%|██████████| 389/389 [01:58<00:00,  3.29it/s, loss=0.00417]


LR changed to: [0.000700003]


Epoch 10/30 [Valid]: 100%|██████████| 25/25 [00:05<00:00,  4.37it/s, loss=0.00149]


Epoch 10/30 - Train Loss: 0.00417, Valid Loss: 0.00149


Calculating accuracy: 100%|██████████| 25/25 [01:03<00:00,  2.53s/it]


{'token_accuracy': 0.9022463911409927, 'sequence_accuracy': 0.7667095115681234, 'loss': 0.0}


Epoch 11/30 [Train]: 100%|██████████| 389/389 [01:58<00:00,  3.29it/s, loss=0.00146]


LR changed to: [0.0006666700000000001]


Epoch 11/30 [Valid]: 100%|██████████| 25/25 [00:05<00:00,  4.41it/s, loss=0.0012]


Epoch 11/30 - Train Loss: 0.00146, Valid Loss: 0.00120
Model checkpoint saved (Valid Loss: 0.00120)


Epoch 12/30 [Train]: 100%|██████████| 389/389 [01:58<00:00,  3.28it/s, loss=0.0011]


LR changed to: [0.0006333370000000001]


Epoch 12/30 [Valid]: 100%|██████████| 25/25 [00:05<00:00,  4.34it/s, loss=0.00136]


Epoch 12/30 - Train Loss: 0.00110, Valid Loss: 0.00136


Epoch 13/30 [Train]: 100%|██████████| 389/389 [01:58<00:00,  3.29it/s, loss=0.0032]


LR changed to: [0.000600004]


Epoch 13/30 [Valid]: 100%|██████████| 25/25 [00:05<00:00,  4.41it/s, loss=0.000869]


Epoch 13/30 - Train Loss: 0.00320, Valid Loss: 0.00087
Model checkpoint saved (Valid Loss: 0.00087)


Epoch 14/30 [Train]: 100%|██████████| 389/389 [01:58<00:00,  3.29it/s, loss=0.00117]


LR changed to: [0.000566671]


Epoch 14/30 [Valid]: 100%|██████████| 25/25 [00:05<00:00,  4.41it/s, loss=0.00202]


Epoch 14/30 - Train Loss: 0.00117, Valid Loss: 0.00202


Epoch 15/30 [Train]: 100%|██████████| 389/389 [01:58<00:00,  3.29it/s, loss=0.0025]


LR changed to: [0.000533338]


Epoch 15/30 [Valid]: 100%|██████████| 25/25 [00:05<00:00,  4.40it/s, loss=0.00158]


Epoch 15/30 - Train Loss: 0.00250, Valid Loss: 0.00158


Calculating accuracy: 100%|██████████| 25/25 [01:03<00:00,  2.54s/it]


{'token_accuracy': 0.928904488827368, 'sequence_accuracy': 0.8412596401028277, 'loss': 0.0}


Epoch 16/30 [Train]: 100%|██████████| 389/389 [01:58<00:00,  3.29it/s, loss=0.000816]


LR changed to: [0.0005000050000000001]


Epoch 16/30 [Valid]: 100%|██████████| 25/25 [00:05<00:00,  4.36it/s, loss=0.000926]


Epoch 16/30 - Train Loss: 0.00082, Valid Loss: 0.00093


Epoch 17/30 [Train]: 100%|██████████| 389/389 [01:58<00:00,  3.30it/s, loss=0.000786]


LR changed to: [0.00046667200000000006]


Epoch 17/30 [Valid]: 100%|██████████| 25/25 [00:05<00:00,  4.34it/s, loss=0.000754]


Epoch 17/30 - Train Loss: 0.00079, Valid Loss: 0.00075
Model checkpoint saved (Valid Loss: 0.00075)


Epoch 18/30 [Train]: 100%|██████████| 389/389 [01:57<00:00,  3.30it/s, loss=0.00274]


LR changed to: [0.00043333900000000003]


Epoch 18/30 [Valid]: 100%|██████████| 25/25 [00:05<00:00,  4.38it/s, loss=0.00133]


Epoch 18/30 - Train Loss: 0.00274, Valid Loss: 0.00133


Epoch 19/30 [Train]: 100%|██████████| 389/389 [01:57<00:00,  3.30it/s, loss=0.000837]


LR changed to: [0.0004000060000000001]


Epoch 19/30 [Valid]: 100%|██████████| 25/25 [00:05<00:00,  4.39it/s, loss=0.000749]


Epoch 19/30 - Train Loss: 0.00084, Valid Loss: 0.00075
Model checkpoint saved (Valid Loss: 0.00075)


Epoch 20/30 [Train]: 100%|██████████| 389/389 [01:57<00:00,  3.30it/s, loss=0.000677]


LR changed to: [0.000366673]


Epoch 20/30 [Valid]: 100%|██████████| 25/25 [00:05<00:00,  4.41it/s, loss=0.0008]


Epoch 20/30 - Train Loss: 0.00068, Valid Loss: 0.00080


Calculating accuracy: 100%|██████████| 25/25 [01:01<00:00,  2.47s/it]


{'token_accuracy': 0.9345659481906269, 'sequence_accuracy': 0.8553984575835476, 'loss': 0.0}


Epoch 21/30 [Train]: 100%|██████████| 389/389 [01:57<00:00,  3.30it/s, loss=0.00114]


LR changed to: [0.00033334000000000006]


Epoch 21/30 [Valid]: 100%|██████████| 25/25 [00:05<00:00,  4.43it/s, loss=0.00085]


Epoch 21/30 - Train Loss: 0.00114, Valid Loss: 0.00085


Epoch 22/30 [Train]: 100%|██████████| 389/389 [01:57<00:00,  3.30it/s, loss=0.000733]


LR changed to: [0.00030000700000000003]


Epoch 22/30 [Valid]: 100%|██████████| 25/25 [00:05<00:00,  4.44it/s, loss=0.000646]


Epoch 22/30 - Train Loss: 0.00073, Valid Loss: 0.00065
Model checkpoint saved (Valid Loss: 0.00065)


Epoch 23/30 [Train]: 100%|██████████| 389/389 [01:57<00:00,  3.30it/s, loss=0.000548]


LR changed to: [0.00026667399999999995]


Epoch 23/30 [Valid]: 100%|██████████| 25/25 [00:05<00:00,  4.45it/s, loss=0.000664]


Epoch 23/30 - Train Loss: 0.00055, Valid Loss: 0.00066


Epoch 24/30 [Train]: 100%|██████████| 389/389 [01:57<00:00,  3.30it/s, loss=0.000982]


LR changed to: [0.00023334099999999997]


Epoch 24/30 [Valid]: 100%|██████████| 25/25 [00:05<00:00,  4.38it/s, loss=0.00265]


Epoch 24/30 - Train Loss: 0.00098, Valid Loss: 0.00265


Epoch 25/30 [Train]: 100%|██████████| 389/389 [01:57<00:00,  3.30it/s, loss=0.000689]


LR changed to: [0.000200008]


Epoch 25/30 [Valid]: 100%|██████████| 25/25 [00:05<00:00,  4.39it/s, loss=0.000572]


Epoch 25/30 - Train Loss: 0.00069, Valid Loss: 0.00057
Model checkpoint saved (Valid Loss: 0.00057)


Calculating accuracy: 100%|██████████| 25/25 [01:01<00:00,  2.47s/it]


{'token_accuracy': 0.943974688550524, 'sequence_accuracy': 0.884318766066838, 'loss': 0.0}


Epoch 26/30 [Train]: 100%|██████████| 389/389 [01:57<00:00,  3.30it/s, loss=0.000467]


LR changed to: [0.00016667500000000003]


Epoch 26/30 [Valid]: 100%|██████████| 25/25 [00:05<00:00,  4.44it/s, loss=0.000507]


Epoch 26/30 - Train Loss: 0.00047, Valid Loss: 0.00051
Model checkpoint saved (Valid Loss: 0.00051)


Epoch 27/30 [Train]: 100%|██████████| 389/389 [01:57<00:00,  3.30it/s, loss=0.000421]


LR changed to: [0.00013334200000000005]


Epoch 27/30 [Valid]: 100%|██████████| 25/25 [00:05<00:00,  4.35it/s, loss=0.000443]


Epoch 27/30 - Train Loss: 0.00042, Valid Loss: 0.00044
Model checkpoint saved (Valid Loss: 0.00044)


Epoch 28/30 [Train]: 100%|██████████| 389/389 [01:57<00:00,  3.30it/s, loss=0.000379]


LR changed to: [0.00010000899999999997]


Epoch 28/30 [Valid]: 100%|██████████| 25/25 [00:05<00:00,  4.42it/s, loss=0.000467]


Epoch 28/30 - Train Loss: 0.00038, Valid Loss: 0.00047


Epoch 29/30 [Train]: 100%|██████████| 389/389 [01:57<00:00,  3.30it/s, loss=0.000363]


LR changed to: [6.6676e-05]


Epoch 29/30 [Valid]: 100%|██████████| 25/25 [00:05<00:00,  4.34it/s, loss=0.000447]


Epoch 29/30 - Train Loss: 0.00036, Valid Loss: 0.00045


Epoch 30/30 [Train]: 100%|██████████| 389/389 [01:57<00:00,  3.30it/s, loss=0.000326]


LR changed to: [3.334300000000002e-05]


Epoch 30/30 [Valid]: 100%|██████████| 25/25 [00:05<00:00,  4.45it/s, loss=0.000391]


Epoch 30/30 - Train Loss: 0.00033, Valid Loss: 0.00039
Model checkpoint saved (Valid Loss: 0.00039)


Calculating accuracy: 100%|██████████| 25/25 [01:01<00:00,  2.47s/it]


{'token_accuracy': 0.9542258255882935, 'sequence_accuracy': 0.9087403598971723, 'loss': 0.0}


Testing: 100%|██████████| 25/25 [00:05<00:00,  4.38it/s, loss=0.000463]

Test Loss: 0.0005
Training completed!





In [26]:
# # ckpt_path = '/kaggle/working/best_model_checkpoint_10.pt'
# ckpt_path = '/kaggle/input/e1d1-trf-ep15/best_model_checkpoint_13.pt'
# model.load_state_dict(torch.load(ckpt_path)['model_state_dict'])
# calculate_accuracy(model, dataloaders['test'], device)