In [1]:
import pandas as pd
import numpy as np
import math
import ast

import torch
from torch import nn
import torch.nn.functional as F
from google.colab import drive

torch.manual_seed(1337)
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!wget https://raw.githubusercontent.com/piyush-jena/pos_tagging_transformer/main/data/test_data.csv
!wget https://raw.githubusercontent.com/piyush-jena/pos_tagging_transformer/main/data/train_data.csv
!wget https://raw.githubusercontent.com/piyush-jena/pos_tagging_transformer/main/data/valid_data.csv

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [25]:
test_data = pd.read_csv('test_data.csv')
train_data = pd.read_csv('train_data.csv')
valid_data = pd.read_csv('valid_data.csv')
word_vectors = pd.read_csv('drive/MyDrive/wv.csv')

In [4]:
dictionary = {}
dictionary_size = word_vectors['vectors'].shape[0]

for i in range(dictionary_size):
    wv = np.fromstring(word_vectors['vectors'][i][1:-1], sep=' ')
    dictionary[word_vectors['word'][i]] = torch.from_numpy(wv)

In [5]:
def cmp(s, dt, t):
  ex = torch.all(dt == t.grad).item()
  app = torch.allclose(dt, t.grad)
  maxdiff = (dt - t.grad).abs().max().item()
  print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')

In [6]:
d_model = 64
num_heads = 4
max_sequence_length = 64
ffn_hidden = 256
num_layers = 4
learning_rate = 1e-3
max_iters = 10000
eval_interval = 500
eval_iters = 10

block_size = 10
batch_size = 1

In [7]:
X_train = []
y_train = []

X_test = []
y_test = []

X_valid = []
y_valid = []

for i in range(len(train_data)):
    tokens = ast.literal_eval(train_data['tokens'][i])
    pos_tags = ast.literal_eval(train_data['pos_tags'][i])
    X = []
    y = []

    for j in range(len(tokens)):
        if (tokens[j] not in dictionary):
            X += [torch.zeros(64)]
        else:
            X += [dictionary[tokens[j]]]
        y += [torch.tensor(pos_tags[j])]
    if len(tokens) < block_size:
        for j in range(block_size-len(tokens)):
            X += [torch.zeros(64)]
            y += [torch.tensor(45)]

    X_train.append(X)
    y_train.append(y)

for i in range(len(valid_data)):
    tokens = ast.literal_eval(valid_data['tokens'][i])
    pos_tags = ast.literal_eval(valid_data['pos_tags'][i])
    X = []
    y = []

    for j in range(len(tokens)):
        if (tokens[j] not in dictionary):
            X += [torch.zeros(64)]
        else:
            X += [dictionary[tokens[j]]]
        y += [torch.tensor(pos_tags[j])]
    if len(tokens) < block_size:
        for j in range(block_size-len(tokens)):
            X += [torch.zeros(64)]
            y += [torch.tensor(45)]

    X_valid.append(X)
    y_valid.append(y)

for i in range(len(test_data)):
    tokens = ast.literal_eval(test_data['tokens'][i])
    pos_tags = ast.literal_eval(test_data['pos_tags'][i])
    X = []
    y = []

    for j in range(len(tokens)):
        if (tokens[j] not in dictionary):
            X += [torch.zeros(64)]
        else:
            X += [dictionary[tokens[j]]]
        y += [torch.tensor(pos_tags[j])]
    if len(tokens) < block_size:
        for j in range(block_size-len(tokens)):
            X += [torch.zeros(64)]
            y += [torch.tensor(45)]

    X_test.append(X)
    y_test.append(y)

In [8]:
def get_batch(split):
    xtemp = []
    ytemp = []

    if split == 'train':
        X, Y = X_train, y_train
    elif split == 'valid':
        X, Y = X_valid, y_valid
    else:
        X, Y = X_test, y_test

    ix = torch.randint(len(X), (batch_size, ))
    for i in ix:
        if len(X[i]) > block_size:
            j = torch.randint(len(X[i]) - block_size, (1, ))
            xtemp.append(torch.stack(X[i][j:j+block_size]))
            ytemp.append(torch.stack(Y[i][j:j+block_size]))
        else:
            xtemp.append(torch.stack(X[i]))
            ytemp.append(torch.stack(Y[i]))

    x, y = torch.stack(xtemp), torch.stack(ytemp)
    x, y = x.type(torch.FloatTensor), y
    x, y = x.to(device), y.to(device)
    return x, y

