In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch.optim import Optimizer

import json
from typing import Dict, Tuple, List
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
with open('../data/epirecipes/full_format_recipes.json') as json_data:
    text = json.load(json_data)

text = [
    'Recipe for ' + x['title'] + ' | ' + ''.join(x['directions'])
    for x in text
    if 'title' in x
    and x['title'] is not None
    and 'directions' in x
    and x['directions'] is not None]

In [3]:
import string, re
def pad_punctuation(s:str):
    s = re.sub(f"([{string.punctuation}])", r' \1 ', s)
    s = re.sub(' +', ' ', s)
    return s.lower()

text = [pad_punctuation(x) for x in text]

In [4]:
from collections import Counter
VOCAB_SIZE = 10000
word_count = Counter(word for recipe in text for word in recipe.split())
vocab = ['<pad>', '<unk>'] + [word for word, _ in word_count.most_common(VOCAB_SIZE)]
word2idx = {word:index for index, word in enumerate(vocab)}
tokenizer = lambda x: [word2idx.get(word, word2idx['<unk>']) for word in x.split()]

In [20]:
class Epirecipes(Dataset):
    def __init__(self, data:List[List[int]], seq_len:int):
        self.data = [[word2idx['<pad>']] * seq_len + recipe for recipe in data]
        self.seq_len = seq_len
        
    def __getitem__(self, index) -> torch.Tensor:
        rand_idx = np.random.randint(0, len(self.data[index]) - self.seq_len)
        X =  torch.tensor(self.data[index][rand_idx:rand_idx+self.seq_len])
        y = torch.tensor(self.data[index][rand_idx+self.seq_len])
        return X, y
    
    def __len__(self) -> int:
        return len(self.data)

SEQ_LEN = 256
recipes = Epirecipes([tokenizer(x) for x in text], SEQ_LEN)

dataloader = DataLoader(recipes, batch_size=32, shuffle=True)

In [24]:
class Cell(nn.Module):
    def __init__(self):
        ...

    def forward(self, input):
        ...

class LSTM(nn.Module):
    def __init__(self):
        ...

    def forward(self, input):
        ...