<a href="https://colab.research.google.com/github/shusank8/SEQUENCEModels/blob/main/LSTMFROMScratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
print('SIMPLE LONG SHORT TERM MEMORY; BY SHUSANKET BASYAL')

SIMPLE LONG SHORT TERM MEMORY; BY SHUSANKET BASYAL


In [3]:
# LOADING THE DATASET
# DATASET IS THE SHORT JOKES FROM KAGGLE
import kagglehub
path = kagglehub.dataset_download("abhinavmoudgil95/short-jokes")

print("Path to dataset files:", path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/abhinavmoudgil95/short-jokes?dataset_version_number=1...


100%|██████████| 9.82M/9.82M [00:00<00:00, 48.1MB/s]

Extracting files...





Path to dataset files: /root/.cache/kagglehub/datasets/abhinavmoudgil95/short-jokes/versions/1


In [4]:
# IMPORTING THE NECESSARY LIBARIES

import os
import pandas as pd
import torch
import torch.nn as nn
from torch.nn import functional as F

In [5]:
# LOOKING WHERE THE FILES HAS BEEN DOWNLOADED
os.listdir(path)

['shortjokes.csv']

In [6]:
# LOADING THE FILE INTO DF
df = pd.read_csv(path+"/shortjokes.csv")
# GETTING ALL THE VALUES IN JOKE COLUMN => RETURNS A LIST
text = df['Joke'].values
# JOINING ALL THE STR VAL IN THE LIST TO GET A SINGLE STR
text = "".join(text)
# GETTING THE UNIQUE CHAR PRESENT IN THE DATASET AND CREATING A VARIABLE VOCAB_SIZE THAT STORES THE LEN OF THE UNIQUE ELEMENTS
char = sorted(list(set(text)))
vocab_size = len(char)
# SIMPLE ENCODER, DECODER
# CREATING A HASMAP THAT MAPS STRING TO ID AND VICE VERSA
stringtoid = {sti:i for i,sti in enumerate(char)}
idtostring = {i:sti for i, sti in enumerate(char)}
# USING THE CREATED HASMAP TO CREATER ENCODER AND DECODER
encode = lambda x : [stringtoid[i] for i in x]
decode = lambda x: "".join([idtostring[i] for i in x])
# ENCODING THE TEXT
text = torch.tensor(encode(text), dtype=torch.long)
# CREATING TRAIN AND VAL SIZE
n = int(0.8*len(text))
train = text[0:n]
val = text[n:]

In [7]:
# This function creates batches of data for training or validation.
# It selects random starting points, extracts sequences of a given length (block_size), and prepares input (x) and target (y) tensors for a model.

def generate_batch(split, batch_size, block_size):
  data = train if split =='train' else val
  idx = torch.randint(0, len(data)-block_size, (batch_size, ))
  x = torch.stack([data[i:i+block_size] for i in idx])
  y = torch.stack([data[i+1:i+1+block_size] for i in idx])
  return x,y


In [8]:
# This function estimates the model's loss on the validation set by running 64 mini-batches through it.
# It calculates cross-entropy loss for each batch and returns the average loss, temporarily switching the model to evaluation mode for accurate assessment.

def estimate_loss(model, vocab_size, batch_size, block_size):

  model.eval()

  losses = torch.zeros(64)
  for _ in range(64):
    x,y = generate_batch('val', batch_size, block_size)
    x = x.to('cuda')
    y = y.to('cuda')
    logits = model(x)
    logits = logits.reshape(-1, vocab_size)
    y = y.view(-1)
    loss = F.cross_entropy(logits, y)
    losses[_] = loss.item()
  model.train()
  return losses.mean()


In [9]:
embdim = 64
block_size = 64
hidim = 64
# outdim = 32
batch_size = 128
vocab_size

97

In [10]:
class LSTMFROMScratch(nn.Module):

  def __init__(self):
    super().__init__()

    self.embeddings = nn.Embedding(vocab_size, embdim)

    # forget gate
    self.forget_gate_x = nn.Linear(embdim, hidim, bias = False)
    self.forget_gate_hid = nn.Linear(hidim, hidim, bias = False)

    # input gate
    self.input_gate_x = nn.Linear(embdim, hidim, bias= False)
    self.input_gate_hid = nn.Linear(hidim, hidim, bias= False)

    # candidate gate
    self.candidate_gate_x = nn.Linear(embdim, hidim, bias = False)
    self.candidate_gate_hid = nn.Linear(hidim, hidim, bias = False)

    # output gate
    self.outputgate_x = nn.Linear(embdim,hidim, bias = False)
    self.outputgate_hid = nn.Linear(hidim,hidim, bias = False)

    # self.input_to_hidden = nn.Linear(embdim, hidim,bias = False)

    # self.hidden_to_hidden = nn.Linear(hidim, hidim, bias = False)

    self.out = nn.Linear(hidim, vocab_size, bias = False)

  def forward(self, x, h=None, c = None):
    x = self.embeddings(x)
    # shape of x => (B,T,C)
    x = x.transpose(0,1)
    T,B,C = x.shape
    if h is None:
      h = torch.zeros(B, hidim, device = 'cuda')
      c = torch.zeros(B, hidim, device = 'cuda')
    res = []
    for _ in range(T):

      xi = x[_]

      # a = self.input_to_hidden(xi)

      # b = self.hidden_to_hidden(h)

      # z = a+b

      fg = torch.sigmoid(self.forget_gate_x(xi)+self.forget_gate_hid(h))

      ig = torch.sigmoid(self.input_gate_x(xi)+self.input_gate_hid(h))

      cg = torch.tanh(self.candidate_gate_x(xi)+self.candidate_gate_hid(h))

      og = torch.sigmoid(self.outputgate_x(xi)+self.outputgate_hid(h))

      c = c*fg + ig*cg

      h = torch.tanh(c)*og

      ot = self.out(h)

      res.append(ot)

    res = torch.stack(res)

    res = res.transpose(0,1)

    return res




In [11]:
# class LSTMPY(nn.Module):

#   def __init__(self):
#     super().__init__()
#     self.embeddings = nn.Embedding(vocab_size, embdim)
#     self.lstm = nn.LSTM(embdim, hidim, 1, True, True)
#     self.out = nn.Linear(hidim, vocab_size, bias=False)
#   def forward(self, x):
#     x = self.embeddings(x)
#     out,hid = self.lstm(x)
#     return self.out(out)

In [12]:
# model = LSTMPY()
# for name, p in model.named_parameters():
#   print(name, p.size())

In [13]:
model = LSTMFROMScratch()
for name, param in model.named_parameters():
  if param.dim()>=2:
    torch.nn.init.xavier_normal_(param)
model = model.to("cuda")

optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-3)

