### Sanity Checks for each individual component in transformer.py

In [1]:
from transformer import Transformer
from transformer import LayerNorm
import torch
import torch.nn as nn
import pytest
import math
import torch.nn.functional as F

#### Scaled Dot Product Attention

In [3]:
from transformer import scaled_dot_product_attention

def test_scaled_dot_product(q, k, v):
    values, attention = scaled_dot_product_attention(q, k, v)

    #assuming that q k and v are 2d matrices
    similarity = (q @ k.T)/math.sqrt(q.size()[-1])
    similarity = F.softmax(similarity, dim = -1)

    expected_values = similarity @ v
    expected_attention = similarity

    # print(values)
    # print(expected_values)

    # print(attention)
    # print(expected_attention)

    assert torch.allclose(values, expected_values) and torch.allclose(attention, expected_attention)
    print("TEST: Scaled Dot Product PASSED!")
    

test_scaled_dot_product(torch.ones((2,2)), torch.ones((2,2)), torch.ones((2,2)))
test_scaled_dot_product(torch.ones((2,1)), torch.ones((2, 1)), torch.ones((2,1)))



TEST: Scaled Dot Product PASSED!
TEST: Scaled Dot Product PASSED!


#### Multiheaded Attention

In [4]:
#In order to test this out, let us assume that the batch dimension = 1, the max sequence length is 2, and d model is 6
#and also let us try to make the Wq, Wk, and Wv matrices the identity
#The weight matrix actually

from transformer import MultiheadedAttention

def test_multihead_attention():
    q = torch.ones((1,2,4))
    mha = MultiheadedAttention(num_heads = 2, d_model = 4, identity = True)
    output = mha(q)
    assert torch.allclose(q, output)
    print("TEST: Multiheaded Attention PASSED!")

test_multihead_attention()

TEST: Multiheaded Attention PASSED!


#### Cross Attention

In [5]:
from transformer import MultiheadedCrossAttention

def test_cross_attention():
    x = torch.ones((1,2,4))
    y = torch.ones((1,2,4))
    mca = MultiheadedCrossAttention(num_heads = 2, d_model = 4, identity = True)
    output = mca(x, y)
    assert torch.allclose(x, output)
    print("TEST: MH Cross Attention PASSED!")

test_cross_attention()

TEST: MH Cross Attention PASSED!


#### ENCODER

In [6]:
from transformer import Encoder
from transformer import EncoderChain

##Create random input which matches batch size, max_sequence_length, and d_model

def test_encoder_output():

    batch_size = 3
    d_model = 6
    max_sequence_length = 10
    num_heads = 2

    x = torch.rand((batch_size, max_sequence_length, d_model))
    print(x.dtype)
    mask = torch.rand((max_sequence_length, max_sequence_length))

    enc = Encoder(d_model = d_model, num_heads = num_heads, ffn_hidden = 10, drop_prob = 0.1)
    output = enc(x,mask)

    assert output.size() == x.size()
    print("TEST: Encoder Throughput PASSED!")

def test_encoder_chain_output():

    batch_size = 3
    d_model = 6
    max_sequence_length = 10
    num_heads = 2
    chain_length = 3

    x = torch.rand((batch_size, max_sequence_length, d_model))
    mask = torch.rand((max_sequence_length, max_sequence_length))

    enc = Encoder(d_model = d_model, num_heads = num_heads, ffn_hidden = 10, drop_prob = 0.1)
    enc_chain = EncoderChain(
                    *[Encoder(
                        d_model=d_model,
                        num_heads=num_heads,
                        ffn_hidden=10,
                        drop_prob=0.1
                    ) for _ in range(chain_length)]
                )
    output = enc_chain(x,mask)
    output_single = enc(x,mask)

    assert output.size() == x.size()
    assert torch.allclose(output, output_single) == False
    print("TEST: Encoder Chain Throughput PASSED!")

test_encoder_output()
test_encoder_chain_output()
    

torch.float32
TEST: Encoder Throughput PASSED!
TEST: Encoder Chain Throughput PASSED!


#### DECODER

In [7]:
from transformer import Decoder
from transformer import DecoderChain

##Create random input which matches batch size, max_sequence_length, and d_model

def test_decoder_output():

    batch_size = 3
    d_model = 6
    max_sequence_length = 10
    num_heads = 2

    x = torch.rand((batch_size, max_sequence_length, d_model))
    mask = torch.rand((max_sequence_length, max_sequence_length))

    dec = Decoder(d_model = d_model, num_heads = num_heads, ffn_hidden = 10, drop_prob = 0.1)
    output = dec(x,x, mask, mask)

    assert output.size() == x.size()
    print("TEST: Decoder Throughput PASSED!")

def test_decoder_chain_output():

    batch_size = 3
    d_model = 6
    max_sequence_length = 10
    num_heads = 2
    chain_length = 3

    x = torch.rand((batch_size, max_sequence_length, d_model))
    mask = torch.rand((max_sequence_length, max_sequence_length))

    dec = Decoder(d_model = d_model, num_heads = num_heads, ffn_hidden = 10, drop_prob = 0.1)
    dec_chain = DecoderChain(
                    *[Decoder(
                        d_model=d_model,
                        num_heads=num_heads,
                        ffn_hidden=10,
                        drop_prob=0.1
                    ) for _ in range(chain_length)]
                )
    output = dec_chain(x,x,mask,mask)
    output_single = dec(x,x, mask,mask)

    assert output.size() == x.size()
    assert torch.allclose(output, output_single) == False
    print("TEST: Decoder Chain Throughput PASSED!")

test_decoder_output()
test_decoder_chain_output()

TEST: Decoder Throughput PASSED!
TEST: Decoder Chain Throughput PASSED!


#### LAYER NORM

