In [None]:
import nltk
nltk.download('punkt')
import torch
from torch import nn
from torch.nn import functional as F
import requests
from nltk.tokenize import word_tokenize
import re
import numpy as np

class BigramLanguageModel(nn.Module):
    def __init__(self, batch_size=4, input_length=8, train_iters=100, eval_iters=100):
        super().__init__()
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.input_length = input_length
        self.batch_size = batch_size
        self.train_iters = train_iters
        self.eval_iters = eval_iters

    def forward(self, inputs, targets=None):
        logits = self.token_embeddings_table(inputs)
        if targets is None:
            loss = None
        else:
            batch_size, input_length, vocab_size = logits.shape
            logits = logits.view(batch_size * input_length, vocab_size)
            targets = targets.view(batch_size * input_length)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def fit(self, learning_rate=0.001):
        optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.1)
        for iter in range(self.train_iters):
            if iter % (self.train_iters // 20) == 0:
                avg_loss = self.eval_loss()
                print(f"Iter {iter}: Train Loss = {avg_loss['train']['loss']:.4f}, "
                      f"Train Perplexity = {avg_loss['train']['perplexity']:.4f}, "
                      f"Val Loss = {avg_loss['eval']['loss']:.4f}, "
                      f"Val Perplexity = {avg_loss['eval']['perplexity']:.4f}")
            inputs, targets = self.get_batch(split='train')
            logits, loss = self(inputs, targets)
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()
            scheduler.step()

    def generate(self, context, max_new_tokens, temperature=1.0):
        inputs = context
        for _ in range(max_new_tokens):
            logits, _ = self(inputs)
            logits = logits[:, -1, :] / temperature
            probs = F.softmax(logits, dim=1)
            sampled_output = torch.multinomial(probs, num_samples=1)
            inputs = torch.cat((inputs, sampled_output), dim=1)
        output_text = self.decoder(inputs[0].tolist())
        return output_text

    @torch.no_grad()
    def eval_loss(self):
        perf = {}
        self.eval()
        for split in ['train', 'eval']:
            losses = torch.zeros(self.eval_iters)
            for k in range(self.eval_iters):
                inputs, targets = self.get_batch(split)
                logits, loss = self(inputs, targets)
                losses[k] = loss.item()
            avg_loss = losses.mean()
            perplexity = torch.exp(avg_loss).item()
            perf[split] = {'loss': avg_loss, 'perplexity': perplexity}
        self.train()
        return perf

    def prep_tokens(self, text):
        text = re.sub(r'\s+', ' ', text)  # Normalize spaces
        tokens = word_tokenize(text.lower())
        tokens.append(' ')  # Add space token
        vocab = sorted(list(set(tokens)))  # Create vocabulary
        self.vocab_size = len(vocab)
        self.token_embeddings_table = nn.Embedding(self.vocab_size, self.vocab_size)

        ctoi = {c: i for i, c in enumerate(vocab)}  # Token to integer map
        itoc = {i: c for c, i in ctoi.items()}  # Integer to token map

        # Encoder function: maps text to list of token indices
        self.encoder = lambda text: [ctoi[c] for c in word_tokenize(text.lower()) if c in ctoi]

        # Decoder function: maps token indices to text
        self.decoder = lambda nums: ' '.join([itoc[i] for i in nums])

        # Split tokens into training and validation sets
        n = len(tokens)
        self.train_text = tokens[:int(n * 0.9)]
        self.val_text = tokens[int(n * 0.9):]

        # Encode training and validation data
        self.train_data = torch.tensor(self.encoder(' '.join(self.train_text)), dtype=torch.long)
        self.val_data = torch.tensor(self.encoder(' '.join(self.val_text)), dtype=torch.long)
    def get_batch(self, split='train'):
        data = self.train_data if split == 'train' else self.val_data
        ix = torch.randint(len(data) - self.input_length, (self.batch_size,))
        inputs_batch = torch.stack([data[i:i + self.input_length] for i in ix])
        targets_batch = torch.stack([data[i + 1:i + self.input_length + 1] for i in ix])
        inputs_batch = inputs_batch.to(self.device)
        targets_batch = targets_batch.to(self.device)
        return inputs_batch, targets_batch

def read_local_txt(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        text = f.read()
    return text

# Main execution
if __name__ == "__main__":
    # Load text data
    file_path = "WarrenBuffet.txt"  # Replace with your local file path
    text = read_local_txt(file_path)

    # Initialize and prepare the model
    model = BigramLanguageModel(batch_size=32, input_length=8, train_iters=5000)
    model = model.to(model.device)
    model.prep_tokens(text)

    # Generate text before training
    print("Generated text before training:")
    outputs = model.generate(context=torch.zeros((1, 1), dtype=torch.long, device=model.device), max_new_tokens=100)
    print(outputs)

    # Train the model
    model.fit(learning_rate=0.1)

    # Generate text after training
    print("Generated text after training:")
    outputs = model.generate(context=torch.zeros((1, 1), dtype=torch.long, device=model.device), max_new_tokens=100)
    print(outputs)