In [9]:
@torch.no_grad()
def estimate_loss(model):
    out = {}
    accuracy = {}
    model.eval()
    for split in ['train', 'valid']:
        losses = torch.zeros(eval_iters)
        total = 0
        correct = 0
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()

            N, _ = logits.shape
            Y = Y.view(N)
            total += N
            for i in range(N):
                probs = F.softmax(logits[i], -1)
                idx = torch.multinomial(probs, num_samples=1)
                if (idx in [21, 22, 23, 24, 25, 28, 29] and Y[i] in [21, 22, 23, 24, 25, 28, 29]):
                    correct += 1
                elif (idx in [37, 38, 39, 40, 41, 42] and Y[i] in [37, 38, 39, 40, 41, 42]):
                    correct += 1
                elif (idx in [16, 17, 18, 30, 31, 32] and Y[i] in [16, 17, 18, 30, 31, 32]):
                    correct += 1
                elif (idx not in [21, 22, 23, 24, 25, 28, 29, 37, 38, 39, 40, 41, 42, 16, 17, 18, 30, 31, 32] and Y[i] not in [21, 22, 23, 24, 25, 28, 29, 37, 38, 39, 40, 41, 42, 16, 17, 18, 30, 31, 32]):
                    correct += 1

        out[split] = losses.mean()
        accuracy[split] = (float(correct)/total)
    model.train()
    return out, accuracy

In [10]:
class Head(nn.Module):
    def __init__(self, embd_size, num_heads):
        super().__init__()
        self.query = nn.Linear(embd_size, embd_size // num_heads, bias=False)
        self.key = nn.Linear(embd_size, embd_size // num_heads, bias=False)
        self.value = nn.Linear(embd_size, embd_size // num_heads, bias=False)

    def forward(self, x):
        _, _, C = x.shape
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)
        wei = q @ k.transpose(-1, -2) / (C ** 0.5)
        wei = F.softmax(wei, dim=-1)
        out = wei @ v
        return out

class MultiheadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.heads = nn.ModuleList([Head(d_model, num_heads) for _ in range(num_heads)])
        self.proj = nn.Linear(d_model, d_model)

    def forward(self, x):
        x = torch.cat([h(x) for h in self.heads], dim=-1)
        x = self.proj(x)
        return x

class LayerNorm(nn.Module):
    def __init__(self, parameters_shape, eps=1e-5):
        super().__init__()
        self.parameters_shape=parameters_shape
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta =  nn.Parameter(torch.zeros(parameters_shape))

    def forward(self, inputs):
        dims = [-(i + 1) for i in range(len(self.parameters_shape))]
        mean = inputs.mean(dim=dims, keepdim=True)
        var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
        std = (var + self.eps).sqrt()
        y = (inputs - mean) / std
        out = self.gamma * y  + self.beta
        return out


class FeedForward(nn.Module):
    def __init__(self, d_model, hidden):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, hidden),
            nn.Tanh(),
            nn.Linear(hidden, d_model),
        )

    def forward(self, x):
        return self.net(x)


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len: int = 5000):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = self.pe[:x.size(1)]
        return x


class EncoderLayer(nn.Module):
    def __init__(self, d_model, ffn_hidden, num_heads):
        super(EncoderLayer, self).__init__()
        self.attention = MultiheadAttention(d_model=d_model, num_heads=num_heads)
        self.ffn = FeedForward(d_model=d_model, hidden=ffn_hidden)
        self.norm1 = LayerNorm(parameters_shape=[d_model])
        self.norm2 = LayerNorm(parameters_shape=[d_model])

    def forward(self, x):
        x = x + self.attention(self.norm1(x))
        x = x + self.ffn(self.norm2(x))
        return x

