In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch.optim import Optimizer

import json
from typing import Dict, Tuple, List
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
with open('../data/epirecipes/full_format_recipes.json') as json_data:
    text = json.load(json_data)

text = [
    'Recipe for ' + x['title'] + ' | ' + ''.join(x['directions'])
    for x in text
    if 'title' in x
    and x['title'] is not None
    and 'directions' in x
    and x['directions'] is not None]

In [None]:
import string, re
def pad_punctuation(s:str):
    s = re.sub(f"([{string.punctuation}])", r' \1 ', s)
    s = re.sub(' +', ' ', s)
    return s.lower()

text = [pad_punctuation(x) for x in text]

In [None]:
from collections import Counter
VOCAB_SIZE = 10
word_count = Counter(word for recipe in text for word in recipe.split())
special_tokens = ['<pad>', '<unk>', '<eos>']
vocab = special_tokens + [word for word, _ in word_count.most_common(VOCAB_SIZE)]
VOCAB_SIZE += len(special_tokens)
word2idx = {word:index for index, word in enumerate(vocab)}
tokenizer = lambda x: [word2idx.get(word, word2idx['<unk>']) for word in x.split()]

In [None]:
class Epirecipes(Dataset):
    def __init__(self, data:List[List[int]], seq_len:int):
        self.data = [[word2idx['<pad>']] * seq_len + 
                     recipe + [word2idx['<eos>']] for recipe in data]
        self.seq_len = seq_len
        
    def __getitem__(self, index) -> torch.Tensor:
        rand_idx = np.random.randint(0, len(self.data[index]) - self.seq_len)
        X =  torch.tensor(self.data[index][rand_idx:rand_idx+self.seq_len])
        y = torch.tensor(self.data[index][rand_idx+self.seq_len])
        return X, y
    
    def __len__(self) -> int:
        return len(self.data)

SEQ_LEN = 256
BATCH_SIZE = 32
recipes = Epirecipes([tokenizer(x) for x in text], SEQ_LEN)
dataloader = DataLoader(recipes, batch_size=BATCH_SIZE)

In [None]:
class Cell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Cell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size 
        
        self.W_input = nn.Linear(input_size, 4 * hidden_size)
        torch.nn.init.xavier_normal_(self.W_input.weight)
        torch.nn.init.constant_(self.W_input.bias, 0.0)
        
    def forward(self, x, h_state, c_state):
        combined = torch.cat((x, h_state), dim=-1)
        gate_inputs = self.W_input(combined)
        f_gate, i_gate, c_gate, o_gate = torch.chunk(gate_inputs, chunks=4, dim=-1)
        f_gate = torch.sigmoid(f_gate)
        i_gate = torch.sigmoid(i_gate)
        c_gate = torch.tanh(c_gate)
        o_gate = torch.sigmoid(o_gate)
        
        new_c_state = i_gate * c_gate + f_gate * c_state
        new_h_state = o_gate * torch.tanh(new_c_state)
        
        return new_h_state, new_c_state

class LSTM(nn.Module):
    def __init__(self, embd_size:int, hidden_size:Tuple, num_layers:int, vocab_size:int, batch_size:int):
        super(LSTM, self).__init__()
        self.embd_size = embd_size 
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.vocab_size = vocab_size
        self.batch_size = batch_size
        
        self.embedding = nn.Embedding(vocab_size, embd_size)
        self.input_size = [sum(pair) for pair in zip([embd_size] + list(hidden_size[:-1]), hidden_size)]
        self.layers = nn.ModuleList([Cell(self.input_size[i], self.hidden_size[i]) for i in range(num_layers)])
        
        # -----------------> experimental
        self.output = nn.Sequential(
                nn.Linear(hidden_size[-1], vocab_size),
                nn.Softmax(dim=-1))
                
    def forward(self, X: torch.Tensor):
        batch_size = X.shape[0]
        X = self.embedding(X).permute(1, 0, 2)
        G = list(X)
        device = X.device
        for j in range(self.num_layers):
            h_state = torch.zeros((batch_size, self.hidden_size[j])).to(device)
            c_state = torch.zeros_like(h_state)
            for i in range(len(G)):
                h_state, c_state = self.layers[j](G[i], h_state, c_state)
                G[i] = h_state
        # -----------> experimental 
        out = self.output(h_state)
        return out

    def generate(self, X:str, seq_len:int=256, temperature: float=1.0) -> str:
        X = self.to_input(X, seq_len).tolist()
        self.eval()
        with torch.inference_mode():
            while X[-1] is not word2idx['<eos>']:
                out = self(torch.tensor(X[-seq_len:]).unsqueeze(dim=0).to(device))
                out = out.pow(1 / temperature)
                out /= out.sum(dim=1, keepdim=True)
                out = torch.multinomial(out, 1).item()
                X.append(out)
        return self.to_text(X)
    
    def to_input(self, X:str, seq_len:int=256) -> torch.Tensor:
        X = tokenizer(X)
        X = X[-seq_len:] if len(X) > seq_len else X
        padding_size = seq_len - len(X)
        X = [word2idx['<pad>']] * (padding_size) + X
        return torch.tensor(X)
        
    def to_text(self, X:torch.Tensor) -> str:
        X = " ".join([vocab[token] for token in X if vocab[token] not in special_tokens])
        return X

In [None]:
def trainer(model: nn.Module, dataloader: DataLoader,
            optimizer: Optimizer, loss_fn: nn.Module, 
            device: torch.device, EPOCHS: int, 
            test: torch.Tensor, metric: callable=None) -> None:

    model = model.to(device)
    model.train()
    report = {"loss": [], metric.__name__: []}
    
    for epoch in range(EPOCHS):
        print(f"Epoch: ---------------> {epoch}/{EPOCHS}")
        itr_loss = .0
        itr_metric = .0
        
        for X, y in tqdm(dataloader):
            X, y = X.to(device), y.to(device)        
            logit = model(X)
            loss = loss_fn(logit, y)
            
            itr_loss += loss.item()
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        itr_loss /= len(dataloader)
        itr_metric = metric(model, test)
        
        print(f"Loss: ---------------> {itr_loss:.4f}")
        print(f"{metric.__name__} ---------> {itr_metric:.4f}")
        
        report['loss'].append(itr_loss)
        report[metric.__name__].append(itr_metric)

                
SEQ_LEN = 16
HIDDEN_SIZE = (10, 20, 30)
EMBD_SIZE = 10
NUM_LAYERS = 3
LEARNING_RATE = .0001
BETAS = (0.9, 0.99)
EPOCHS = 5

def perplexity(model: nn.Module, test: torch.Tensor, context_len: int) -> float:
    ppl = .0
    for recipe in test: 
        padded = word2idx['<pad>'] * (context_len - 1) + recipe
        

n_test_samples = int(len(text) * 0.1)
test_text = np.random.choice(text, size=n_test_samples, replace=False)
test = [tokenizer(recipe) for recipe in test_text]


model = LSTM(EMBD_SIZE, HIDDEN_SIZE, NUM_LAYERS, VOCAB_SIZE, BATCH_SIZE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, betas=BETAS)
loss_fn = nn.CrossEntropyLoss()
device = 'cuda' if torch.cuda.is_available() else 'cpu'

trainer(model, dataloader, optimizer, loss_fn, device, EPOCHS, test, perplexity)
