# LSTM on Recipe Data

**The notebook has been adapted from the notebook provided in David Foster's Generative Deep Learning, 2nd Edition.**

- Book: [Amazon](https://www.amazon.com/Generative-Deep-Learning-Teaching-Machines/dp/1098134184/ref=sr_1_1?keywords=generative+deep+learning%2C+2nd+edition&qid=1684708209&sprefix=generative+de%2Caps%2C93&sr=8-1)
- Original notebook (tensorflow and keras): [Github](https://github.com/davidADSP/Generative_Deep_Learning_2nd_Edition/blob/main/notebooks/05_autoregressive/01_lstm/lstm.ipynb)
- Dataset: [Kaggle](https://www.kaggle.com/datasets/hugodarwood/epirecipes)

In [28]:
import re
import string
import json

import numpy as np
import jax
import jax.numpy as jnp
from tensorflow.data import Dataset
from tensorflow.keras.layers import TextVectorization
from tensorflow.keras import utils

from flax import struct
from flax.training import train_state
import flax.linen as nn
import optax

from clu import metrics

## 0. Train parameters

In [15]:
DATA_DIR = '../../data/epirecipes/full_format_recipes.json'

EMBEDDING_DIM = 100
HIDDEN_DIM = 128
NUM_LSTM_LAYERS = 2
VALIDATION_SPLIT = 0.2
SEED = 1024
BATCH_SIZE = 32
EPOCHS = 30
VOCAB_SIZE = 10000


MAX_PAD_LEN = 200
MAX_VAL_TOKENS = 100 # Max number of tokens when generating texts

## 1. Load dataset

In [3]:
def pad_punctuation(sentence):
    sentence = re.sub(f'([{string.punctuation}])', r' \1 ', sentence)
    sentence = re.sub(' +', ' ', sentence)
    return sentence

In [4]:
# load dataset
with open(DATA_DIR, 'r+') as f:
    recipe_data = json.load(f)

In [5]:
# preprocess dataset
filtered_data = [
    'Recipe for ' + x['title'] + ' | ' + ' '.join(x['directions'])
    for x in recipe_data
    if 'title' in x and x['title']
    and 'directions' in x and x['directions']
]

text_ds = [pad_punctuation(sentence) for sentence in filtered_data]
print(f'Total recipe loaded: {len(text_ds)}')

Total recipe loaded: 20098


In [10]:
print('Sample data:')
sample_data = np.random.choice(text_ds)
print(sample_data)

Sample data:
Recipe for Greek - Style Shrimp | Heat oil in heavy large skillet over medium heat . Add onion and sauté until translucent , about 8 minutes . Add tomatoes , wine and oregano . Simmer until thickened , stirring occasionally , about 5 minutes . Add shrimp and stir until opaque , 3 minutes . Season with salt and pepper . Serve with rice . 


## 2. Build vocabularies

In [7]:
# conver texts list to tf dataset
text_ds_tf = Dataset.from_tensor_slices(text_ds)

vectorize_layer = TextVectorization(
    standardize='lower',
    max_tokens=VOCAB_SIZE,
    output_mode='int',
    output_sequence_length=MAX_PAD_LEN+1
)

In [8]:
vectorize_layer.adapt(text_ds_tf)
vocab = vectorize_layer.get_vocabulary()

# First 10 items in the vocabulary
for i, word in enumerate(vocab[:10]):
    print(f'{i}: {word}')

0: 
1: [UNK]
2: .
3: ,
4: and
5: to
6: in
7: the
8: with
9: a


In [11]:
sample_data_tokenized = vectorize_layer(sample_data)
print('Source text:')
print(sample_data)
print('\n')
print('Mapped sample:')
print(sample_data_tokenized.numpy())

Source text:
Recipe for Greek - Style Shrimp | Heat oil in heavy large skillet over medium heat . Add onion and sauté until translucent , about 8 minutes . Add tomatoes , wine and oregano . Simmer until thickened , stirring occasionally , about 5 minutes . Add shrimp and stir until opaque , 3 minutes . Season with salt and pepper . Serve with rice . 


Mapped sample:
[  26   16 1952   13 1032  261   27   17   37    6   78   30   56   20
   29   17    2   18  115    4  130   10  861    3   19  113   12    2
   18  182    3  249    4  639    2   70   10  426    3   48   90    3
   19   59   12    2   18  261    4   42   10  696    3   36   12    2
   63    8   24    4   33    2   68    8  205    2    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0  

## 3. Create train/validation datasets

In [12]:
def map_src_tgt(text):
    tokenized_sentence = vectorize_layer(text)
    src = tokenized_sentence[:-1]
    tgt = tokenized_sentence[1:]
    return src, tgt
    

def get_datasets(input_ds):
    train_size = int(len(input_ds) * (1 - VALIDATION_SPLIT))
    train_ds = input_ds.take(train_size)
    valid_ds = input_ds.skip(train_size)
    print(f'train size: {train_ds.cardinality()}, valid size: {valid_ds.cardinality()}')

    train_ds = train_ds.map(map_src_tgt)
    valid_ds = valid_ds.map(map_src_tgt)
    
    train_ds = train_ds.batch(BATCH_SIZE).shuffle(1024).prefetch(1)
    valid_ds = valid_ds.batch(BATCH_SIZE).prefetch(1)

    print(f'train batch: {train_ds.cardinality()}, valid batch: {valid_ds.cardinality()}')
    return train_ds, valid_ds

In [13]:
train_ds, valid_ds = get_datasets(text_ds_tf)

train size: 16078, valid size: 4020
train batch: 503, valid batch: 126


## 4. Build LSTM model

In [32]:
class LSTM_model(nn.Module):

    num_lstm_layers: int
    
    def setup(self):
        self.embed = nn.Embed(num_embeddings=VOCAB_SIZE, features=HIDDEN_DIM)
        self.lstm_layers = [nn.LSTMCell(name=f'lstm_{i}') for i in range(self.num_lstm_layers)]
        self.dense = nn.Dense(features=VOCAB_SIZE)

    def __call__(self, x):
        # Embedding
        x = self.embed(x)

        # LSTM
        for i in range(self.num_lstm_layers):
            hidden_state = self.lstm_layers[i].initialize_carry(jax.random.PRNGKey(0), (BATCH_SIZE, MAX_PAD_LEN), HIDDEN_DIM)
            _, x = self.lstm_layers[i](hidden_state, x)
        
        # Dense layer
        x = self.dense(x)
        x = nn.softmax(x) # activation
        return x

lstm_model = LSTM_model(NUM_LSTM_LAYERS)
rng = jax.random.PRNGKey(0)

print(lstm_model.tabulate(rng, jnp.ones((BATCH_SIZE, MAX_PAD_LEN), dtype=jnp.int16)))


[3m                                              LSTM_model Summary                                              [0m
┏━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath     [0m[1m [0m┃[1m [0m[1mmodule    [0m[1m [0m┃[1m [0m[1minputs                 [0m[1m [0m┃[1m [0m[1moutputs                [0m[1m [0m┃[1m [0m[1mparams                       [0m[1m [0m┃
┡━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│           │ LSTM_model │ [2mint16[0m[32,200]           │ [2mfloat32[0m[32,200,10000]   │                               │
├───────────┼────────────┼─────────────────────────┼─────────────────────────┼───────────────────────────────┤
│ embed     │ Embed      │ [2mint16[0m[32,200]           │ [2mfloat32[0m[32,200,128]     │ embedding: [2mfloat32[0m[10000,128] │
│           │            │                         │  

## 5. Create `TrainState`

In [33]:
@struct.dataclass
class Metrics(metrics.Collection):
    loss: metrics.Average.from_output('loss')

class TrainState(train_state.TrainState):
    metrics: Metrics


def create_train_state(model, param_key, learning_rate):
    # initialize model
    params = model.init(param_key, jnp.ones((BATCH_SIZE, MAX_PAD_LEN), dtype=jnp.int16))
    # initialize optimizer
    tx = optax.adam(learning_rate=learning_rate)
    return TrainState.create(
            apply_fn=model.apply,
            params=params,
            tx=tx,
            metrics=Metrics.empty())

In [34]:
lstm_model = LSTM_model(NUM_LSTM_LAYERS)
state = create_train_state(lstm_model, jax.random.PRNGKey(0), 1e-3)