class Encoder(nn.Module):
    def __init__(self, d_model, ffn_hidden, num_heads, num_layers, block_size, d_input, d_output):
        super().__init__()
        self.position_embedding_table = PositionalEncoding(d_model, max_len = block_size)
        self.layers = nn.Sequential(*[EncoderLayer(d_model, ffn_hidden, num_heads) for _ in range(num_layers)])
        self.norm = LayerNorm(parameters_shape=[d_model])
        self.linear = nn.Linear(d_model, d_output)

    def forward(self, x, target=None):
        embds = x + self.position_embedding_table(x)
        embds = self.layers(embds)
        embds = self.norm(embds)
        logits = self.linear(embds)

        if target is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            target = target.view(B*T)
            loss = F.cross_entropy(logits, target)
        return logits, loss

In [11]:
x, y = get_batch('train')

In [19]:
n = batch_size

g = torch.Generator().manual_seed(2147483647) # for reproducibility

position_embedding_table = PositionalEncoding(d_model, max_len = block_size)

W = torch.randn((d_model, 47), generator=g) * (5/3)/((d_model * 47)**0.5)
b = torch.randn(47,            generator=g) * 0.1

key11 = torch.randn((d_model, d_model // 2),          generator=g) * 0.1
query11 = torch.randn((d_model, d_model // 2),          generator=g) * 0.1
value11 = torch.randn((d_model, d_model // 2),          generator=g) * 0.1

key12 = torch.randn((d_model, d_model // 2),          generator=g) * 0.1
query12 = torch.randn((d_model, d_model // 2),          generator=g) * 0.1
value12 = torch.randn((d_model, d_model // 2),          generator=g) * 0.1

key21 = torch.randn((d_model, d_model // 2),          generator=g) * 0.1
query21 = torch.randn((d_model, d_model // 2),          generator=g) * 0.1
value21 = torch.randn((d_model, d_model // 2),          generator=g) * 0.1

key22 = torch.randn((d_model, d_model // 2),          generator=g) * 0.1
query22 = torch.randn((d_model, d_model // 2),          generator=g) * 0.1
value22 = torch.randn((d_model, d_model // 2),          generator=g) * 0.1

linear1 = torch.randn((d_model, d_model), generator=g) * (5/3)/((d_model * d_model)**0.5)
bias1 = torch.randn(d_model,            generator=g) * 0.1

linear2 = torch.randn((d_model, d_model), generator=g) * (5/3)/((d_model * d_model)**0.5)
bias2 = torch.randn(d_model,            generator=g) * 0.1

W11 = torch.randn((d_model, ffn_hidden), generator=g) * (5/3)/((d_model * ffn_hidden)**0.5)
b11 = torch.randn(ffn_hidden,            generator=g) * 0.1

W12 = torch.randn((ffn_hidden, d_model), generator=g) * (5/3)/((ffn_hidden * d_model)**0.5)
b12 = torch.randn(d_model,            generator=g) * 0.1

W21 = torch.randn((d_model, ffn_hidden), generator=g) * (5/3)/((d_model * ffn_hidden)**0.5)
b21 = torch.randn(ffn_hidden,            generator=g) * 0.1

W22 = torch.randn((ffn_hidden, d_model), generator=g) * (5/3)/((ffn_hidden * d_model)**0.5)
b22 = torch.randn(d_model,            generator=g) * 0.1

parameters = [W, b, W11, W12, W21, W22, b11, b12, b21, b22, linear1, bias1, linear2, bias2, key11, query11, value11, key21, query21, value21, key12, query12, value12, key22, query22, value22]
print(sum(p.nelement() for p in parameters)) # number of parameters in total
for p in parameters:
  p.requires_grad = True

102127


In [20]:
pos_emb = position_embedding_table(x)
embd = x + pos_emb

_, _, C = embd.shape

k11 = embd @ key11
q11 = embd @ query11
v11 = embd @ value11

wei11 = q11 @ k11.transpose(-1, -2) / (C ** 0.5)
swei11 = F.softmax(wei11, dim=-1)
out11 = swei11 @ v11

k12 = embd @ key12
q12 = embd @ query12
v12 = embd @ value12

wei12 = q12 @ k12.transpose(-1, -2) / (C ** 0.5)
swei12 = F.softmax(wei12, dim=-1)
out12 = swei12 @ v12

attention1 = torch.cat([out11, out12], dim=-1)
attention_output1 = embd + attention1 @ linear1 + bias1
ffn11 = attention_output1 @ W11 + b11
affn11 = F.tanh(ffn11)
ffn12 = affn11 @ W12 + b12
encoder_output1 = attention_output1 + ffn12

k21 = encoder_output1 @ key21
q21 = encoder_output1 @ query21
v21 = encoder_output1 @ value21

wei21 = q21 @ k21.transpose(-1, -2) / (C ** 0.5)
swei21 = F.softmax(wei21, dim=-1)
out21 = swei21 @ v21

k22 = encoder_output1 @ key22
q22 = encoder_output1 @ query22
v22 = encoder_output1 @ value22

wei22 = q22 @ k22.transpose(-1, -2) / (C ** 0.5)
swei22 = F.softmax(wei22, dim=-1)
out22 = swei22 @ v22

attention2 = torch.cat([out21, out22], dim=-1)
attention_output2 = attention2 @ linear2 + bias2
ffn21 = attention_output2 @ W21 + b21
affn21 = F.tanh(ffn21)
ffn22 = affn21 @ W22 + b22
encoder_output2 = attention_output2 + ffn22
logits = encoder_output2 @ W + b

B, T, C1 = logits.shape
logits1 = logits.view(B*T, C1)
logit_maxes = logits1.max(1, keepdim=True).values
norm_logits = logits1 - logit_maxes # subtract max for numerical stability
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdims=True)
counts_sum_inv = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(block_size), y[0]].mean()

for p in parameters:
    p.grad = None
for t in [logits1, logits, encoder_output2, encoder_output1, ffn22, ffn12, affn21, ffn21, # afaik there is no cleaner way
          attention_output2, attention2, out22, swei22, wei22, v22,
         q22, k22, out21, swei21, wei21, v21, q21, k21, affn11, ffn11,
         attention_output1, attention1, out12, swei12, wei12, v12,
         q12, k12, out11, swei11, wei11, v11, q11, k11]:
    t.retain_grad()

In [21]:
loss.backward()

In [22]:
dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(block_size), y[0]] = -1.0/block_size
dprobs = (1.0 / probs) * dlogprobs
dcounts_sum_inv = (counts * dprobs).sum(1, keepdim=True)
dcounts = counts_sum_inv * dprobs
dcounts_sum = (-counts_sum**-2) * dcounts_sum_inv
dcounts += torch.ones_like(counts) * dcounts_sum
dnorm_logits = counts * dcounts
dlogits1 = dnorm_logits.clone()
dlogit_maxes = (-dnorm_logits).sum(1, keepdim=True)
dlogits1 += F.one_hot(logits1.max(1).indices, num_classes=logits1.shape[1]) * dlogit_maxes
dlogits = torch.unsqueeze(dlogits1, 0)

dencoder_output2 = torch.unsqueeze(torch.squeeze(dlogits) @ W.T, 0)
dW = torch.squeeze(encoder_output2.transpose(-2, -1) @ dlogits)
db = torch.squeeze(dlogits.sum(1))
dffn22 = dencoder_output2.clone()
dattention_output2 = dencoder_output2.clone()

daffn21 = dffn22 @ W22.T
dW22 = torch.squeeze(affn21.transpose(-2, -1) @ dffn22)
db22 = torch.squeeze(dffn22.sum(1))
dffn21 = daffn21 * (1 - affn21**2)

dattention_output2 += dffn21 @ W21.T
dW21 = torch.squeeze(attention_output2.transpose(-2, -1) @ dffn21)
db21 = torch.squeeze(dffn21.sum(1))
dattention2 = dattention_output2 @ linear2.T
dlinear2 = torch.squeeze(attention2.transpose(-2, -1) @ dattention_output2)
dbias2 = torch.squeeze(dattention_output2.sum(1))
dout21 = dattention2[:,:,:32]
dout22 = dattention2[:,:,32:]

dswei22 = dout22 @ v22.transpose(-2, -1)
dv22 = torch.squeeze(swei22.transpose(-2, -1) @ dout22)

dwei22 = torch.zeros([10, 10])
for i in range(10):
    dwei22[i, :] = (torch.unsqueeze(dswei22[0][i], 0) @ (-torch.outer(swei22[0][i], swei22[0][i]) + torch.diag(swei22[0][i])))
dwei22 = torch.unsqueeze(dwei22, 0)

dk22 = dwei22.transpose(-2, -1) @ q22 / (C ** 0.5)
dq22 = dwei22 @ k22 / (C ** 0.5)

dvalue22 = torch.squeeze(encoder_output1.transpose(-2, -1) @ dv22)
dkey22 = torch.squeeze(encoder_output1.transpose(-2, -1) @ dk22)
dquery22 = torch.squeeze(encoder_output1.transpose(-2, -1) @ dq22)

dencoder_output1 = dv22 @ value22.transpose(-2, -1) + dq22 @ query22.transpose(-2, -1) + dk22 @ key22.transpose(-2, -1)

dswei21 = dout21 @ v21.transpose(-2, -1)
dv21 = torch.squeeze(swei21.transpose(-2, -1) @ dout21)

dwei21 = torch.zeros([10, 10])
for i in range(10):
    dwei21[i, :] = (torch.unsqueeze(dswei21[0][i], 0) @ (-torch.outer(swei21[0][i], swei21[0][i]) + torch.diag(swei21[0][i])))
dwei21 = torch.unsqueeze(dwei21, 0)

dk21 = dwei21.transpose(-2, -1) @ q21 / (C ** 0.5)
dq21 = dwei21 @ k21 / (C ** 0.5)

dvalue21 = torch.squeeze(encoder_output1.transpose(-2, -1) @ dv21)
dkey21 = torch.squeeze(encoder_output1.transpose(-2, -1) @ dk21)
dquery21 = torch.squeeze(encoder_output1.transpose(-2, -1) @ dq21)

dencoder_output1 += dv21 @ value21.transpose(-2, -1) + dq21 @ query21.transpose(-2, -1) + dk21 @ key21.transpose(-2, -1)

dffn12 = dencoder_output1.clone()
dattention_output1 = dencoder_output1.clone()

daffn11 = dffn12 @ W12.T
dW12 = torch.squeeze(affn11.transpose(-2, -1) @ dffn12)
db12 = torch.squeeze(dffn12.sum(1))
dffn11 = daffn11 * (1 - affn11**2)

dattention_output1 += dffn11 @ W11.T
dW11 = torch.squeeze(attention_output1.transpose(-2, -1) @ dffn11)
db11 = torch.squeeze(dffn11.sum(1))
dattention1 = dattention_output1 @ linear1.T
dlinear1 = torch.squeeze(attention1.transpose(-2, -1) @ dattention_output1)
dbias1 = torch.squeeze(dattention_output1.sum(1))
dout11 = dattention1[:,:,:32]
dout12 = dattention1[:,:,32:]

dswei12 = dout12 @ v12.transpose(-2, -1)
dv12 = torch.squeeze(swei12.transpose(-2, -1) @ dout12)

dwei12 = torch.zeros([10, 10])
for i in range(10):
    dwei12[i, :] = (torch.unsqueeze(dswei12[0][i], 0) @ (-torch.outer(swei12[0][i], swei12[0][i]) + torch.diag(swei12[0][i])))
dwei12 = torch.unsqueeze(dwei12, 0)

dk12 = dwei12.transpose(-2, -1) @ q12 / (C ** 0.5)
dq12 = dwei12 @ k12 / (C ** 0.5)

dvalue12 = torch.squeeze(embd.transpose(-2, -1) @ dv12)
dkey12 = torch.squeeze(embd.transpose(-2, -1) @ dk12)
dquery12 = torch.squeeze(embd.transpose(-2, -1) @ dq12)

dswei11 = dout11 @ v11.transpose(-2, -1)
dv11 = torch.squeeze(swei11.transpose(-2, -1) @ dout11)

dwei11 = torch.zeros([10, 10])
for i in range(10):
    dwei11[i, :] = (torch.unsqueeze(dswei11[0][i], 0) @ (-torch.outer(swei11[0][i], swei11[0][i]) + torch.diag(swei11[0][i])))
dwei11 = torch.unsqueeze(dwei11, 0)

dk11 = dwei11.transpose(-2, -1) @ q11 / (C ** 0.5)
dq11 = dwei11 @ k11 / (C ** 0.5)

dvalue11 = torch.squeeze(embd.transpose(-2, -1) @ dv11)
dkey11 = torch.squeeze(embd.transpose(-2, -1) @ dk11)
dquery11 = torch.squeeze(embd.transpose(-2, -1) @ dq11)

In [23]:
cmp('logits1', dlogits1, logits1)
cmp('logits', dlogits, logits)
cmp('encoder_output2', dencoder_output2, encoder_output2)
cmp('W', dW, W)
cmp('b', db, b)
cmp('ffn22', dffn22, ffn22)
cmp('affn21', daffn21, affn21)
cmp('W22', dW22, W22)
cmp('b22', db22, b22)
cmp('ffn21', dffn21, ffn21)
cmp('attention_output2', dattention_output2, attention_output2)
cmp('W22', dW21, W21)
cmp('b22', db21, b21)
cmp('attention2', dattention2, attention2)
cmp('linear2', dlinear2, linear2)
cmp('bias2', dbias2, bias2)
cmp('out21', dout21, out21)
cmp('out22', dout22, out22)
cmp('swei22', dswei22, swei22)
cmp('wei22', dwei22, wei22)
cmp('k22', dk22, k22)
cmp('q22', dq22, q22)
cmp('v22', dv22, v22)
cmp('value22', dvalue22, value22)
cmp('key22', dkey22, key22)
cmp('query22', dquery22, query22)
cmp('swei21', dswei21, swei21)
cmp('wei21', dwei21, wei21)
cmp('k21', dk21, k21)
cmp('q21', dq21, q21)
cmp('v21', dv21, v21)
cmp('value21', dvalue21, value21)
cmp('key21', dkey21, key21)
cmp('query21', dquery21, query21)
cmp('encoder_output1', dencoder_output1, encoder_output1)

cmp('ffn12', dffn12, ffn12)
cmp('affn11', daffn11, affn11)
cmp('W12', dW12, W12)
cmp('b12', db12, b12)
cmp('ffn11', dffn11, ffn11)
cmp('attention_output1', dattention_output1, attention_output1)
cmp('W11', dW11, W11)
cmp('b11', db11, b11)
cmp('attention1', dattention1, attention1)
cmp('linear1', dlinear1, linear1)
cmp('bias1', dbias1, bias1)
cmp('out11', dout11, out11)
cmp('out12', dout12, out12)
cmp('swei12', dswei12, swei12)
cmp('wei12', dwei12, wei12)
cmp('k12', dk12, k12)
cmp('q12', dq12, q12)
cmp('v12', dv12, v12)
cmp('value12', dvalue12, value12)
cmp('key12', dkey12, key12)
cmp('query12', dquery12, query12)
cmp('swei11', dswei11, swei11)
cmp('wei11', dwei11, wei11)
cmp('k11', dk11, k11)
cmp('q11', dq11, q11)
cmp('v11', dv11, v11)
cmp('value11', dvalue11, value11)
cmp('key11', dkey11, key11)
cmp('query11', dquery11, query11)

logits1         | exact: False | approximate: True  | maxdiff: 1.4901161193847656e-08
logits          | exact: False | approximate: True  | maxdiff: 1.4901161193847656e-08
encoder_output2 | exact: False | approximate: True  | maxdiff: 1.3969838619232178e-09
W               | exact: False | approximate: True  | maxdiff: 1.4901161193847656e-08
b               | exact: False | approximate: True  | maxdiff: 1.4901161193847656e-08
ffn22           | exact: False | approximate: True  | maxdiff: 1.3969838619232178e-09
affn21          | exact: False | approximate: True  | maxdiff: 2.3283064365386963e-10
W22             | exact: False | approximate: True  | maxdiff: 9.313225746154785e-10
b22             | exact: False | approximate: True  | maxdiff: 3.725290298461914e-09
ffn21           | exact: False | approximate: True  | maxdiff: 2.3283064365386963e-10
attention_output2 | exact: False | approximate: True  | maxdiff: 1.1641532182693481e-09
W22             | exact: False | approximate: True  | 

--2023-12-28 23:50:06--  https://raw.githubusercontent.com/piyush-jena/pos_tagging_transformer/main/data/test_data.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 737353 (720K) [text/plain]
Saving to: ‘test_data.csv’


2023-12-28 23:50:07 (14.6 MB/s) - ‘test_data.csv’ saved [737353/737353]

--2023-12-28 23:50:07--  https://raw.githubusercontent.com/piyush-jena/pos_tagging_transformer/main/data/train_data.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3235899 (3.1M) [text/plain]
Saving to: ‘train_data.csv’


2023-12-28 23: