## Imports

In [1]:
import operator
import pickle
import re
from collections import defaultdict, Counter
from itertools import count
from typing import List, Dict, Any, Tuple, Union, Optional, Callable

import torch
import torch.nn as nn
import torch.nn.functional as F
from langchain_text_splitters import RecursiveCharacterTextSplitter
from nltk.tokenize import sent_tokenize
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils import clip_grad_norm

ModuleNotFoundError: No module named 'nltk'

## Data Preparation

In [2]:
def fopen(path):
    with open(path, "r", encoding="utf-8") as f:
        data = f.read()
    return data

In [None]:
def preprocess(data):
    out = data.lower()
    out = out.replace("\n", " ").replace("\r", "").replace("\t", " ") # remove newlines, tabs, and carriage (/n) returns
    #out = re.sub(r"([.,!?;:'\"\(\)\[\]\{\}])", r" \1 ", out)
    out = re.sub(r"\s{2,}", " ", out) # remove multiple spaces
    out = out.replace("_", "") # remove underscores
    return out

In [None]:
dune = fopen("dune.txt")
corpus_cleaned = preprocess(dune)
corpus_cleaned = sent_tokenize(corpus_cleaned) # split into sentences

train_samples, dev_samples = train_test_split(
    corpus_cleaned, 
    train_size=0.8, 
    test_size=0.2, 
    random_state = 137,
    shuffle=True,
)

## Tokenization

In [None]:
class CharTokenizer:
    def __init__(self):
        self.start_token = "[START]"
        self.end_token = "[END]"
        self.unk_token = "[UNK]"
        self.pad_token = "[PAD]"
        
        self.vocab = defaultdict(count().__next__)
        self.freq = defaultdict(int)

        self.__init_special_tokens__()

    def __init_special_tokens__(self):
        self.vocab[self.start_token]
        self.vocab[self.end_token]
        self.vocab[self.pad_token]
        self.vocab[self.unk_token]

    def insert_token(self, token):
        if token not in self.vocab:
            self.vocab[token]

    def train(self, samples: List[str]):

        for sample in tqdm(samples): # for each sentence
            for char_token in sample:  # for each character in the sentence
                self.insert_token(char_token) # add the character to the vocabulary
            
        self.vocab_size = len(self.vocab)
        self.i2c = {v: k for k, v in self.vocab.items()} # invert the vocabulary to get the index to character mapping

    def encode(
        self, 
        input_text: Union[str, List],
        max_length: Optional[int] = None,
        preprocessing_function: Callable = lambda x: x,
        exclude_end_token: bool = False
    ) -> Union[List[int], List[List[int]]]:
        input_ids = []
        
        if type(input_text) == str:
            input_text = preprocessing_function(input_text)
            input_ids.append(self.vocab.get(self.start_token))
            
            for char_token in input_text:
                input_ids.append(self.vocab.get(char_token, self.vocab.get(self.unk_token))) # get the index of the character in the vocabulary
                if max_length is not None and max_length - 1 == len(input_ids):
                    break

            if not exclude_end_token:
                input_ids.append(self.vocab.get(self.end_token)) # add the end token to the input

            if max_length is not None and len(input_ids) < max_length:
                input_ids.extend(
                    [self.vocab.get(self.pad_token) for _ in range(len(input_ids), max_length)] # pad the input with the pad token
                )

        else:
            input_text = list(map(lambda x: preprocessing_function(x), input_text)) # preprocess each input text
            for each_input_text in input_text:
                each_input_ids = []
                for char_token in each_input_text:
                    each_input_ids.append(self.vocab.get(char_token, self.vocab.get(self.unk_token))) 
                    if max_length is not None and max_length - 1 == len(each_input_ids): #
                        break

                    if not exclude_end_token:
                        each_input_ids.append(self.vocab.get(self.end_token))
                        
                    if max_length is not None and len(each_input_ids) < max_length: 
                        each_input_ids.extend( 
                            [self.vocab.get(self.pad_token) for _ in range(len(each_input_ids), max_length)] # pad the input with the pad token
                        )
                input_ids.append(each_input_ids)

        return input_ids

    def decode(
        self, input_ids: Union[List[int], List[List[int]]]
    ) -> Union[str, List[str]]:
        decoded_string = []
        
        if type(input_ids) == list and type(input_ids[0]) == int:
            for id in input_ids:
                decoded_string.append(self.i2c.get(id))

        else:
            for each_input_id in input_ids:
                decoded_string_each = []
                for id in each_input_id:
                    decoded_string_each.append(self.i2c.get(id))
                decoded_string_each.append(decoded_string_each)

        return decoded_string

    def save(self, output_file="model/char.tokenizer"):
        with open(output_file, 'wb') as f:
            pickle.dump(self, f)

    @staticmethod # static method to load the tokenizer || @they belong to the class namespace and can be called directly on the class without needing an instance.
    def load(output_file="model/char.tokenizer"): 
        with open(output_file, "rb") as f: 
            tokenizer = pickle.load(f) 
        return tokenizer

