In [9]:
import torch
import einops
import torch.nn as nn
import polars as pr
import plotly.express as ex

from torch import Tensor
from random import random
from torch.utils.data import Dataset, DataLoader
from polars.dataframe.frame import DataFrame
from polars.series.series import Series
from typing import Optional, Dict, List, NoReturn, Callable, Generic, Any

import seris
import os
torch.set_default_device("cuda:0")
torch.cuda.empty_cache()

In [10]:
HOME = os.getenv("HOME")
train1 = HOME + "/Datasets/muld_OpenSubtitles/data/train-00000-of-00003-ae10b8591df9b61f.parquet"
train2 = HOME + "/Datasets/muld_OpenSubtitles/data/train-00001-of-00003-d297e02053936096.parquet"
train3 = HOME + "/Datasets/muld_OpenSubtitles/data/train-00002-of-00003-7c2c5fa1d6ac9938.parquet"
test = HOME + "/Datasets/muld_OpenSubtitles/data/test-00000-of-00001-30e0e85c508944e5.parquet"

In [11]:
dataset = seris.SimpleDataset(train1)
len(dataset)

9250

In [14]:
chunk = seris.dataset.GetData(dataset[0])
# chunk[:2]

In [5]:
class Attention(nn.Module):
    def __init__(
        self,
        d_model: int
    ):
        super().__init__()
        self.d_model = d_model
        self.k_proj = nn.Linear(d_model , d_model)
        self.q_proj = nn.Linear(d_model , d_model)
        self.v_proj = nn.Linear(d_model , d_model)
        
    def forward(self, x: Tensor) -> Tensor:
        k: Tensor = self.k_proj(x)
        q: Tensor = self.q_proj(x)
        v: Tensor = self.v_proj(x)
        return nn.functional.softmax((k@q.transpose(-1,-2)).div(self.d_model),dim=-1)@v

class Head(nn.Module):
    def __init__(
        self,
        dim :int = 64,
        d_model : int = 64,
        nheads : int = 1
    ):
        assert dim%nheads == 0
        super().__init__()
        self.nheads = nheads
        self.head_dim = dim//nheads
        self.heads = nn.ModuleList([Attention(d_model=d_model) for _ in range(nheads)])
        
    def forward(self, x: Tensor) -> Tensor:
        x = einops.rearrange(x , 'b (t c) d -> b t c d' , t=self.nheads)
        x = torch.concat(tuple(head(x[:,i,:]) for i,head in enumerate(self.heads,start=0)),dim=-2)
        return x

class Transformer(nn.Module):
    def __init__(
        self,*,
        dim: int = 64,
        nheads: int = 1,
        d_model: int =32,
        vocab_size: int = 1,
        batch_size: int = 32
    ):
        super().__init__()
        self.dim = dim
        self.nheads = nheads
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.bs = batch_size
        self.emb_proj = nn.Embedding(vocab_size, d_model)
        self.multihead = Head( dim = dim ,d_model=d_model , nheads=nheads)
        self.fully_conn = nn.Sequential(
            nn.Linear(d_model*dim , vocab_size)
        )
        
    def forward(self, x: Tensor) -> Tensor:
        x = self.emb_proj( x )
        x = self.multihead( x )
        x = self.fully_conn( x.view(-1, self.d_model*self.dim) )
        return x

In [6]:
tk = seris.Tokenizer(special_token= ["<reserved_1>","<reserved_2>" ,"<end>","<start>"] )
for i in dataset:
    tk.train(i)
tk.vocab_size

398

In [7]:
len(tk.kv.keys()) ,\
len(tk.vk.keys()) ,\
max(tk.kv.values()) ,\
tk.vocab_size

(398, 398, 397, 398)

In [8]:
def get_batch_rand(tensor_data:Tensor, batch_size=32, dim=64):
    # ( batchsize*32 , ( input*64 ) )  , ( batchsize*32 , ( output*1 ))
    l = len(tensor_data) -1
    rnd = torch.randint( l-dim ,(batch_size,))
    x = torch.concat( [ tensor_data[i:i+dim].unsqueeze(0) for i in rnd ] ,dim=0)
    y = torch.concat( [ tensor_data[i+dim].unsqueeze(0) for i in rnd ] ,dim=0)
    return x,y

def get_batch(tensor_data:Tensor, batch_size=32, dim=64):
    # ( batchsize*32 , ( input*64 ) )  , ( batchsize*32 , ( output*1 ))
    length = len(tensor_data) -batch_size
    for i in range(length-dim):
        lst = list(range(i,i+batch_size))
        x_ = torch.concat( [ tensor_data[i:i+dim].unsqueeze(0) for i in lst ] ,dim=0)
        y_ = torch.concat( [ tensor_data[i+dim].unsqueeze(0) for i in lst ] ,dim=0)
        yield x_,y_

In [None]:
lr = 0e-3
dim = 32
batch_size = 64
nheads = 2
d_model = 32
vocab_size = tk.vocab_size

net = Transformer(
    batch_size=batch_size,
    d_model=d_model,
    dim=dim,
    nheads=nheads,
    vocab_size=vocab_size
)
optim = torch.optim.AdamW(params=net.parameters(),lr=lr)
criterion = nn.CrossEntropyLoss()
for corpus in dataset:
    for x,y in get_batch(tk.encode(corpus),batch_size=batch_size,dim=dim):
        y_ = net(x)
        optim.zero_grad()
        loss = criterion(y_,y)
        loss.backward()
        optim.step()
    print(loss.item())