<a href="https://colab.research.google.com/github/tyobeka/recipe-generator/blob/main/training_an_rnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Purpose of this Notebook

The purpose of this notebook is to provide a template for preparing data and training an RNN for sequence-to-sequence modelling. A lot of the code is based on lab exercise that I completed for the [2025 MIT Introduction to Deep Learning course](https://github.com/MITDeepLearning/introtodeeplearning/tree/master) that I have been working through independently.

## 1. Installing Dependencies.

Before we can start working, we need to install dependencies, and import relevant packages needed for this task.

In [None]:
# comet
!pip install comet_ml > /dev/null 2>&1
import comet_ml
COMET_API_KEY = "XYdAKTEOdcNCOQmuvd2YICcmp"

assert COMET_API_KEY != "", "Please insert your Comet API Key"

In [None]:
# pytorch and relevant libraries
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os
import time
import functools
from IPython import display as ipythondisplay
from tqdm import tqdm

assert torch.cuda.is_available(), "Please enable GPU from runtime settings"

In [None]:
# MIT Introduction to Deep Learning package and relevant packages
# for music generation task
!pip install mitdeeplearning --quiet
import mitdeeplearning as mdl

from scipy.io.wavfile import write
!apt-get install abcmidi timidity > /dev/null 2>&1


## 2. Loading Dataset

The dataset used for the music generation task consists of Irish folk songs, represented in ABC notation. Along with each song, is its meta data containing additional information about the song:
- X: the songs index
- T: title of the song
- Z: unique identifier of the song
- M: a feature of the song?
- L: the tempo of the song?
- K: the key that the song is played in?

In [None]:
# download the dataset
songs = mdl.lab1.load_training_data()

# print an example of a song
example_song = songs[0]
print("\nExample song: ")
print(example_song)

In [None]:
# use the following code to convert the ABC notation to an audio file to listen to
# mdl.lab1.play_song(example_song)

## 3. Tokenizing the Dataset

This involves determining what the basic unit of each observation in the dataset is, and generating a vocabulary that consists of a collection of these units that can be sequenced together to represent each observation in the dataset.

In [None]:
# join list of songs into single string containing all songs
songs_joined = "\n\n".join(songs)

# find all unique characters in the joined string
vocab = sorted(set(songs_joined))
print("There are", len(vocab), "unique characters in the dataset")

## 4. Preprocessing the Dataset

Having decided what the basic unit of an observation, we can represent the data in the correct format to be processed by the RNN model.

### Vectorizing the text

In order for the model to process the data, we need to create a numerical representation for the text-based data. To do this, two lookup tables are generated:

1. `char2idx`: Maps characters to numbers.
2. `idx2char`: Maps numbers back to characters.

In [None]:
char2idx = {u: i for i, u in enumerate(vocab)}
idx2char = np.array(vocab)

# print an example to show how the tables work
print(f"The index for 'm': {char2idx['m']}.")
print(f"The character corresponding to index 68: '{idx2char[68]}'.")

Vectorize the songs in the `songs_joined` string:

In [None]:
def vectorize_string(string):
  """
  A function to convert all songs to a numeric representation.

  Output: an np.array of dimension N elements, were N is the number of characters
  in the input string.
  """
  output = []
  for char in string:
    output.append(char2idx[char])

  return np.array(output)

vectorized_songs = vectorize_string(songs_joined)

### Creating training inputs and targets

The input to an RNN is a sequence of characters of length `seq_length`. For this task, we will also define a target sequence of the same length, except that that it will be shifted one character to the right.

For example if the `seq_length` is 4 and the text is "Hello", then the input sequence will be "Hell" and the target sequence will be "ello". Meaning that the text in the data needs to broken into chunks of `seq_length + 1` for training.

It also means that after training, we should be able to generate a single a character, or short sequence of characters.

We will also be batching data for training, so a batch function will be written to do so:

In [None]:
def get_batch(vectorized_songs, seq_length, batch_size):

  # highest index in vectorized_songs: 0 to n
  n = vectorized_songs.shape[0] - 1

  # randomly choose the starting indices of input examples in the training batch
  idx = np.random.choice(n - seq_length, batch_size)

  input_batch = []
  output_batch = []
  for i in idx:
    input_batch.append(vectorized_songs[i:i+seq_length])
    output_batch.append(vectorized_songs[i+1:i+seq_length+1])

  # covert batches to tensors
  x_batch = torch.tensor(input_batch, dtype=torch.long)
  y_batch = torch.tensor(output_batch, dtype=torch.long)

  return x_batch, y_batch

x_batch, y_batch = get_batch(vectorized_songs, seq_length=10, batch_size=2)
print(f"x_batch shape: {x_batch.shape}")
print(f"y_batch shape: {y_batch.shape}")

Some code to illustrate why we define the target sequence as a sequence of the same length as its corresponding input sequence, but shifted one character to the right:

In [None]:
x_batch, y_batch = get_batch(vectorized_songs, seq_length=5, batch_size=1)

for i, (input_idx, target_idx) in enumerate(zip(x_batch[0], y_batch[0])):
    print("Step {:3d}".format(i))
    print("  input: {} ({:s})".format(input_idx, repr(idx2char[input_idx.item()])))
    print("  expected output: {} ({:s})".format(target_idx, repr(idx2char[target_idx.item()])))

**For each input character, we output a character that is then fed back into the model as input: this will enable the model to generate a song, one character at a time.**

## 5. Define the Recurrent Neural Network (RNN) Model

The model that we will be using is based off the **LSTM architecture**, that uses two state vectors to maintain information about the temporal relationships between the consecutive characters:
- Cell state:

$$ \mathbf{C}^*_t= tanh\left(\mathbf{W}_C\left[\mathbf{h}_{t-1}, \mathbf{x}_t\right] + \mathbf{b}_C\right)$$

$$ C_t = \mathbf{f}_t*\mathbf{C}^*_{t-1} + \mathbf{i}_t*\mathbf{C}^*_t$$

- Hidden state:
$$ \mathbf{o}_t = \sigma \left(\mathbf{W}_o\left[\mathbf{h}_{t-1}, \mathbf{x}_t\right] + \mathbf{b}_o\right)$$
$$ \mathbf{h}_t = \mathbf{o}_t * tanh(\mathbf{C}_t) $$

where,
- $\mathbf{i}_t = \sigma \left(\mathbf{W}_i\left[\mathbf{h}_{t-1}, \mathbf{x}_t\right] + \mathbf{b}_i\right)$ is an **"input gate layer"** gate that decides which values in the cell state to update.
- $\mathbf{f}_t = \sigma \left(\mathbf{W}_i\left[\mathbf{h}_{t-1}, \mathbf{x}_t\right] + \mathbf{f}_i\right)$ is a **"forget gate layer"** gate that decides which layers from the previous cell state to get rid of.
- $\mathbf{o}_t$ is **"output layer"** gate that decides which parts of the cell state we're going to output as the hidden state.

See the following [link](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) for more details.

We will be using the PyTorch's `nn.Module` to define the RNN:

In [None]:
class LSTMModel(nn.Module):
  def __init__(self, vocab_size, embedding_dim, hidden_size):
    super(LSTMModel, self).__init__()

    self.hidden_size = hidden_size

    # define different components of lstm
    self.embedding = nn.Embedding(
        num_embeddings=vocab_size,
        embedding_dim=embedding_dim
        )

    self.lstm = nn.LSTM(
        input_size=embedding_dim,
        hidden_size=hidden_size,
        num_layers=1,
        batch_first=True,
        dropout=0,
        bidirectional=False
        )
    self.fc = nn.Linear(
        in_features=hidden_size,
        out_features=vocab_size
    )

    # for t=0, C_t and h_t are zero vectors
    def init_hidden(self, batch_size, device):
      return (torch.zeros(1, batch_size, self.hidden_size).to(device),
              torch.zeros(1, batch_size, self.hidden_size).to(device))

    # forward function
    def forward(self, x, state=None, return_state=False):
      x = self.embedding(x)

      if state is None:
        state = self.init_hidden(x.size(0), x.device)

      out, state = self.lstm(x, state)
      out = self.fc(out)

      if return_state:
        return (out, state)
      else:
        return out

Let's instantiate the model to see what it looks like:

In [None]:
# define parameters
params = dict(
  batch_size = 8,
  embedding_dim = 256,
  hidden_size = 1024,
  vocab_size = len(vocab),
  seq_length = 100,
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  )

# instantiate model
model = LSTMModel(
    vocab_size=params['vocab_size'],
    embedding_dim=params['embedding_dim'],
    hidden_size=params['hidden_size']
    )

# move it to the correct device
model = model.to(params['device'])

# print the model
print(model)

Let us test the model to check whether it performs as expected:

In [None]:
x, y = get_batch(vectorized_songs=vectorized_songs, seq_length=params['seq_length'], batch_size=params['batch_size'])
x = x.to(params['device'])
y = y.to(params['device'])

yhat = model(x)
print("Input shape:      ", x.shape, " # (batch_size, sequence_length)")
print("Model output shape: ", yhat.shape, "# (batch_size, sequence_length, vocab_size)")

### Computing predictions

To get actual predictions from the model, we need to begin by defining an output distribution: `torch.softmax` is applied over the output logits. The output distribution is a categorical distribution, we then sample from this distribution to obtain a prediction using `torch.multinomial`.

**Note:** we sample from the output distribution, as opposed to simply taking `argmax` to avoid the model getting stuck in a repetitive loop, outputting the same character multiple times in the output.

In [None]:
# computing a prediction
example_logit = yhat[0]
sampled_indices = torch.multinomial(torch.softmax(example_logit, dim=-1), num_samples = 1)
sampled_indices = sampled_indices.squeeze(-1).cpu().numpy()
sampled_indices

In [None]:
# decoding the index to see the text produced
print("Input: \n", repr("".join(idx2char[x[0].cpu()])))
print()
print("Next Char Predictions: \n", repr("".join(idx2char[sampled_indices])))