In [35]:
tokenizer = CharTokenizer()
tokenizer.train(train_samples)

  0%|          | 0/10862 [00:00<?, ?it/s]

In [36]:
tokenizer.save()

## DataLoaders

In [8]:
class SentenceSampler(Dataset):
    def __init__(
        self, tokenizer: CharTokenizer, samples: List[str], max_length: int
    ) -> None:
        self.tokenizer = tokenizer
        self.samples = samples
        self.max_length = max_length

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
        sentence = self.samples[idx]
        input_ids = self.tokenizer.encode(sentence, self.max_length)
        return {
            "input_ids": torch.LongTensor(input_ids),
        }

In [9]:
train_dataset = SentenceSampler(tokenizer, train_samples, max_length=64)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)

dev_dataset = SentenceSampler(tokenizer, dev_samples, max_length=64)
dev_dataloader = DataLoader(dev_dataset, batch_size=16, shuffle=False)

In [10]:
def test():
    for batch in train_dataloader:
        print(tokenizer.decode(batch["input_ids"][0].tolist()))
        break

#test()

## Modeling

In [None]:
class RNNLanguageModel(nn.Module):
    def __init__(
        self, vocab_size: int, num_layers, embedding_size: int, 
        embedding_dropout_rate: float, hidden_size: int
    ) -> None:
        super(RNNLanguageModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.embedding_dropout = nn.Dropout(embedding_dropout_rate)

        self.lstm = nn.LSTM(
            input_size=embedding_size,
            hidden_size=hidden_size,
            bidirectional=False,
            num_layers=num_layers,
            batch_first=True
        )

        self.lm_head = nn.Linear(
            in_features=hidden_size,
            out_features=vocab_size
        )

        self.gelu = nn.GELU()
        self.initialize_weights()

    def forward(self, token_ids):
        # token_ids: [B, max_seq_len]
        out = self.embedding_dropout(self.embedding(token_ids))
        # out: [B, max_seq_len, embed_dim]
        out, (hidden, cell) = self.lstm(out)
        # out: [B, max_seq_len, hidden_size]
        hidden, cell = None, None
        out = self.gelu(out)
        return self.lm_head(out)
        # return: [B, max_seq_len, vocab_size]

    def initialize_weights(self):
        torch.nn.init.normal_(self.embedding.weight, mean=0.0, std=0.1)
        
        for name, param in self.lstm.named_parameters():
            if "weight_ih" in name:
                torch.nn.init.xavier_uniform_(param.data)
            elif "weight_hh" in name:
                torch.nn.init.orthogonal_(param.data)
            elif "bias" in name:
                param.data.fill_(0)
        
        torch.nn.init.xavier_uniform_(self.lm_head.weight)
        if self.lm_head.bias is not None:
            self.lm_head.bias.data.fill_(0)

## Training

In [None]:
pad_token_id = tokenizer.vocab.get(tokenizer.pad_token)
end_token_id = tokenizer.vocab.get(tokenizer.end_token)

criterion = nn.CrossEntropyLoss(ignore_index=pad_token_id)

In [None]:
class TrainerWrapper:
    def __init__(
        self, model: RNNLanguageModel, device: torch.device, 
        norm_threshold: Optional[float] = None
    ):
        self.model = model
        self.device = device
        self.model = self.model.to(self.device)
        
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size = 10, gamma = 0.1)
        self.norm_threshold = norm_threshold

    def train(self, train_dataloader, dev_dataloader, epochs, tokenizer, criterion):
        total = len(train_dataloader) * epochs
        
        with tqdm(total=total, desc="Training Round") as training:
            for epoch in range(epochs):
                train_perplexity = 0 # initialize the training perplexity ||  It measures how well a probabilistic model predicts a sample of text.
                train_loss = 0
                batch_count = 0
                for step, batch in enumerate(train_dataloader):
                    self.model.train()
                    
                    input_ids = batch["input_ids"].to(self.device)

                    out = self.model(input_ids[:, :-1])
                    loss = criterion(out.permute(0, 2, 1), input_ids[:, 1:])

                    #print(tokenizer.decode(token_ids[0, :].tolist()))
                    #print(tokenizer.decode(targets[0, :].tolist()))

                    loss.backward()
                    clip_grad_norm(model.parameters(), 0.5)
                    

                    self.optimizer.step()
                    
                    train_loss += loss.item()
                    batch_count +=1
                
                    training.update()
                
                dev_loss, dev_perplexity = self.evaluate(
                    dev_dataloader=dev_dataloader,
                    tokenizer=tokenizer,
                    criterion=criterion
                )

                print(35*"*")
                print(f"Epoch {epoch+1}/{epochs}")
                print(f"  - Train Loss: {train_loss/batch_count}")
                print(f"  - Train Perplexity: {torch.exp(torch.tensor(train_loss/batch_count))}")
                print(f"  - Eval Loss: {dev_loss}")
                print(f"  - Eval Perplexity: {dev_perplexity}")
                if self.scheduler is not None:
                    self.scheduler.step()

    @torch.no_grad()
    def evaluate(self, dev_dataloader, tokenizer, criterion):
        self.model = self.model.to(self.device)
        total = len(dev_dataloader)

        dev_loss = 0
        batch_count = 0

        with tqdm(total=total, desc="Evaluation Round") as evaluation:
            for step, batch in enumerate(dev_dataloader):
                self.model.eval()
                
                input_ids = batch["input_ids"].to(self.device)

                out = self.model(input_ids[:, :-1])
                loss = criterion(out.permute(0, 2, 1), input_ids[:, 1:])

                dev_loss += loss.item()
                batch_count +=1
                evaluation.update()

        return dev_loss/batch_count, torch.exp(torch.tensor(dev_loss/batch_count))

    @torch.inference_mode() #Gradient computation is disabled to save memory etc.
    def generate(self, tokenizer, end_token_id, max_generation, condition: Optional[str] = None):
        input_ids = tokenizer.encode(
            input_text=condition if condition is not None else tokenizer.start_token,
            max_length=None,
            exclude_end_token=True
        )

        while max_generation:
            input_ids_tensor = torch.LongTensor(input_ids)[None, :].to(self.device)
            out = self.model(input_ids_tensor)
            out_last = out[:, -1, :]
            out_normalized = out_last.softmax(dim=-1)
            o_pred = out_normalized.argmax(dim=-1).flatten().item()
            input_ids.append(o_pred)

            if o_pred == end_token_id:
                break

            max_generation = max_generation - 1

        return tokenizer.decode(input_ids)