In [25]:
epoches = 10000
for _ in range(epoches):

  x,y = generate_batch('train', batch_size, block_size)
  x = x.to("cuda")
  y = y.to("cuda")
  logits = model(x)
  logits = logits.reshape(-1, vocab_size)
  y = y.view(-1)
  loss = F.cross_entropy(logits, y)
  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()
  if _%200==0:

    l = estimate_loss(model, vocab_size, batch_size, block_size)
    print("step:", _ , "loss=>", l.item())

step: 0 loss=> 1.7409148216247559
step: 200 loss=> 1.7359461784362793
step: 400 loss=> 1.7355518341064453
step: 600 loss=> 1.7298846244812012
step: 800 loss=> 1.7306774854660034
step: 1000 loss=> 1.7333598136901855
step: 1200 loss=> 1.738206148147583
step: 1400 loss=> 1.7317755222320557
step: 1600 loss=> 1.7363221645355225
step: 1800 loss=> 1.7335338592529297
step: 2000 loss=> 1.7332932949066162
step: 2200 loss=> 1.7348014116287231
step: 2400 loss=> 1.7307429313659668
step: 2600 loss=> 1.7306863069534302
step: 2800 loss=> 1.7286218404769897
step: 3000 loss=> 1.7274980545043945
step: 3200 loss=> 1.7333952188491821
step: 3400 loss=> 1.7272214889526367
step: 3600 loss=> 1.7223597764968872
step: 3800 loss=> 1.7258373498916626
step: 4000 loss=> 1.7239192724227905
step: 4200 loss=> 1.727766752243042
step: 4400 loss=> 1.7290436029434204
step: 4600 loss=> 1.7228200435638428
step: 4800 loss=> 1.726190447807312
step: 5000 loss=> 1.722676396369934
step: 5200 loss=> 1.724047303199768
step: 5400 lo

