# Neural Networks on MNIST

This Jupyter notebook explains various approaches for implementing neural networks that recognize digits on [MNIST](http://yann.lecun.com/exdb/mnist/) dataset.

## Preparing the MNIST dataset

Most deep learning frameworks provide APIs for loading famous datasets like MNIST (e.g., `torchvision.datasets.MNIST` in pytorch). The APIs are handy, but hide the important step for preparing a training data for a deep learning framework; when graduating from an example dataset to the real data, we must convert a training data of our interest into the data structure that is acceptable by a deep learning framework.

The cell below downloads the original distribution of the MNIST dataset on the Web, converts the dataset into `numpy` arrays, and saves the arrays as the file `mnist.npz` with keyword names.

In [1]:
import gzip
import os
import sys
import struct
import numpy as np

def read_image(fi):
    magic, n, rows, columns = struct.unpack(">IIII", fi.read(16))
    assert magic == 0x00000803
    assert rows == 28
    assert columns == 28
    rawbuffer = fi.read()
    assert len(rawbuffer) == n * rows * columns
    rawdata = np.frombuffer(rawbuffer, dtype='>u1', count=n*rows*columns)
    return rawdata.reshape(n, rows, columns).astype(np.float32) / 255.0

def read_label(fi):
    magic, n = struct.unpack(">II", fi.read(8))
    assert magic == 0x00000801
    rawbuffer = fi.read()
    assert len(rawbuffer) == n
    return np.frombuffer(rawbuffer, dtype='>u1', count=n)

if __name__ == '__main__':
    os.system('wget -N http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz')
    os.system('wget -N http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz')
    os.system('wget -N http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz')
    os.system('wget -N http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz')

    np.savez_compressed(
        'mnist',
        train_x=read_image(gzip.open('train-images-idx3-ubyte.gz', 'rb')),
        train_y=read_label(gzip.open('train-labels-idx1-ubyte.gz', 'rb')),
        test_x=read_image(gzip.open('t10k-images-idx3-ubyte.gz', 'rb')),
        test_y=read_label(gzip.open('t10k-labels-idx1-ubyte.gz', 'rb'))
    )

--2023-10-24 16:25:24--  http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Resolving yann.lecun.com (yann.lecun.com)... 2606:4700:3036::ac43:ab4c, 2606:4700:3034::6815:1d24, 172.67.171.76, ...
Connecting to yann.lecun.com (yann.lecun.com)|2606:4700:3036::ac43:ab4c|:80... connected.
HTTP request sent, awaiting response... 304 Not Modified
File ‘train-images-idx3-ubyte.gz’ not modified on server. Omitting download.

--2023-10-24 16:25:25--  http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Resolving yann.lecun.com (yann.lecun.com)... 2606:4700:3034::6815:1d24, 2606:4700:3036::ac43:ab4c, 172.67.171.76, ...
Connecting to yann.lecun.com (yann.lecun.com)|2606:4700:3034::6815:1d24|:80... connected.
HTTP request sent, awaiting response... 304 Not Modified
File ‘train-labels-idx1-ubyte.gz’ not modified on server. Omitting download.

--2023-10-24 16:25:25--  http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Resolving yann.lecun.com (yann.lecun.com)... 2606:4700:3034:

The file contains four numpy arrays (one tensor and array for each split of training and test sets) with the keywords:

+ `train_x`: $60000 \text{ (images)} \times 28 \text{ (y)} \times 28 \text{ (x)}$
+ `train_y`: $60000 \text{ (labels)}$
+ `test_x`: $10000 \text{ (images)} \times 28 \text{ (y)} \times 28 \text{ (x)}$
+ `test_y`: $10000 \text{ (labels)}$


### Install pytorch

In [2]:
!pip install torch torchvision transformers spacy ftfy==4.4.3 pandas

Looking in indexes: https://pypi.org/simple, https://pypi.idi.ntnu.no

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m23.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


### Transformers

In [18]:

import re
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from collections import Counter
from typing import List, Tuple
import pandas
from transformers import OpenAIGPTTokenizer
import pickle
class BeowulfDataset(Dataset):
    def __init__(self, file_path: str, sequence_length: int, sequence_length_enc: int, picklefile=None):
        # Read the text file
        text = pandas.read_csv("./gutenberg-poetry-dataset.csv")
        text = text[["title", "content"]]
        text = text.dropna()
        # Remove non-alphanumeric characters and tokenize

        # Build vocabulary
        self.tokenizer = OpenAIGPTTokenizer.from_pretrained("openai-gpt")
        self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        
        # Create input-output pairs using sliding window
        self.sequence_length = sequence_length
        self.pairs = [(x_a[0], x_a[1][:10000]) for x_a in text.to_numpy()]
        if picklefile is None:
            for i in range(len(self.pairs)):
                if len(self.tokenizer(self.pairs[i][0],return_tensors="pt", truncation=True, max_length=self.sequence_length, padding="max_length")["input_ids"].squeeze()) == 0 or len(self.tokenizer(self.pairs[i][0],return_tensors="pt", truncation=True, max_length=self.sequence_length, padding="max_length")["input_ids"].squeeze()) == 0:
                    self.pairs.pop(i)
                i -= 1
            self.threes = []
            o = 0
            for pair in self.pairs:
                o+=1
                print("\r",o, end="")
                x, y = pair
                x = self.tokenizer(x, return_tensors="pt", truncation=True, max_length=sequence_length_enc, padding="max_length")["input_ids"].squeeze()
                y = self.tokenizer(y, return_tensors="pt", truncation=True, padding="max_length", max_length=self.sequence_length+1)["input_ids"].squeeze()
                if y is None or len(y) == 0:
                    continue
                for i in range(min(len(y), self.sequence_length+1)):
                    z = y[i]
                    y_z = y.clone()
                    y_z[i:] = self.tokenizer.pad_token_id
                    y_z = y_z[:self.sequence_length]
                    self.threes.append((x,y_z,z))
            with open("./threes.pickle", "wb") as f:
                pickle.dump(self.threes, f)
        else:
            with open(picklefile, "rb") as f:
                self.threes = pickle.load(f)

    def __len__(self):
        return len(self.threes)

    def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        return self.threes[idx]

    def num_unique_tokens(self):
        return len(self.tokenizer)
    

    def tokens_to_text(self, tokens:torch.Tensor) -> str:
        """
        Convert a list of token IDs back to a string of text.
        """
        if tokens.shape[0] == 1:
            tokens = [tokens.item()]
        else: 
            tokens = tokens.tolist()
        
        txt = [self.tokenizer.convert_tokens_to_string(self.tokenizer.convert_ids_to_tokens(token)) for token in tokens]
        return ' '.join(txt)
    
    def text_to_tokens(self, text: str) -> torch.Tensor:
        """
        Convert a list of token IDs back to a string of text.
        """
        out = self.tokenizer(text)
        return out


file_path = 'gutenberg-poetry-dataset.csv'
batch_size = 128
shuffle = True
sequence_length = 100  # for example, predicting 10th word based on previous 9 words
sequence_length_enc = 25

dataset = BeowulfDataset(file_path, sequence_length, sequence_length_enc)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

# Example usage:
print(f"Number of unique tokens in the dataset: {dataset.num_unique_tokens()}")

print(dataset[128])
uniques = dataset.num_unique_tokens()
print(uniques)



 1150Number of unique tokens in the dataset: 40479
(tensor([13329, 30488,   535,  4562,  1024, 14585,   702,  3878, 16262,  1580,
        26621, 40478, 40478, 40478, 40478, 40478, 40478, 40478, 40478, 40478,
        40478, 40478, 40478, 40478, 40478]), tensor([  498,   525, 10519,  2455,   240,  3690,  4665,   632,   523, 40477,
         1739,  1767,   666,   481,  1276,   240,   488,   589,   622, 29992,
          240, 40477,   556,  4250,   498,  7756,   240, 40478, 40478, 40478,
        40478, 40478, 40478, 40478, 40478, 40478, 40478, 40478, 40478, 40478,
        40478, 40478, 40478, 40478, 40478, 40478, 40478, 40478, 40478, 40478,
        40478, 40478, 40478, 40478, 40478, 40478, 40478, 40478, 40478, 40478,
        40478, 40478, 40478, 40478, 40478, 40478, 40478, 40478, 40478, 40478,
        40478, 40478, 40478, 40478, 40478, 40478, 40478, 40478, 40478, 40478,
        40478, 40478, 40478, 40478, 40478, 40478, 40478, 40478, 40478, 40478,
        40478, 40478, 40478, 40478, 40478, 40

In [9]:
import math
from matplotlib import pyplot as plt
#torch.manual_seed(0)
class self_attention_pure(nn.Module):
  def __init__(self, device="cuda:0", p_in=12, mid = 512, dims=0, num_heads=4):
    super().__init__()
    self.num_words = p_in
    self.num_dims = p_in
    self.num_heads = num_heads
    if dims != 0:
      self.num_dims = dims
    if self.num_dims % num_heads  != 0:
      raise "num heads must divide num dims"

    self.Q = nn.Linear(self.num_dims, self.num_dims).to(device)
    self.K = nn.Linear(self.num_dims, self.num_dims).to(device)
    self.V = nn.Linear(self.num_dims, self.num_dims).to(device)
    self.softmax = nn.Softmax(dim=3)
    self.preoutput_0 = nn.Linear(self.num_dims,mid).to(device)
    self.preoutput_1 = nn.Linear(mid, self.num_dims).to(device)
    self.gelu = nn.GELU()
    self.device=device
    self.normalize = nn.LayerNorm([p_in, self.num_dims]).to(device)


  def forward(self, x, device="cuda:0", visualization=False):
    x = x.to(self.device)
    id = x
    x_q = self.Q(x).view(x.shape[0], -1, self.num_heads, self.num_dims//self.num_heads)
    x_k = self.K(x).view(x.shape[0], -1, self.num_heads, self.num_dims//self.num_heads)
    x_v = self.V(x).view(x.shape[0], -1, self.num_heads, self.num_dims//self.num_heads)
    x_q = x_q.permute(0, 2,1,3)
    x_k = x_k.permute(0, 2,1,3)
    x_v = x_v.permute(0, 2,1,3)

    x_qk = x_q.matmul(x_k.transpose(2,3))/math.sqrt(self.num_dims)
    x = self.softmax(x_qk).matmul(x_v)
    if visualization:
      vis_softmax = nn.Softmax(dim=1)
      for i in range(x_qk.shape[1]):
        print(x_qk.shape)
        plt.imshow(vis_softmax(x_qk[0][i]).clone().detach().cpu(), vmin=0, vmax=1)
        plt.show()
      print(self.softmax(x_qk).shape)
    x = x.permute(0, 2, 1, 3)
    x = x.reshape(x.shape[0], -1, self.num_dims)
    x = self.preoutput_0(x)
    x = self.preoutput_1(x)
    x += id
    x = self.normalize(x)
    x = self.gelu(x)

    return x


class self_attention_w_pos_enc(nn.Module):
  def __init__(self, device="cuda:0", words = 12, dims = 12):
    super().__init__()
    self.num_words = words
    self.num_dims = dims
    self.embedding = nn.Embedding(uniques,self.num_dims).to(device)
    self.pos_encoding = torch.zeros((self.num_words,self.num_dims))

    for i in range(self.num_words):
      for j in range(self.num_dims):
        if j % 2 == 0:
          self.pos_encoding[i,j] = math.sin(i/math.pow(10000, j/self.num_dims))
        else:
          self.pos_encoding[i,j] = math.cos(i/math.pow(10000, j-1/self.num_dims))
    self.pos_encoding = self.pos_encoding.to(device)

    self.attention = self_attention_pure(device=device, p_in = self.num_dims, dims = self.num_dims, mid=512)
    self.output = nn.Linear(512, 519).to(device)

  def forward(self, x, device="cuda:0", visualization=False):
    x = x.to(device)
    x = self.embedding(x)
    x += self.pos_encoding
    x = self.attention(x, visualization=visualization)
    x = self.output(x)

    return x

class self_attention_w_pos_enc_pure(nn.Module):
  def __init__(self, device="cuda:0", num_words = 12, num_dims=12, out=512, num_heads=4, emb_size=uniques):
    super().__init__()
    self.num_words = num_words
    self.num_dims = num_dims
    self.embedding = nn.Embedding(emb_size,self.num_dims).to(device)
    self.pos_encoding = torch.zeros((self.num_words,self.num_dims))
    self.device=device

    for i in range(self.num_words):
      for j in range(self.num_dims):
        if j % 2 == 0:
          self.pos_encoding[i,j] = math.sin(i/math.pow(10000, j/self.num_dims))
        else:
          self.pos_encoding[i,j] = math.cos(i/math.pow(10000, (j-1)/self.num_dims))
    self.pos_encoding = self.pos_encoding.to(device)

    self.attention = self_attention_pure(device=device, p_in = self.num_words, dims = self.num_dims, mid=out, num_heads=num_heads)

  def forward(self, x, device="cuda:0", visualization=False):
    x = x.to(self.device)
    x = self.embedding(x)
    x += self.pos_encoding
    x = self.attention(x, visualization=visualization)
    return x


class multi_layer_self_attention(nn.Module):
  def __init__(self, device="cuda:0", outs=[144,256,512,519], seq_length=12, dims=12, num_heads=4):
    super().__init__()
    self.attention_w_enc = self_attention_w_pos_enc_pure(device=device, num_words=seq_length, num_dims=dims, out=outs[0], num_heads=num_heads)
    self.attention_1 = self_attention_pure(device=device, p_in=seq_length, dims=dims, mid=outs[0], num_heads=num_heads)
    self.attention_2 = self_attention_pure(device=device, p_in=seq_length, dims=dims, mid=outs[1], num_heads=num_heads)
    self.attention_3 = self_attention_pure(device=device, p_in=seq_length, dims=dims, mid=outs[2], num_heads=num_heads)
    self.linear_1 = nn.Linear(seq_length*dims, outs[0]).to(device)
    self.linear_2 = nn.Linear(outs[0], outs[1]).to(device)
    self.flatten = nn.Flatten(start_dim=1)
    self.output = nn.Linear(outs[1], outs[3]).to(device)
    self.device=device

  def forward(self, x, visualization=False):
    x = x.to(self.device)
    x = self.attention_w_enc(x, visualization=visualization)
    x = self.attention_1(x, visualization=visualization)
    x = self.attention_2(x, visualization=False)
    x = self.attention_3(x, visualization=False)

    x = self.flatten(x)
    x = self.linear_1(x)
    x = self.linear_2(x)
    x = self.output(x)
    return x

seq_length = sequence_length
self_atten = multi_layer_self_attention(seq_length=seq_length, dims=256, num_heads=8, outs=[768,1024,512,uniques])
#p = self_atten(numdata[:12])

optimizer = torch.optim.Adam(self_atten.parameters(), lr = 0.00001, weight_decay=0.000005)
loss = nn.CrossEntropyLoss()
schedule = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2, cooldown=3, factor=0.333)

"""

for k in range(25):
  loss_n = 0
  acc = 0
  for i, batch in enumerate(dataloader):
    optimizer.zero_grad()
    print(batch[0]["input_ids"].shape)
    x = batch[0]["input_ids"].view((-1, 24))[:, :23]
    print(x.shape)
    y = torch.zeros((batch[0]["input_ids"].shape[0], uniques))
    print(y.shape)
    y = y.to("cuda:0")
    y[:,batch[0]["input_ids"].view((-1, 24))[:,23]] = 1
    
    y_p = self_atten(x, visualization=False)
    l = loss(y_p, y)
    l.backward()
    optimizer.step()
    loss_n += l/batch[0]["input_ids"].shape[0]
    acc += torch.eq(y.argmax(dim=1), y_p.argmax(dim=1)).sum().item()/batch[0]["input_ids"].shape[0]
    print(f"\rit: {i+1}/{len(dataloader)}, loss: {loss_n/(i+1)}, acc: {acc/(i+1)}", end="")
  loss_n /= len(dataloader)
  acc /= len(dataloader)
  schedule.step(loss_n)
  print("epoch:", k ,"loss:", loss_n, "acc:",acc)


#p = self_atten(numdata[:12])

#print(p, words[p.argmax()], data[13])

accuracy = 0
for i, batch in enumerate(dataloader):
  x = batch[0]
  y = torch.zeros((batch[1].shape[0], uniques))
  y[:,batch[1]] = 1
  y_p = self_atten(x, visualization=False)
  y_p = y_p.to("cpu")
  #print(y, words[y_p.argmax()])
  accuracy += torch.eq(y.argmax(dim=1), y_p.argmax(dim=1)).sum().item()/batch[0].shape[0]

accuracy /= len(dataloader)-seq_length-1
print("accuracy:", accuracy)

print(dataset.tokens_to_text(dataset[128][0]), dataset.tokens_to_text(torch.Tensor([dataset[128][1]])))
print(dataset[128][0].shape)
x = torch.Tensor(dataset[128][0]).unsqueeze(0)
self_atten(x, visualization=True)
x = torch.Tensor(dataset[129][0]).unsqueeze(0)
self_atten(x, visualization=True)
"""

'\n\nfor k in range(25):\n  loss_n = 0\n  acc = 0\n  for i, batch in enumerate(dataloader):\n    optimizer.zero_grad()\n    print(batch[0]["input_ids"].shape)\n    x = batch[0]["input_ids"].view((-1, 24))[:, :23]\n    print(x.shape)\n    y = torch.zeros((batch[0]["input_ids"].shape[0], uniques))\n    print(y.shape)\n    y = y.to("cuda:0")\n    y[:,batch[0]["input_ids"].view((-1, 24))[:,23]] = 1\n    \n    y_p = self_atten(x, visualization=False)\n    l = loss(y_p, y)\n    l.backward()\n    optimizer.step()\n    loss_n += l/batch[0]["input_ids"].shape[0]\n    acc += torch.eq(y.argmax(dim=1), y_p.argmax(dim=1)).sum().item()/batch[0]["input_ids"].shape[0]\n    print(f"\rit: {i+1}/{len(dataloader)}, loss: {loss_n/(i+1)}, acc: {acc/(i+1)}", end="")\n  loss_n /= len(dataloader)\n  acc /= len(dataloader)\n  schedule.step(loss_n)\n  print("epoch:", k ,"loss:", loss_n, "acc:",acc)\n\n\n#p = self_atten(numdata[:12])\n\n#print(p, words[p.argmax()], data[13])\n\naccuracy = 0\nfor i, batch in enume

In [None]:
class self_attention_pure_encoder_qk_decoder_v(nn.Module):
  def __init__(self, device="cuda:0", p_in=12, mid = 512, dims=0, num_heads=4):
    super().__init__()
    self.num_words = p_in
    self.num_dims = p_in
    self.num_heads = num_heads
    if dims != 0:
      self.num_dims = dims
    if self.num_dims % num_heads  != 0:
      raise "num heads must divide num dims"

    self.Q = nn.Linear(self.num_dims, self.num_dims).to(device)
    self.K = nn.Linear(self.num_dims, self.num_dims).to(device)
    self.V = nn.Linear(self.num_dims, self.num_dims).to(device)
    self.softmax = nn.Softmax(dim=3)
    self.preoutput_0 = nn.Linear(self.num_dims,mid).to(device)
    self.preoutput_1 = nn.Linear(mid, self.num_dims).to(device)
    self.gelu = nn.GELU()
    self.device=device
    self.normalize = nn.LayerNorm([p_in, self.num_dims]).to(device)


  def forward(self, x, keyvalue,  device="cuda:0", visualization=False):
    x = x.to(self.device)
    keyvalue = keyvalue.to(self.device)
    id = x
    x_q = self.Q(x).view(x.shape[0],-1, self.num_heads, self.num_dims//self.num_heads)
    x_k = self.K(keyvalue).view(x.shape[0], -1, self.num_heads, self.num_dims//self.num_heads)
    x_v = self.V(keyvalue).view(x.shape[0], -1, self.num_heads, self.num_dims//self.num_heads)

    x_q = x_q.permute(0, 2, 1, 3)
    x_k = x_k.permute(0, 2, 1, 3)
    x_v = x_v.permute(0, 2, 1, 3)

    x_qk = x_q.matmul(x_k.transpose(2,3))/math.sqrt(self.num_dims)
    x = self.softmax(x_qk).matmul(x_v)
    if visualization:
      vis_softmax = nn.Softmax(dim=1)
      for i in range(x_qk.shape[1]):
        print(x_qk.shape)
        plt.imshow(vis_softmax(x_qk[0][i]).clone().detach().cpu(), vmin=0, vmax=1)
        plt.show()
      print(self.softmax(x_qk).shape)
    x = x.permute(0, 2, 1, 3)
    x = x.reshape(x.shape[0], -1, self.num_dims)
    x = self.preoutput_0(x)
    x = self.preoutput_1(x)
    x += id
    x = self.normalize(x)
    x = self.gelu(x)

    return x




class multi_layer_self_attention_encoder_decoder_scheme(nn.Module):
  def __init__(self, device="cuda:0", outs=[144,256,512,519], seq_length=12, dims=12, num_heads=4, emb_size=uniques):
    super().__init__()
    self.attention_w_enc = self_attention_w_pos_enc_pure(device=device, num_words=seq_length, num_dims=dims, out=outs[0], num_heads=num_heads, emb_size=emb_size)
    self.attention_1 = self_attention_pure_encoder_qk_decoder_v(device=device, p_in=seq_length, dims=dims, mid=outs[0], num_heads=num_heads)
    self.attention_2 = self_attention_pure_encoder_qk_decoder_v(device=device, p_in=seq_length, dims=dims, mid=outs[1], num_heads=num_heads)
    self.attention_3 = self_attention_pure(device=device, p_in=seq_length, dims=dims, mid=outs[2], num_heads=num_heads)

    self.attention_w_enc_encoder = self_attention_w_pos_enc_pure(device=device, num_words=sequence_length_enc, num_dims=dims, out=outs[0], num_heads=num_heads, emb_size=emb_size)
    self.attention_1_encoder = self_attention_pure(device=device, p_in=sequence_length_enc, dims=dims, mid=outs[0], num_heads=num_heads)
    self.attention_2_encoder = self_attention_pure(device=device, p_in=sequence_length_enc, dims=dims, mid=outs[1], num_heads=num_heads)
    self.attention_3_encoder = self_attention_pure(device=device, p_in=sequence_length_enc, dims=dims, mid=outs[2], num_heads=num_heads)
    self.dims= dims
    self.linear_1 = nn.Linear(seq_length*dims, outs[0]).to(device)
    self.linear_2 = nn.Linear(outs[0], outs[1]).to(device)
    self.flatten = nn.Flatten(start_dim=1)
    self.output = nn.Linear(outs[1], outs[3]).to(device)
    self.device=device

  def forward(self, x_enc, x_dec, visualization=False):
    #Gjør først venstre side av figure 1 for å få query og key
    x_enc = x_enc.to(self.device)
    x_enc = self.attention_w_enc_encoder(x_enc, visualization=visualization)
    x_enc = self.attention_1_encoder(x_enc, visualization=visualization)
    x_enc = self.attention_2_encoder(x_enc, visualization=visualization)
    x_enc = self.attention_3_encoder(x_enc, visualization=visualization)

    x_enc = self.flatten(x_enc)
    querykey_for_dec = x_enc
    querykey_for_dec = querykey_for_dec.reshape(x_enc.shape[0], -1, self.dims)

    x_dec = x_dec.to(self.device)
    x_dec = self.attention_w_enc(x_dec, visualization=visualization)
    x_dec = self.attention_1(x_dec, querykey_for_dec, visualization=visualization)
    x_dec = self.attention_2(x_dec, querykey_for_dec, visualization=visualization)
    x_dec = self.attention_3(x_dec, visualization=visualization)

    x_dec = self.flatten(x_dec)
    x_dec = self.linear_1(x_dec)
    x_dec = self.linear_2(x_dec)
    x_dec = self.output(x_dec)

    return x_dec


seq_length = sequence_length
self_atten = multi_layer_self_attention_encoder_decoder_scheme(seq_length=seq_length, dims=256, num_heads=8, outs=[768,1024,512,uniques], emb_size=uniques)
#p = self_atten(numdata[:12])

optimizer = torch.optim.Adam(self_atten.parameters(), lr = 0.0001, weight_decay=0.000005)
loss = nn.CrossEntropyLoss()
schedule = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2, cooldown=3, factor=0.333)



for k in range(25):
  loss_n = 0
  acc = 0
  for i, batch in enumerate(dataloader):
    optimizer.zero_grad()
    x, y = batch[0].view((-1, sequence_length_enc)), batch[1].view((-1,sequence_length))
    z = torch.zeros((batch[0].shape[0], uniques))
    z = z.to("cuda:0")
    z[:, batch[2].view((-1, 1))[:]] = 1
    
    y_p = self_atten(x, y, visualization=False)
    l = loss(y_p, z)
    l.backward()
    optimizer.step()
    loss_n += l/batch[0].shape[0]
    acc += torch.eq(z.argmax(dim=1), y_p.argmax(dim=1)).sum().item()/batch[0].shape[0]
    print(f"\rit: {i+1}/{len(dataloader)}, loss: {loss_n/(i+1)}, acc: {acc/(i+1)}", end="")
  loss_n /= len(dataloader)
  acc /= len(dataloader)
  schedule.step(loss_n)
  print("epoch:", k ,"loss:", loss_n, "acc:",acc)


#p = self_atten(numdata[:12])

#print(p, words[p.argmax()], data[13])

accuracy = 0
for i, batch in enumerate(dataloader):
  x, y = batch[0].view((-1, sequence_length_enc)), batch[1].view((-1,sequence_length))
  z = torch.zeros((batch[2].shape[0], uniques))
  z[batch[2].view((-1,1))] = 1
  y_p = self_atten(x, y, visualization=False)
  y_p = y_p.to("cpu")
  #print(y, words[y_p.argmax()])
  accuracy += torch.eq(z.argmax(dim=1), y_p.argmax(dim=1)).sum().item()/batch[0].shape[0]

accuracy /= len(dataloader)-seq_length-1
print("accuracy:", accuracy)

print(dataset.tokens_to_text(dataset[128][0]), dataset.tokens_to_text(torch.Tensor([dataset[128][1]])),dataset.tokens_to_text(torch.Tensor([dataset[128][2]])))
print(dataset[128][0].shape)
x = torch.Tensor(dataset[128][0]).unsqueeze(0)
y = torch.Tensor(dataset[128][1]).unsqueeze(0)

self_atten(x, y, visualization=True)
#x = torch.Tensor(dataset[129][0]).unsqueeze(0)
#self_atten(x, visualization=True)

it: 908/908, loss: 5.731567859649658, acc: 0.027730933370044054epoch: 0 loss: tensor(5.7316, device='cuda:0', grad_fn=<DivBackward0>) acc: 0.027730933370044054
it: 908/908, loss: 5.446281909942627, acc: 0.04154047356828194epoch: 1 loss: tensor(5.4463, device='cuda:0', grad_fn=<DivBackward0>) acc: 0.04154047356828194
it: 908/908, loss: 5.4168877601623535, acc: 0.03192972191629956epoch: 2 loss: tensor(5.4169, device='cuda:0', grad_fn=<DivBackward0>) acc: 0.03192972191629956
it: 908/908, loss: 5.401157379150391, acc: 0.023222398127753303epoch: 3 loss: tensor(5.4012, device='cuda:0', grad_fn=<DivBackward0>) acc: 0.023222398127753303
it: 908/908, loss: 5.4023966789245605, acc: 0.04499070759911894epoch: 4 loss: tensor(5.4024, device='cuda:0', grad_fn=<DivBackward0>) acc: 0.04499070759911894
it: 908/908, loss: 5.397100448608398, acc: 0.042116946585903085epoch: 5 loss: tensor(5.3971, device='cuda:0', grad_fn=<DivBackward0>) acc: 0.042116946585903085
it: 908/908, loss: 5.389350891113281, acc: 0