In [14]:
model = RNNLanguageModel(tokenizer.vocab_size, 2, 256, 0.1, 1024)

In [15]:
trainer = TrainerWrapper(model, torch.device("mps"))

In [16]:
trainer.train(train_dataloader, dev_dataloader, 20, tokenizer, criterion)

Training Round:   0%|          | 0/6800 [00:00<?, ?it/s]

  clip_grad_norm(model.parameters(), 0.5)


Evaluation Round:   0%|          | 0/170 [00:00<?, ?it/s]

***********************************
Epoch 1/20
  - Train Loss: 2.1828316043404974
  - Train Perplexity: 8.871390342712402
  - Eval Loss: 1.6189646482467652
  - Eval Perplexity: 5.047861576080322


Evaluation Round:   0%|          | 0/170 [00:00<?, ?it/s]

***********************************
Epoch 2/20
  - Train Loss: 1.446007208964404
  - Train Perplexity: 4.246127128601074
  - Eval Loss: 1.3855497198946336
  - Eval Perplexity: 3.9970223903656006


Evaluation Round:   0%|          | 0/170 [00:00<?, ?it/s]

***********************************
Epoch 3/20
  - Train Loss: 1.2940423043335185
  - Train Perplexity: 3.647501230239868
  - Eval Loss: 1.3242703507928286
  - Eval Perplexity: 3.759441375732422


Evaluation Round:   0%|          | 0/170 [00:00<?, ?it/s]

***********************************
Epoch 4/20
  - Train Loss: 1.2235511061023263
  - Train Perplexity: 3.399237632751465
  - Eval Loss: 1.292295601087458
  - Eval Perplexity: 3.6411354541778564


Evaluation Round:   0%|          | 0/170 [00:00<?, ?it/s]

***********************************
Epoch 5/20
  - Train Loss: 1.1847949185792137
  - Train Perplexity: 3.2700161933898926
  - Eval Loss: 1.3018337151583503
  - Eval Perplexity: 3.6760313510894775


Evaluation Round:   0%|          | 0/170 [00:00<?, ?it/s]