In [26]:

# This function generates tokens using the trained model.
# Starting from a given input, it predicts the next token, samples from the probability distribution, appends it to the sequence,
# and continues for max_tok steps without updating gradients.

def generatetok(model, start, max_tok):
  with torch.no_grad():
    for _ in range(max_tok):
      # start2 = start[:, -block_size:, :]
      B,T = start.shape
      logits = model(start)

      # logits = logits.reshape(-1, vocab_size)
      prob = logits[:,-1,:]
      prob = F.softmax(prob, dim=-1)
      lo = torch.multinomial(prob, num_samples=1)
      start = torch.cat([start, lo], dim=1)
  return start


In [27]:
start = torch.tensor([2,2,2,2,2,2], device='cuda',dtype=torch.long).reshape(6,1)

In [28]:
# INITIALIZING THE START AS 0
# start = torch.zeros([3,1], device='cuda',dtype=torch.long)
# GENERATING FROM THE MODEL
out = generatetok(model, start, 256)
out.shape
# output
res = []
for _ in range(start.shape[0]):
  o = out[_]
  res.append(decode(o.tolist()))
for x in res:
  print(x)
  print("--------------\n")


 "Patutiture hour? Play tesk!"... poss? Alway down. 6 the these in the's any when a jeand bad: It's a Guy on the My Meopeach atteachess more.... Skread one is a gonning time   Shes? They gaving behinds a toda, not she arms Suchouth. Strits BSARD: People boy
--------------

 hemplonder. I don't come nound stecil get call" I those Mexican doings guilbstwite? Fed she accidentlyThe in deself!Why do diegns? The room? Probb his tadslard, exceris always tooted, I scrimme? Me:3yral solb.I like the Man? Come Mampolat, I'll tone" shitc
--------------

 shouffliler: When you have Hart 2 consecolese.You russia to make my naking sitter long and the bull of Chace, gethinahtally subctives her pains.My telely now you'How do you get a bottor? Dapbase." You can't cildencolow around the Jewwarged mirrieds buttive
--------------

 baby say. But a camps.... Some Elept things me it tried that shit to a disaptoss the Gendon? He time going to leg! ADHAILLEMA rro's world but I fun joke: I'm don't have "You Bre

In [29]:
idtostring

{0: '\x08',
 1: '\x10',
 2: ' ',
 3: '!',
 4: '"',
 5: '#',
 6: '$',
 7: '%',
 8: '&',
 9: "'",
 10: '(',
 11: ')',
 12: '*',
 13: '+',
 14: ',',
 15: '-',
 16: '.',
 17: '/',
 18: '0',
 19: '1',
 20: '2',
 21: '3',
 22: '4',
 23: '5',
 24: '6',
 25: '7',
 26: '8',
 27: '9',
 28: ':',
 29: ';',
 30: '<',
 31: '=',
 32: '>',
 33: '?',
 34: '@',
 35: 'A',
 36: 'B',
 37: 'C',
 38: 'D',
 39: 'E',
 40: 'F',
 41: 'G',
 42: 'H',
 43: 'I',
 44: 'J',
 45: 'K',
 46: 'L',
 47: 'M',
 48: 'N',
 49: 'O',
 50: 'P',
 51: 'Q',
 52: 'R',
 53: 'S',
 54: 'T',
 55: 'U',
 56: 'V',
 57: 'W',
 58: 'X',
 59: 'Y',
 60: 'Z',
 61: '[',
 62: '\\',
 63: ']',
 64: '^',
 65: '_',
 66: '`',
 67: 'a',
 68: 'b',
 69: 'c',
 70: 'd',
 71: 'e',
 72: 'f',
 73: 'g',
 74: 'h',
 75: 'i',
 76: 'j',
 77: 'k',
 78: 'l',
 79: 'm',
 80: 'n',
 81: 'o',
 82: 'p',
 83: 'q',
 84: 'r',
 85: 's',
 86: 't',
 87: 'u',
 88: 'v',
 89: 'w',
 90: 'x',
 91: 'y',
 92: 'z',
 93: '{',
 94: '|',
 95: '}',
 96: '~'}