In [8]:
def test_layer_norm():
    ones = torch.ones((4,4))
    x = torch.triu(ones)
    # print(x)
    
    normalizer = LayerNorm(d_model = 4, eps = 1e-5)
    layernorm = normalizer(x)
    # print(layernorm)
    
    torch_normalizer = nn.LayerNorm(normalized_shape = 4, eps = 1e-05)
    expected = torch_normalizer(x)
    # print(expected)
    
    assert torch.allclose(layernorm, expected)
    print("TEST: Layer Norm Generation PASSED!")

test_layer_norm()

TEST: Layer Norm Generation PASSED!


#### POSITIONAL ENCODER

In [9]:
from transformer import PositionalEncoder

def test_positional_encoder():

    pe = PositionalEncoder(max_sequence_length = 10, d_model = 6)

    pos_encoding = pe.generate()
    # print(pos_encoding)

    assert 1 == 1
    print("TEST: PE Encodings PASSED!")

test_positional_encoder()
    

TEST: PE Encodings PASSED!


#### DATASET CREATOR

In [12]:
from dataset import RecursionDatasetCreator
from torch.utils.data import Dataset, DataLoader

def test_dataset_creator():

    START_TOKEN = '<sos>'
    PADDING_TOKEN = '<pad>'
    END_TOKEN = '<eos>'

    recursion_vocabulary = ['a', '_', 'n', '+', '1', '=', '*', '/', '-', '^',
                            '2', '3', '4', '5', '6', '7', '8', '0', '9', '(', ')',' ', START_TOKEN, PADDING_TOKEN, END_TOKEN]
    
    solution_vocabulary = ['a', '_', 'n', '+', '1', '=', '*', '/', '-', '^',
                            '2', '3', '4', '5', '6', '7', '8', '9', '0', '(', ')',' ',START_TOKEN, PADDING_TOKEN, END_TOKEN]

    dataset_creator = RecursionDatasetCreator("./recursions.txt", "./solutions.txt", 100, 10, recursion_vocabulary, solution_vocabulary)

    recursion_dataset = dataset_creator.create_recursion_dataset()

    batch_size = 3
    train_loader = DataLoader(recursion_dataset, batch_size)
    iterator = iter(train_loader)

    for batch_num, batch in enumerate(iterator):

        print(batch)

    
    print("TEST: Dataset Loader PASSED!")
    return

test_dataset_creator()

4.0
[('a_n4', 'a_n4', 'a_n4'), ('an =', 'an =', 'an =')]
[('a_n4', 'a_n4', 'a_n4'), ('an =', 'an =', 'an =')]
[('a_n4', 'a_n4', 'a_n4'), ('an =', 'an =', 'an =')]
[('a_n4',), ('an =',)]
TEST: Dataset Loader PASSED!


#### Tokenizer and Mask Geneneration Unit Test

In [13]:
from transformer import Tokenizer
from transformer import generate_masks_tokenized

def test_tokenizer():

    START = '<sos>'
    PAD = '<pad>'
    END = '<eos>'

    recursion_vocabulary = ['a', '_', 'n', '+', '1', '=', '*', '/', '-', '^',
                            '2', '3', '4', '5', '6', '7', '8', '0', '9', '(', ')',' ', START, PAD, END]
    
    solution_vocabulary = ['a', '_', 'n', '+', '1', '=', '*', '/', '-', '^',
                            '2', '3', '4', '5', '6', '7', '8', '9', '0', '(', ')',' ',START, PAD, END]

    in_tkzr = Tokenizer(recursion_vocabulary, START, PAD, END)

    in_sentences = ["an", "27", "2"]
    tok_input = in_tkzr.tokenize(in_sentences)
    tok_input = in_tkzr.pad(tok_input, 5, start = True, end = True)
    print(tok_input)

    out_sentences = ["a", "7", "21"]
    out_tkzr = Tokenizer(solution_vocabulary, START, PAD, END)
    tok_output = out_tkzr.tokenize(out_sentences)
    tok_output = out_tkzr.pad(tok_output, 5, start = True, end = True)
    print(tok_output)

    enc_mask, dec_mask, cross_mask = generate_masks_tokenized(tok_input, tok_output, 23, 23)
    print(enc_mask)
    print(dec_mask)
    print(cross_mask)
    print("TEST Passed! ENCODER IS GOOD, DECODER is GOOD, CROSS is GOOD!")

test_tokenizer()
    
    

torch.Size([3, 5])
tensor([[22,  0,  2, 24, 23],
        [22, 10, 15, 24, 23],
        [22, 10, 24, 23, 23]])
torch.Size([3, 5])
tensor([[22,  0, 24, 23, 23],
        [22, 15, 24, 23, 23],
        [22, 10,  4, 24, 23]])
tensor([[[-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -1.0000e+09],
         [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -1.0000e+09],
         [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -1.0000e+09],
         [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -1.0000e+09],
         [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09]],

        [[-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -1.0000e+09],
         [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -1.0000e+09],
         [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -1.0000e+09],
         [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -1.0000e+09],
         [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09]],

        [[-0.000

In [5]:
from serialization import serialize
from serialization import unserialize

def test_serialize_unserialize(expression):

    print(expression, unserialize(serialize(expression)))
    return

          
test_serialize_unserialize("(1111 ^111)*11^(1+ -1)")

def test_serialize_two(expression):

    print(expression, serialize(expression))
    print(expression, unserialize(serialize(expression)))

test_serialize_two("1*(1 + 1) + 1")

#ok so there is this flaw in the serialize and unserialize algorithm

(1111 ^111)*11^(1+ -1) ((1111 ^ 111) * (11 ^ (1 + -1)))
1*(1 + 1) + 1 * 1 + + 1 1 1
1*(1 + 1) + 1 (1 * ((1 + 1) + 1))