***********************************
Epoch 6/20
  - Train Loss: 1.164923092547585
  - Train Perplexity: 3.205676317214966
  - Eval Loss: 1.293555688156801
  - Eval Perplexity: 3.6457266807556152


Evaluation Round:   0%|          | 0/170 [00:00<?, ?it/s]

***********************************
Epoch 7/20
  - Train Loss: 1.1582732933409075
  - Train Perplexity: 3.1844301223754883
  - Eval Loss: 1.3016847687609054
  - Eval Perplexity: 3.6754837036132812


Evaluation Round:   0%|          | 0/170 [00:00<?, ?it/s]

***********************************
Epoch 8/20
  - Train Loss: 1.1597516112467823
  - Train Perplexity: 3.189141273498535
  - Eval Loss: 1.2971061264767367
  - Eval Perplexity: 3.658693552017212


Evaluation Round:   0%|          | 0/170 [00:00<?, ?it/s]

***********************************
Epoch 9/20
  - Train Loss: 1.1359906494617462
  - Train Perplexity: 3.1142570972442627
  - Eval Loss: 1.282497244722703
  - Eval Perplexity: 3.605632781982422


Evaluation Round:   0%|          | 0/170 [00:00<?, ?it/s]

***********************************
Epoch 10/20
  - Train Loss: 1.0973355521174037
  - Train Perplexity: 2.9961721897125244
  - Eval Loss: 1.2774177572306464
  - Eval Perplexity: 3.587364435195923


Evaluation Round:   0%|          | 0/170 [00:00<?, ?it/s]

***********************************
Epoch 11/20
  - Train Loss: 1.0174635720603606
  - Train Perplexity: 2.766169548034668
  - Eval Loss: 1.2609886548098397
  - Eval Perplexity: 3.5289087295532227


Evaluation Round:   0%|          | 0/170 [00:00<?, ?it/s]

***********************************
Epoch 12/20
  - Train Loss: 0.9720761420095668
  - Train Perplexity: 2.6434268951416016
  - Eval Loss: 1.2578525424003602
  - Eval Perplexity: 3.5178589820861816


Evaluation Round:   0%|          | 0/170 [00:00<?, ?it/s]

***********************************
Epoch 13/20
  - Train Loss: 0.9381887288654551
  - Train Perplexity: 2.5553488731384277
  - Eval Loss: 1.259469797330744
  - Eval Perplexity: 3.523552656173706


Evaluation Round:   0%|          | 0/170 [00:00<?, ?it/s]

***********************************
Epoch 14/20
  - Train Loss: 0.9080400414326611
  - Train Perplexity: 2.4794580936431885
  - Eval Loss: 1.2619871910880593
  - Eval Perplexity: 3.5324342250823975


Evaluation Round:   0%|          | 0/170 [00:00<?, ?it/s]

***********************************
Epoch 15/20
  - Train Loss: 0.8800503914847093
  - Train Perplexity: 2.4110212326049805
  - Eval Loss: 1.2664564700687633
  - Eval Perplexity: 3.5482568740844727


Evaluation Round:   0%|          | 0/170 [00:00<?, ?it/s]

***********************************
Epoch 16/20
  - Train Loss: 0.8537456950720619
  - Train Perplexity: 2.3484268188476562
  - Eval Loss: 1.2728309498113743
  - Eval Perplexity: 3.5709474086761475


Evaluation Round:   0%|          | 0/170 [00:00<?, ?it/s]

***********************************
Epoch 17/20
  - Train Loss: 0.8279372599195032
  - Train Perplexity: 2.288593053817749
  - Eval Loss: 1.2799165522350984
  - Eval Perplexity: 3.596339464187622


KeyboardInterrupt: 

In [24]:
torch.save(trainer.model.state_dict(), "model/model_charlm.pt")

## Load Model and Play

In [7]:
def join(trainer, tokenizer, end_token_id, gen_len, condition):
    return "".join(trainer.generate(tokenizer, end_token_id, 100, condition))

In [13]:
tokenizer = CharTokenizer.load()
pad_token_id = tokenizer.vocab.get(tokenizer.pad_token)
end_token_id = tokenizer.vocab.get(tokenizer.end_token)

In [9]:
model = RNNLanguageModel(tokenizer.vocab_size, 2, 256, 0.1, 1024)

In [10]:
model.load_state_dict(torch.load("model/model_charlm.pt", weights_only=True))

  return self.fget.__get__(instance, owner)()


<All keys matched successfully>

In [11]:
trainer = TrainerWrapper(model, torch.device("mps"))

In [17]:
join(trainer, tokenizer, end_token_id, 100, "house")

'[START]house the spice wealth of the sand around them.[END]'