# LSTM on Recipe Data

**The notebook has been adapted from the notebook provided in David Foster's Generative Deep Learning, 2nd Edition.**

- Book: [Amazon](https://www.amazon.com/Generative-Deep-Learning-Teaching-Machines/dp/1098134184/ref=sr_1_1?keywords=generative+deep+learning%2C+2nd+edition&qid=1684708209&sprefix=generative+de%2Caps%2C93&sr=8-1)
- Original notebook (tensorflow and keras): [Github](https://github.com/davidADSP/Generative_Deep_Learning_2nd_Edition/blob/main/notebooks/05_autoregressive/01_lstm/lstm.ipynb)
- Dataset: [Kaggle](https://www.kaggle.com/datasets/hugodarwood/epirecipes)

In [1]:
import re
import string
import json
from collections import defaultdict
import time

import numpy as np
import jax
import jax.numpy as jnp
from tensorflow.data import Dataset
from tensorflow.keras.layers import TextVectorization
from tensorflow.keras import utils

from flax import struct
from flax.training import train_state
import flax.linen as nn
import optax

from clu import metrics

## 0. Train parameters

In [2]:
DATA_DIR = '../../data/epirecipes/full_format_recipes.json'

EMBEDDING_DIM = 100
HIDDEN_DIM = 128
NUM_LSTM_LAYERS = 2
VALIDATION_SPLIT = 0.2
BATCH_SIZE = 32
EPOCHS = 30
VOCAB_SIZE = 10000
LR = 1e-3

MAX_PAD_LEN = 200
MAX_VAL_TOKENS = 100 # Max number of tokens when generating texts

## 1. Load dataset

In [3]:
def pad_punctuation(sentence):
    sentence = re.sub(f'([{string.punctuation}])', r' \1 ', sentence)
    sentence = re.sub(' +', ' ', sentence)
    return sentence

In [4]:
# load dataset
with open(DATA_DIR, 'r+') as f:
    recipe_data = json.load(f)

In [5]:
# preprocess dataset
filtered_data = [
    'Recipe for ' + x['title'] + ' | ' + ' '.join(x['directions'])
    for x in recipe_data
    if 'title' in x and x['title']
    and 'directions' in x and x['directions']
]

text_ds = [pad_punctuation(sentence) for sentence in filtered_data]
print(f'Total recipe loaded: {len(text_ds)}')

Total recipe loaded: 20098


In [6]:
print('Sample data:')
sample_data = np.random.choice(text_ds)
print(sample_data)

Sample data:
Recipe for Shrimp Toasts With Sesame Seeds and Scallions | Pulse shrimp , chili paste , lemongrass , fish sauce , and ginger in a food processor until smooth . Season with salt and pulse again to combine . Transfer mixture to a medium bowl ; stir in scallion whites . Place sesame seeds on a plate . Spread shrimp mixture over bread slices , extending all the way to edges . Press bread , shrimp side down , into sesame seeds to coat evenly . Pour oil into a large skillet to come 1 / 4 " up sides and heat over medium - high until a small pinch of shrimp mixture sizzles when added to oil . Working in 2 batches , fry toasts , shrimp side down , until golden and crisp , about 2 minutes ; turn and cook until other sides are golden and crisp , about 1 minute . Transfer to a paper towel–lined wire rack to drain . Cut each toast diagonally into quarters and top with scallion greens . 


## 2. Build vocabularies

In [18]:
# conver texts list to tf dataset
text_ds_tf = Dataset.from_tensor_slices(text_ds)

vectorize_layer = TextVectorization(
    standardize='lower',
    max_tokens=VOCAB_SIZE,
    output_mode='int',
    output_sequence_length=MAX_PAD_LEN+1
)

In [19]:
vectorize_layer.adapt(text_ds_tf)
vocab = vectorize_layer.get_vocabulary()
index_to_word = {index : word for index, word in enumerate(vocab)}
word_to_index = {word : index for index, word in enumerate(vocab)}

# First 10 items in the vocabulary
for i, word in enumerate(vocab[:10]):
    print(f'{i}: {word}')

0: 
1: [UNK]
2: .
3: ,
4: and
5: to
6: in
7: the
8: with
9: a


In [20]:
sample_data_tokenized = vectorize_layer(sample_data)
print('Source text:')
print(sample_data)
print('\n')
print('Mapped sample:')
print(sample_data_tokenized.numpy())

Source text:
Recipe for Shrimp Toasts With Sesame Seeds and Scallions | Pulse shrimp , chili paste , lemongrass , fish sauce , and ginger in a food processor until smooth . Season with salt and pulse again to combine . Transfer mixture to a medium bowl ; stir in scallion whites . Place sesame seeds on a plate . Spread shrimp mixture over bread slices , extending all the way to edges . Press bread , shrimp side down , into sesame seeds to coat evenly . Pour oil into a large skillet to come 1 / 4 " up sides and heat over medium - high until a small pinch of shrimp mixture sizzles when added to oil . Working in 2 batches , fry toasts , shrimp side down , until golden and crisp , about 2 minutes ; turn and cook until other sides are golden and crisp , about 1 minute . Transfer to a paper towel–lined wire rack to drain . Cut each toast diagonally into quarters and top with scallion greens . 


Mapped sample:
[  26   16  261  939    8  517  234    4  546   27  437  261    3  543
  345    3 1

## 3. Create train/validation datasets

In [21]:
def map_src_tgt(text):
    tokenized_sentence = vectorize_layer(text)
    src = tokenized_sentence[:-1]
    tgt = tokenized_sentence[1:]
    return src, tgt
    

def get_datasets(input_ds):
    train_size = int(len(input_ds) * (1 - VALIDATION_SPLIT))
    train_ds = input_ds.take(train_size)
    valid_ds = input_ds.skip(train_size)
    print(f'train size: {train_ds.cardinality()}, valid size: {valid_ds.cardinality()}')

    train_ds = train_ds.map(map_src_tgt)
    valid_ds = valid_ds.map(map_src_tgt)
    
    train_ds = train_ds.batch(BATCH_SIZE).shuffle(1024).prefetch(1)
    valid_ds = valid_ds.batch(BATCH_SIZE).prefetch(1)

    print(f'train batch: {train_ds.cardinality()}, valid batch: {valid_ds.cardinality()}')
    return train_ds, valid_ds

In [22]:
train_ds, valid_ds = get_datasets(text_ds_tf)

train size: 16078, valid size: 4020
train batch: 503, valid batch: 126


## 4. Build LSTM model

In [23]:
class LSTM_model(nn.Module):

    num_lstm_layers: int
    
    def setup(self):
        self.embed = nn.Embed(num_embeddings=VOCAB_SIZE, features=HIDDEN_DIM)
        self.lstm_layers = [nn.LSTMCell(name=f'lstm_{i}') for i in range(self.num_lstm_layers)]
        self.dense = nn.Dense(features=VOCAB_SIZE)

    def __call__(self, x):
        # Embedding
        x = self.embed(x)

        # LSTM
        for i in range(self.num_lstm_layers):
            hidden_state = self.lstm_layers[i].initialize_carry(jax.random.PRNGKey(0), (x.shape[0], x.shape[1]), HIDDEN_DIM)
            _, x = self.lstm_layers[i](hidden_state, x)
        
        # Dense layer
        x = self.dense(x)
        return x

lstm_model = LSTM_model(NUM_LSTM_LAYERS)
rng = jax.random.PRNGKey(0)

print(lstm_model.tabulate(rng, jnp.ones((BATCH_SIZE, MAX_PAD_LEN), dtype=int)))


[3m                                              LSTM_model Summary                                              [0m
┏━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath     [0m[1m [0m┃[1m [0m[1mmodule    [0m[1m [0m┃[1m [0m[1minputs                 [0m[1m [0m┃[1m [0m[1moutputs                [0m[1m [0m┃[1m [0m[1mparams                       [0m[1m [0m┃
┡━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│           │ LSTM_model │ [2mint32[0m[32,200]           │ [2mfloat32[0m[32,200,10000]   │                               │
├───────────┼────────────┼─────────────────────────┼─────────────────────────┼───────────────────────────────┤
│ embed     │ Embed      │ [2mint32[0m[32,200]           │ [2mfloat32[0m[32,200,128]     │ embedding: [2mfloat32[0m[10000,128] │
│           │            │                         │  

## 5. Create `TrainState`

In [24]:
@struct.dataclass
class Metrics(metrics.Collection):
    loss: metrics.Average.from_output('loss')


class TrainState(train_state.TrainState):
    metrics: Metrics


def create_train_state(model, param_key, learning_rate):
    # initialize model
    params = model.init(param_key, jnp.ones((BATCH_SIZE, MAX_PAD_LEN), dtype=int))['params']
    # initialize optimizer
    tx = optax.adam(learning_rate=learning_rate)
    return TrainState.create(
            apply_fn=model.apply,
            params=params,
            tx=tx,
            metrics=Metrics.empty())

## 6. Train step functions

In [25]:
# train step
@jax.jit
def train_step(state, batch):
    def loss_fn(params):
        preds = state.apply_fn({'params': params}, batch[0])
        loss = optax.softmax_cross_entropy_with_integer_labels(preds, batch[1]).mean()
        return loss

    # compute loss and apply gradients
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)

    # Update metrics
    metric_updates = state.metrics.single_from_model_output(loss=loss)
    metrics = state.metrics.merge(metric_updates)
    state = state.replace(metrics=metrics)
    return state 

# evaluation
@jax.jit
def validation(state, batch):
    preds = state.apply_fn({'params': state.params}, batch[0])
    loss = optax.softmax_cross_entropy_with_integer_labels(preds, batch[1]).mean()

    # Update metrics
    metric_updates = state.metrics.single_from_model_output(loss=loss)
    metrics = state.metrics.merge(metric_updates)
    state = state.replace(metrics=metrics)
    return state

In [26]:
# get next-word probability distribution
@jax.jit
def get_probs(state, input_tokens):
    return state.apply_fn({'params': state.params}, input_tokens)[0][-1]

# Text generator
class TextGenerator():
    def __init__(self, index_to_word):
        self.index_to_word = index_to_word
        self.word_to_index = {word : index for index, word in index_to_word.items()}

    # scaling the model's output probability with temperature
    def sample_from(self, probs, temperature):
        probs = probs ** (1 / temperature)
        probs = probs / np.sum(probs)
        return np.random.choice(VOCAB_SIZE, p=probs), probs
    
    # generate text
    def generate(self, state, start_prompt, max_tokens, temperature, output_info=False):
        
        start_tokens = [self.word_to_index[word] for word in start_prompt.split()]
        sample_token = None
        info = []

        while len(start_tokens) < max_tokens and sample_token != 0:
            input_tokens = np.array(start_tokens).reshape(1, -1)
            probs = get_probs(state, input_tokens)
            probs = nn.log_softmax(probs)
            sample_token, probs = self.sample_from(np.exp(probs), temperature)
            start_tokens.append(sample_token)
            if output_info:
                info.append({'tokens': np.copy(start_tokens), 'word_probs': probs})
            
        output_text = [self.index_to_word[token] for token in start_tokens if token != 0]
        print(' '.join(output_text))

        return info

## 7. Training

In [27]:
lstm_model = LSTM_model(NUM_LSTM_LAYERS)
state = create_train_state(lstm_model, jax.random.PRNGKey(0), learning_rate=LR)
text_generator = TextGenerator(index_to_word)

In [28]:
loss_hist = defaultdict(list)

for i in range(EPOCHS):
    prev_time = time.time()
    
    #training
    for batch in train_ds.as_numpy_iterator():
        state = train_step(state, batch)

    train_loss = state.metrics.compute()['loss']
    state = state.replace(metrics=state.metrics.empty())

    #validation
    test_state = state
    for batch in valid_ds.as_numpy_iterator():
        test_state = validation(test_state, batch)

    valid_loss = test_state.metrics.compute()['loss']
    
    
    loss_hist['train_loss'].append(train_loss)
    loss_hist['valid_loss'].append(valid_loss)

    curr_time = time.time()
    print(f'Epoch: {i+1}\tepoch time {(curr_time - prev_time) / 60:.2f} min')
    print(f'\ttrain loss: {train_loss:.4f}, valid loss: {valid_loss:.4f}')

    if (i + 1) % 10 == 0:
        # generate text
        print('\nGenerated text:')
        info = text_generator.generate(state, 'recipe for', MAX_VAL_TOKENS, 1.0)
        print('\n')

Epoch: 1	epoch time 0.09 min
	train loss: 4.4139, valid loss: 3.2973
Epoch: 2	epoch time 0.06 min
	train loss: 3.0422, valid loss: 2.8731
Epoch: 3	epoch time 0.06 min
	train loss: 2.8226, valid loss: 2.7693
Epoch: 4	epoch time 0.06 min
	train loss: 2.7436, valid loss: 2.7220
Epoch: 5	epoch time 0.06 min
	train loss: 2.7007, valid loss: 2.6938
Epoch: 6	epoch time 0.05 min
	train loss: 2.6722, valid loss: 2.6759
Epoch: 7	epoch time 0.06 min
	train loss: 2.6520, valid loss: 2.6625
Epoch: 8	epoch time 0.06 min
	train loss: 2.6369, valid loss: 2.6531
Epoch: 9	epoch time 0.06 min
	train loss: 2.6248, valid loss: 2.6453
Epoch: 10	epoch time 0.06 min
	train loss: 2.6149, valid loss: 2.6398

Generated text:
recipe for duck breasts , paprika , until coated and add chicken is heating pie dish ) chill for fish sauce , and mix 1 minute . place shiso . in single layer . finely chop chocolate - high speed until the grill for pasta in oats , uncovered until it , cinnamon , whisking until tender , knoc