## Importing libraries

In [None]:
from collections import Counter
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader as DL
import torch.utils.data
import math
import torch.nn.functional as F

## HyperParameter

In [None]:
batch_size = 64
max_len = 16
num_heads = 8

## Mount Google drive

In [None]:
from google.colab import drive as gdrive
gdrive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
os.chdir('/content/drive/My Drive/cornell movie-dialogs corpus/')
!ls

movie_characters_metadata.tsv  movie_titles_metadata.tsv  README.txt
movie_conversations.tsv        pairs_encoded.json	  WORDMAP_corpus.json
movie_lines.tsv		       raw_script_urls.tsv


In [None]:
corpus_movie_conv = '/content/drive/My Drive/cornell movie-dialogs corpus/movie_conversations.tsv'
corpus_movie_lines = '/content/drive/My Drive/cornell movie-dialogs corpus/movie_lines.tsv'

In [None]:
corpus_movie_conv

'/content/drive/My Drive/cornell movie-dialogs corpus/movie_conversations.tsv'

## Data Preparation

* In our approach, we establish a fixed length for our sequences, ensuring consistency in our data processing.
* As we handle data in batches, it's crucial to determine this maximum length beforehand.
* By doing so, we can efficiently store our data in matrices, streamlining the input process for our neural network.
* To accommodate sentences shorter than the designated maximum length, we employ padding.
* In this instance, we've set the maximum length at 25 characters, providing a standardized framework for our data processing pipeline.

### Reading the Movie Conversation and Lines

### Understanding conversations

### Conversation Grouping

* The conversation data is structured such that consecutive line form coherent conversations.
* Each group of lines represents a single conversation.

### Example

* For instance, lines 194 to 197 constitute one conversation.
* Similarly, lines 198 and 199 form another conversation.

This grouping approach facilitates the analysis and processing of conversations within the dataset, enabling efficient handling of sequential dialogues.

In [None]:
with open(corpus_movie_conv, 'r') as c:
  conv = c.readlines()

In [None]:
conv

["u0\tu2\tm0\t['L194' 'L195' 'L196' 'L197']\n",
 "u0\tu2\tm0\t['L198' 'L199']\n",
 "u0\tu2\tm0\t['L200' 'L201' 'L202' 'L203']\n",
 "u0\tu2\tm0\t['L204' 'L205' 'L206']\n",
 "u0\tu2\tm0\t['L207' 'L208']\n",
 "u0\tu2\tm0\t['L271' 'L272' 'L273' 'L274' 'L275']\n",
 "u0\tu2\tm0\t['L276' 'L277']\n",
 "u0\tu2\tm0\t['L280' 'L281']\n",
 "u0\tu2\tm0\t['L363' 'L364']\n",
 "u0\tu2\tm0\t['L365' 'L366']\n",
 "u0\tu2\tm0\t['L367' 'L368']\n",
 "u0\tu2\tm0\t['L401' 'L402' 'L403']\n",
 "u0\tu2\tm0\t['L404' 'L405' 'L406' 'L407']\n",
 "u0\tu2\tm0\t['L575' 'L576']\n",
 "u0\tu2\tm0\t['L577' 'L578']\n",
 "u0\tu2\tm0\t['L662' 'L663']\n",
 "u0\tu2\tm0\t['L693' 'L694' 'L695']\n",
 "u0\tu2\tm0\t['L696' 'L697' 'L698' 'L699']\n",
 "u0\tu2\tm0\t['L860' 'L861']\n",
 "u0\tu2\tm0\t['L862' 'L863' 'L864' 'L865']\n",
 "u0\tu2\tm0\t['L866' 'L867' 'L868' 'L869']\n",
 "u0\tu2\tm0\t['L870' 'L871' 'L872']\n",
 "u0\tu2\tm0\t['L924' 'L925']\n",
 "u0\tu2\tm0\t['L984' 'L985']\n",
 "u0\tu2\tm0\t['L1044' 'L1045']\n",
 "u0\tu3\tm0\t[

## Understanding lines

### Explanation of line content

* Each line in the dataset corresponds to a specific uttereance within a conversations.
* The content of each line includes the actual saying, either a question or a reply, along with the associated character.

### Example Illustration

* Line number 1045 contains the saying "they do not."
* The subsequent line provides the continuation of the conversation.
* For instance, if we examine the first conversation:
  * The initial line represents the question posed.
  * The following line serves as the reply to that questions.
  * This pattern continues throughout the conversation.
* To access a specific question, one can refer to the line number corresponding to the start of that question.
* Similarly, the subsequent line contains the reply to the preceding question.

This organization of the dataset enables easy identification and extraction of both questions and replies within the conversations.

In [None]:
with open(corpus_movie_lines, 'r', encoding='ISO-8859-1') as l:
  lines = l.readlines()

lines

['L1045\tu0\tm0\tBIANCA\tThey do not!\n',
 'L1044\tu2\tm0\tCAMERON\tThey do to!\n',
 'L985\tu0\tm0\tBIANCA\tI hope so.\n',
 'L984\tu2\tm0\tCAMERON\tShe okay?\n',
 "L925\tu0\tm0\tBIANCA\tLet's go.\n",
 'L924\tu2\tm0\tCAMERON\tWow\n',
 "L872\tu0\tm0\tBIANCA\tOkay -- you're gonna need to learn how to lie.\n",
 'L871\tu2\tm0\tCAMERON\tNo\n',
 '"L870\tu0\tm0\tBIANCA\tI\'m kidding.  You know how sometimes you just become this ""persona""?  And you don\'t know how to quit?"\n',
 'L869\tu0\tm0\tBIANCA\tLike my fear of wearing pastels?\n',
 '"L868\tu2\tm0\tCAMERON\tThe ""real you""."\n',
 'L867\tu0\tm0\tBIANCA\tWhat good stuff?\n',
 "L866\tu2\tm0\tCAMERON\tI figured you'd get to the good stuff eventually.\n",
 'L865\tu2\tm0\tCAMERON\tThank God!  If I had to hear one more story about your coiffure...\n',
 "L864\tu0\tm0\tBIANCA\tMe.  This endless ...blonde babble. I'm like boring myself.\n",
 'L863\tu2\tm0\tCAMERON\tWhat crap?\n',
 'L862\tu0\tm0\tBIANCA\tdo you listen to this crap?\n',
 'L861\tu2

## Data to dictionary

In [None]:
lines[0].split('\t')
# We need to index and what was said

['L1045', 'u0', 'm0', 'BIANCA', 'They do not!\n']

In [None]:
lines_dict = {}
for line in lines:
  objects = line.split('\t')
  line_idx = objects[0]
  lines_dict[line_idx] = objects[-1]

In [None]:
lines_dict

{'L1045': 'They do not!\n',
 'L1044': 'They do to!\n',
 'L985': 'I hope so.\n',
 'L984': 'She okay?\n',
 'L925': "Let's go.\n",
 'L924': 'Wow\n',
 'L872': "Okay -- you're gonna need to learn how to lie.\n",
 'L871': 'No\n',
 '"L870': 'I\'m kidding.  You know how sometimes you just become this ""persona""?  And you don\'t know how to quit?"\n',
 'L869': 'Like my fear of wearing pastels?\n',
 '"L868': 'The ""real you""."\n',
 'L867': 'What good stuff?\n',
 'L866': "I figured you'd get to the good stuff eventually.\n",
 'L865': 'Thank God!  If I had to hear one more story about your coiffure...\n',
 'L864': "Me.  This endless ...blonde babble. I'm like boring myself.\n",
 'L863': 'What crap?\n',
 'L862': 'do you listen to this crap?\n',
 'L861': 'No...\n',
 '"L860': 'Then Guillermo says ""If you go any lighter you\'re gonna look like an extra on 90210."""\n',
 'L699': 'You always been this selfish?\n',
 'L698': 'But\n',
 'L697': "Then that's all you had to say.\n",
 'L696': 'Well no...\n'

## Cleaning the conversation

In [None]:
def remove_punc(string):
  """
  Remove punctuation characters from the input string and convert it to lowercase.

  Parameters:
  string (str): The input string containing punctuation characters.

  Returns:
  string (str): The input string without any punctuation characters and converted to lowercase.
  """

  # Define a string containing all punctuation characters
  punctuations = '''!()-[]{};:"\<>/@#$%^&*_~'''

  #Initilize an empty string to stroe the input string without punctuation
  no_punct = ""

  #Iter over each character in the input string
  for char in string:
    # Chech if the character is not a punctuation character
    if char not in punctuations:
      #Append the character to the string without punctuation
      no_punct += char # Space is als a character

  # Convert the string without punctuation to lowercase and return it
  return no_punct.lower()

The above code iterates over conversations in a dataset, extracting conversation IDs, and then creating question-answer pairs based on these IDs. It removes punctuation and leading/trailing whitespace from the lines corresponding to each ID, splits the lines into words, and limits the length of each to a specified maximum length. Finally, it appends the question-answer pair to a list of pairs.

In [None]:
conv[0]

"u0\tu2\tm0\t['L194' 'L195' 'L196' 'L197']\n"

In [None]:
#this is string and we need to convert this to a python list.
conv[0].split('\t')[-1]

"['L194' 'L195' 'L196' 'L197']\n"

In [None]:
conv[0].split('\t')[-1].replace(' ', ',')

"['L194','L195','L196','L197']\n"

In [None]:
eval(conv[0].split('\t')[-1].replace(' ', ','))

['L194', 'L195', 'L196', 'L197']

In [None]:
# Initialize an empty list to store question-answer pairs
pairs = []

# Iterate over each conversation in the dataset
for i, con in enumerate(conv):
    try:
      # Extract the conversation IDs and evaluate them as a list
      ids = eval(con.split('\t')[-1].replace(' ', ','))

      # Iterate over the conversation IDs
      for i in range(len(ids)):
          # Initialize an empty list to store question-answer pairs for each conversation
          qa_pairs = []

          # Break the loop if it's the last conversation ID
          if i == len(ids) - 1:
              break

          # Remove punctuation and leading/trailing whitespace from the lines corresponding to the conversation IDs
          first = remove_punc(lines_dict[ids[i]].strip())
          second = remove_punc(lines_dict[ids[i + 1]].strip())

          # Split the lines into words and limit the length of each to 'max_len'
          qa_pairs.append(first.split()[:max_len])
          qa_pairs.append(second.split()[:max_len])

          # Append the question-answer pair to the list of pairs
          pairs.append(qa_pairs)
    except:
      pass
      # print("Error on i =", i, con)

In [None]:
len(pairs)

210531

In [None]:
pairs[0]

[['can',
  'we',
  'make',
  'this',
  'quick?',
  'roxanne',
  'korrine',
  'and',
  'andrew',
  'barrett',
  'are',
  'having',
  'an',
  'incredibly',
  'horrendous',
  'public'],
 ['well',
  'i',
  'thought',
  "we'd",
  'start',
  'with',
  'pronunciation',
  'if',
  "that's",
  'okay',
  'with',
  'you.']]

In [None]:
lines_dict["L194"]

'Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\n'

In [None]:
lines_dict["L194"].strip()

'Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.'

In [None]:
question = lines_dict["L194"].strip()
reply = lines_dict["L195"].strip()

q_list = question.split()
r_list = reply.split()

In [None]:
q_list, r_list

(['Can',
  'we',
  'make',
  'this',
  'quick?',
  'Roxanne',
  'Korrine',
  'and',
  'Andrew',
  'Barrett',
  'are',
  'having',
  'an',
  'incredibly',
  'horrendous',
  'public',
  'break-',
  'up',
  'on',
  'the',
  'quad.',
  'Again.'],
 ['Well',
  'I',
  'thought',
  "we'd",
  'start',
  'with',
  'pronunciation',
  'if',
  "that's",
  'okay',
  'with',
  'you.'])

In [None]:
qa_pair = [q_list, r_list]
qa_pair

[['Can',
  'we',
  'make',
  'this',
  'quick?',
  'Roxanne',
  'Korrine',
  'and',
  'Andrew',
  'Barrett',
  'are',
  'having',
  'an',
  'incredibly',
  'horrendous',
  'public',
  'break-',
  'up',
  'on',
  'the',
  'quad.',
  'Again.'],
 ['Well',
  'I',
  'thought',
  "we'd",
  'start',
  'with',
  'pronunciation',
  'if',
  "that's",
  'okay',
  'with',
  'you.']]

In [None]:
# Confirming that all the pairs have 2 list
for p in pairs:
  if len(p) != 2:
    print(len(p))
print("All pairs have 2 list")

All pairs have 2 list


## Word-to-Index Dictionary for Word Embeddings

### Introduction

Now we'll focus on constructing a word-to-index dictionary, an essential step in utilizing word embeddings. Word embeddings represent each word in a vocabulary as a dense vector, typically obtained from a one-hot encoding followed by an embedding layer. This process allows for more efficient representation and processing of textual data.

### Process Overview

* **Mapping Words to Indices**: Each unique word in the dataset will be assigned a unique index. This index will serve as the basis for creating one-hot vectors.
* **Generating One-Hot Vectors**: We will use PyTorch, to automatically convert these indices into one-hot vectors.
* **Utilizing Embedding Layers**: The one-hot vectors will then be inserted into an embedding layer. This layer transforms one-hot vectors into dense word embeddings, capturing semantic relationships between words.

### Steps

1. **Collecting Unique Words**: The first step involves gathering all the unique words present in the datasets.
2. **Calculating Word Frequencies**: We need to determine how often each word occurs in our dataset.
3. **Filtering Low-Frequency Words**: Words that occur infrequently, less than five times for instance, will be removed. This helps streamline the vocabulary size and reduces the complexity of the output layer in our model.

By following these steps, we ensure that our word-to-index dictionary effectively represents the vocabulary of our dataset while maintaining efficiency in computational resources.

## Creating Word Frequency Dictionary using collections

This code iterates over each question-answer pair in the list of pairs and updates a Counter object called word_freq with the frequencies of words appearing in both the questions and answers. The update() method increments the counts for each word encountered in the pairs.

In [None]:
# Initialize a Counter object to store word frequencies
word_freq = Counter()

# Iterate over each question-answer pair in the list of pairs
for pair in pairs:
    # Update the word frequencies with the words from both the question and the answer
    word_freq.update(pair[0])  # Update word frequencies with words from the question
    word_freq.update(pair[1])  # Update word frequencies with words from the answer

In [None]:
word_freq

Counter({'can': 11286,
         'we': 21574,
         'make': 4493,
         'this': 21853,
         'quick?': 6,
         'roxanne': 1,
         'korrine': 1,
         'and': 37699,
         'andrew': 28,
         'barrett': 8,
         'are': 18035,
         'having': 862,
         'an': 6611,
         'incredibly': 39,
         'horrendous': 1,
         'public': 187,
         'well': 10445,
         'i': 117031,
         'thought': 3408,
         "we'd": 445,
         'start': 991,
         'with': 16356,
         'pronunciation': 2,
         'if': 13812,
         "that's": 12652,
         'okay': 1971,
         'you.': 9595,
         'not': 20583,
         'the': 95972,
         'hacking': 10,
         'gagging': 5,
         'spitting': 14,
         'part.': 100,
         'please.': 799,
         'okay...': 192,
         'then': 4943,
         'how': 11802,
         "'bout": 290,
         'try': 1456,
         'out': 10225,
         'some': 6337,
         'french': 185,
         '

## Filtering Words by Frequency

Words that occur less frequently than the specified threshold (`min_word_freq`) are filtered out from the word frequency dictionary (`word_freq`)

### Creating Word-to-Index Mapping

- Each remaining word is assigned a unique index in the `word_map` dictionary, starting from 1.
* The index is incremented for each word in the list of filtered words, creating a word-to-index mapping.


### Adding Special Tokens

* Special tokens such as `<unk>` (unknown), `<start>` (start-of-sequence), `<end>` (end-of-sequence), and `<pad>` (padding) are added to the `word_map` dictionary with unique indices.
- These tokens are crucial for data preprocessing and model training, allowing for handling of out-of-vocabulary words, marking sequence boundaries, and managing variable-length sequences.

The resulting `word_map` dictionary provides a comprehensive mapping of words to indices, including special tokens, facilitating efficient data processing and model training.

In [None]:
# Set the minimum word frequency threshold
min_word_freq = 8

# Filter words based on their frequency to exclude those occurring less frequently than the threshold
words = [w for w in word_freq.keys() if word_freq[w] > min_word_freq]

# Create a word-to-index mapping dictionary
word_map = {k: v + 1 for v, k in enumerate(words)}  # Assign unique indices to each word, starting from 1

# Add special tokens to the word map with unique indices
word_map['<unk>'] = len(word_map) + 1  # Unknown token for out-of-vocabulary words
word_map['<start>'] = len(word_map) + 1  # Start-of-sequence token
word_map['<end>'] = len(word_map) + 1  # End-of-sequence token
word_map['<pad>'] = 0  # Padding token with index 0

In [None]:
word_map

{'can': 1,
 'we': 2,
 'make': 3,
 'this': 4,
 'and': 5,
 'andrew': 6,
 'are': 7,
 'having': 8,
 'an': 9,
 'incredibly': 10,
 'public': 11,
 'well': 12,
 'i': 13,
 'thought': 14,
 "we'd": 15,
 'start': 16,
 'with': 17,
 'if': 18,
 "that's": 19,
 'okay': 20,
 'you.': 21,
 'not': 22,
 'the': 23,
 'hacking': 24,
 'spitting': 25,
 'part.': 26,
 'please.': 27,
 'okay...': 28,
 'then': 29,
 'how': 30,
 "'bout": 31,
 'try': 32,
 'out': 33,
 'some': 34,
 'french': 35,
 'saturday?': 36,
 'night?': 37,
 "you're": 38,
 'asking': 39,
 'me': 40,
 'out.': 41,
 'so': 42,
 'cute.': 43,
 "what's": 44,
 'your': 45,
 'name': 46,
 'again?': 47,
 'forget': 48,
 'it.': 49,
 'no': 50,
 "it's": 51,
 'my': 52,
 'fault': 53,
 "didn't": 54,
 'have': 55,
 'a': 56,
 'proper': 57,
 'introduction': 58,
 'cameron.': 59,
 'thing': 60,
 'is': 61,
 'cameron': 62,
 "i'm": 63,
 'at': 64,
 'mercy': 65,
 'of': 66,
 'particularly': 67,
 'hideous': 68,
 'breed': 69,
 'loser.': 70,
 'seems': 71,
 'like': 72,
 'she': 73,
 'could

In [None]:
print(f"The total words are {len(word_map)}.")

The total words are 15729.


## Saving the WordMap

In [None]:
with open('WORDMAP_corpus.json', 'w') as j:
    json.dump(word_map, j)

## IMPORTANT: I can improve on this by using WordPiece or BPE tokenizer instead of using word-level tokens

## Encoding Words Using Word Mapping (Tokenizeation)

After creating the `word_map`, the next step is to encode the words using this mapping. Since neural networks require numerical inputs rather than strings, we need to represent words as indices in the `word_map`.


### Function Definitions

Two functions will be created for encoding: one for questions and one for replies.

### Function: `encode_question`
- **Input Arguments:**
  - `words`: List of words in the question.
  - `word_map`: Mapping of words to indices (`word_map`).

- **Explanation:**
  - This function, `encode_question`, converts each word in the question into its corresponding index using the provided `word_map`.

### Function: `encode_reply`
- **Input Arguments:**
  - `words`: List of words in the reply.
  - `word_map`: Mapping of words to indices (`word_map`).

- **Explanation:**
  - Similarly, the `encode_reply` function converts each word in the reply into its corresponding index using the `word_map`.


In [None]:
def encode_enc_inp(words, word_map):
    """
    Encode a question into a sequence of indices using a word-to-index mapping.

    Parameters:
    words (list): List of words in the question.
    word_map (dict): Mapping of words to indices.

    Returns:
    list: Encoded question as a sequence of indices.
    """

    # Convert each word in the question to its corresponding index in the word map
    # Use '<unk>' index for out-of-vocabulary words
    enc_c = [word_map.get(word, word_map['<unk>']) for word in words]

    # Pad the encoded sequence with '<pad>' token to ensure uniform length
    enc_c += [word_map['<pad>']] * (max_len - len(words))

    return enc_c

In [None]:
def encode_dec_inp(words, word_map):
    """
    Encode a reply into a sequence of indices using a word-to-index mapping.

    Parameters:
    words (list): List of words in the reply.
    word_map (dict): Mapping of words to indices.

    Returns:
    list: Encoded reply as a sequence of indices.
    """

    # Convert each word in the reply to its corresponding index in the word map
    # Use '<unk>' index for out-of-vocabulary words
    # Add '<start>' and '<end>' tokens to mark the start and end of the reply
    enc_c = [word_map['<start>']] + [word_map.get(word, word_map['<unk>']) for word in words] + \
            [word_map['<end>']] + [word_map['<pad>']] * (max_len - len(words))

    return enc_c

In [None]:
# Initialize an empty list to store encoded question-answer pairs
pairs_encoded = []

# Iterate over each question-answer pair in the list of pairs
for pair in pairs:
    # Encode the question and the reply using the provided word-to-index mapping
    qus = encode_enc_inp(pair[0], word_map)  # Encode the question
    ans = encode_dec_inp(pair[1], word_map)  # Encode the reply

    # Append the encoded question-answer pair to the list of encoded pairs
    pairs_encoded.append([qus, ans])


In [None]:
pairs[10]

[["c'esc", 'ma', 'tete.', 'this', 'is', 'my', 'head'],
 ['right.', 'see?', "you're", 'ready', 'for', 'the', 'quiz.']]

In [None]:
pairs_encoded[10]

[[15726, 101, 15726, 4, 61, 52, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [15727, 103, 104, 38, 105, 106, 23, 15726, 15728, 0, 0, 0, 0, 0, 0, 0, 0, 0]]

## Saving Number coded WordMap (tokenized words)

In [None]:
fname = "pairs_encoded.json"
with open(fname, 'w') as p:
    json.dump(pairs_encoded, p)

## Custom Dataset Class

In [None]:
class MovieDataset(Dataset):
  """
  Custom PyTorch datset class for loading encoded question-reply pairs.

  Args:
  ----
  None

  Attributes:
  ----------
  pairs (list): List of encoded question-reply pairs.
  dataset_size (int): Total number of question-reply pairs in the dataset.

  Methods:
  --------
  __init__(): Initializes the dataset by loading encoded pairs from a JSON file.
  __getitem__(i): Retrieves the encoded question-reply pair at index i.
  __len__(): Returns the total number of question-reply pairs in the dataset.
  """
  def __init__(self):
        """
        Initialize the dataset by loading encoded pairs from a JSON file.
        Sets the total number of pairs in the dataset.
        """
        self.pairs = json.load(open('pairs_encoded.json'))  # Load encoded pairs from a JSON file
        self.dataset_size = len(self.pairs)  # Set the total number of pairs in the dataset

  def __getitem__(self, i):
      """
      Retrieve the encoded question-reply pair at index i.

      Args:
      -----
      i (int): Index of the pair to retrieve.

      Returns:
      --------
      tuple: Encoded question and reply tensors.
      """
      # Convert the encoded question and reply to PyTorch LongTensors
      enc_inp = torch.LongTensor(self.pairs[i][0])
      dec = torch.LongTensor(self.pairs[i][1])

      # Prepare Target Data
      dec_inp = dec[ :-1]
      dec_out = dec[1 : ]

      return enc_inp, dec_inp, dec_out

  def __len__(self):
      """
      Return the total number of question-reply pairs in the dataset.

      Returns:
      --------
      int: Total number of pairs in the dataset.
      """
      return self.dataset_size

In [None]:
train_data = MovieDataset()

In [None]:
q_r = train_data[10]
q_r

(tensor([15726,   101, 15726,     4,    61,    52,   102,     0,     0,     0,
             0,     0,     0,     0,     0,     0]),
 tensor([15727,   103,   104,    38,   105,   106,    23, 15726, 15728,     0,
             0,     0,     0,     0,     0,     0,     0]),
 tensor([  103,   104,    38,   105,   106,    23, 15726, 15728,     0,     0,
             0,     0,     0,     0,     0,     0,     0]))

In [None]:
rev_word_map = {v: k for k, v in word_map.items()}

In [None]:
rev_word_map[104]

'see?'

In [None]:
def tensor_to_sentence(t, clean=False):
  q = t.detach().numpy()
  q_words = " ".join([rev_word_map[v] for v in q])

  if clean:
    q_words = q_words.replace("<pad>", "")

  return q_words

In [None]:
q_words = tensor_to_sentence(q_r[0])
r_words = tensor_to_sentence(q_r[1])
q_words, r_words

('<unk> ma <unk> this is my head <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>',
 "<start> right. see? you're ready for the <unk> <end> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>")

## Custom DataLoader

In [None]:
train_loader = DL(train_data,
                  batch_size = batch_size,
                  shuffle=True,
                  pin_memory=True)

In [None]:
# the reason we have 25 length in question is because we defined max length as 25
# reply has 25 + 2 = 27 because we have start and end appended to it
# and of course there is padding if the sentence does not have 25 words in it.
# for i, (enc_inp, dec_inp, dec,out) in enumerate(train_loader):

for i, (enc_inp, dec_inp, dec_out) in enumerate(train_loader):
  print(enc_inp.shape, dec_inp.shape, dec_out.shape)
  print(tensor_to_sentence(enc_inp[0]))
  print(tensor_to_sentence(dec_inp[0]))
  print(tensor_to_sentence(dec_out[0]))

  break

torch.Size([64, 16]) torch.Size([64, 17]) torch.Size([64, 17])
yeah. <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
<start> you lost your control. <end> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
you lost your control. <end> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>


## Setting the device

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

## Mask

This function **create_masks** generates masks for the *input question* and *reply* sequences to facilitate attention mechanisms in the neural network model.
- It first defines a nested function subsequent_mask to create a mask preventing attending to subsequent positions.
- Then, it creates masks for the input question, input reply, and target reply, ensuring proper masking for padding tokens and subsequent positions.
- The masks are returned as a tuple for further use in the model.

### Example
- Sentence: `<start>Hello how are you <end>`
- reply_input: `<start>Hello how are you`
  - reply_input is input to our decoder
- reply_target: `Hello how are you<end>`
  - reply_target is the target to our decoder
- Remember we are doing supervised learning.

In [None]:
# Batched scenario
t = torch.triu(torch.ones((2, 4, 4)))
t.transpose(1, 2)

tensor([[[1., 0., 0., 0.],
         [1., 1., 0., 0.],
         [1., 1., 1., 0.],
         [1., 1., 1., 1.]],

        [[1., 0., 0., 0.],
         [1., 1., 0., 0.],
         [1., 1., 1., 0.],
         [1., 1., 1., 1.]]])

## Embedding class with positional Details

In [None]:
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_size, pad_id):
        super(TokenEmbedding, self).__init__()
        self.token_embedding = nn.Embedding(vocab_size,
                                            embed_size,
                                            padding_idx=pad_id)
        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.token_embedding.weight.data.uniform_(-initrange, initrange)

    def forward(self, x):
        x_embed = self.token_embedding(x)
        return x_embed

In [None]:
class PositionalEmbedding(nn.Module):
    """
    ref: https://github.com/codertimo/BERT-pytorch/blob/master/bert_pytorch/model/embedding/position.py

    """
    def __init__(self, d_model, max_len=512):
        super().__init__()

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model).float()
        pe.require_grad = False

        position = torch.arange(0, max_len).float().unsqueeze(1)
        div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return self.pe[:, :x.size(1)]

In [None]:
import math
class Embeddings(nn.Module):
    def __init__(self, vocab, embed_size, max_len):
        super(Embeddings, self).__init__()
        self.token_embedding = TokenEmbedding(vocab_size=len(vocab),
                                              embed_size=embed_size,
                                              pad_id=vocab["<pad>"])
        self.embed_size = embed_size
        self.pos_embedding = PositionalEmbedding(d_model=embed_size,
                                                 max_len=max_len+2)

    def forward(self, x):
        token_embed = self.token_embedding(x) * math.sqrt(self.embed_size)
        pos_embed = self.pos_embedding(x)

        # print(x.shape, token_embed.shape, pos_embed.shape)

        return token_embed + pos_embed


## Creating the model

In [None]:
class Transformer(nn.Module):
    def __init__(self, vocab,
                 d_model=512,
                 n_head=8,
                 num_encoder_layers=6,
                 num_decoder_layers=6,
                 dim_feedforward=2048,
                 dropout=0.1,
                 max_len=15) -> None:
        """Instantiating Transformer class
        Args:
            config (Config): model config, the instance of data_utils.utils.Config
            vocab (Vocabulary): the instance of data_utils.vocab_tokenizer.Vocabulary
        """
        super(Transformer, self).__init__()
        self.vocab = vocab
        d_model = d_model #512
        n_head = n_head #8
        num_encoder_layers = num_encoder_layers #6
        num_decoder_layers = num_decoder_layers #6
        dim_feedforward = dim_feedforward #2048
        dropout = dropout #0.1

        self.input_embedding = Embeddings(vocab, d_model, max_len)

        self.transfomrer = torch.nn.Transformer(d_model=d_model,
                                                nhead=n_head,
                                                num_encoder_layers=num_encoder_layers,
                                                num_decoder_layers=num_decoder_layers,
                                                dim_feedforward=dim_feedforward,
                                                dropout=dropout,
                                                batch_first=True)

        self.proj_vocab_layer = nn.Linear(in_features=d_model,
                                          out_features=len(vocab))

        # https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module.apply
        # self.apply(self._initailze)
        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.proj_vocab_layer.bias.data.zero_()
        self.proj_vocab_layer.weight.data.uniform_(-initrange, initrange)

    def forward(self, enc_input: torch.Tensor, dec_input: torch.Tensor) -> torch.Tensor:

        x_enc_embed = self.input_embedding(enc_input.long())
        x_dec_embed = self.input_embedding(dec_input.long())

        # Masking
        # tensor([[False, False, False,  True,  ...,  True]])
        src_key_padding_mask = enc_input == self.vocab["<pad>"]
        tgt_key_padding_mask = dec_input == self.vocab["<pad>"]

        memory_key_padding_mask = src_key_padding_mask
        tgt_mask = self.transfomrer.generate_square_subsequent_mask(dec_input.size(1))

        # transformer ref: https://pytorch.org/docs/stable/nn.html#torch.nn.Transformer
        src_key_padding_mask = src_key_padding_mask.type(torch.float)
        tgt_key_padding_mask = tgt_key_padding_mask.type(torch.float)
        memory_key_padding_mask = memory_key_padding_mask.type(torch.float)
        tgt_mask = tgt_mask.type(torch.float).to(device)

        feature = self.transfomrer(src = x_enc_embed,
                                   tgt = x_dec_embed,
                                   src_key_padding_mask = src_key_padding_mask,
                                   tgt_key_padding_mask = tgt_key_padding_mask,
                                   memory_key_padding_mask=memory_key_padding_mask,
                                   tgt_mask = tgt_mask)

        logits = self.proj_vocab_layer(feature)

        return logits

In [None]:
word_map["<pad>"], len(word_map)

(0, 15729)

In [None]:
model = Transformer(word_map, max_len=15).to(device)
for i, (enc_inp, dec_inp, dec_out) in enumerate(train_loader):
  print(enc_inp.shape, dec_inp.shape, dec_out.shape)
  enc_inp, dec_inp = enc_inp.to(device), dec_inp.to(device)
  out = model(enc_inp, dec_inp)
  print(out.shape, dec_out.shape)

  # for 1 sentence form the batch
  # we have (max_len, vocab_size) output
  # hello - [vocab_size tensor with logit values]
  # how - [vocab_size tensor with logit values]
  # are - [vocab_size tensor with logit values]
  # your - [vocab_size tensor with logit values]
  # after softmax we will have 16 items with max values, we will compare that with dec_out
  # and calcualte the loss
  print(out[0].shape, dec_out[0].shape)
  break

torch.Size([64, 16]) torch.Size([64, 17]) torch.Size([64, 17])
torch.Size([64, 17, 15729]) torch.Size([64, 17])
torch.Size([17, 15729]) torch.Size([17])


## Optimizer

In [None]:
class AdamWarmup:

    def __init__(self, model_size, warmup_steps, optimizer):

        self.model_size = model_size
        self.warmup_steps = warmup_steps
        self.optimizer = optimizer
        self.current_step = 0
        self.lr = 0

    def get_lr(self):
        return self.model_size ** (-0.5) * min(self.current_step ** (-0.5), self.current_step * self.warmup_steps ** (-1.5))

    def step(self):
        # Increment the number of steps each time we call the step function
        self.current_step += 1


        lr = self.get_lr()

        # print(self.current_step, lr)
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        # update the learning rate
        self.lr = lr
        self.optimizer.step()

## Loss

In [None]:
class LossWithLS(nn.Module):
    def __init__(self, size, smooth):
        super(LossWithLS, self).__init__()
        self.criterion = nn.KLDivLoss(reduction='batchmean' )

        # self.criterion = nn.CrossEntropyLoss(ignore_index=0)

        self.confidence = 1.0 - smooth
        self.smooth = smooth
        self.size = size
        self.true_dist = None

    def forward(self, x, target):
      assert x.size(-1) == self.size
      true_dist = torch.zeros_like(x.data)
      true_dist.fill_(self.smooth / (self.size - 1))  # Fill with the smoothing value
      true_dist.scatter_(2, target.unsqueeze(2), self.confidence)  # Assign the confidence value to the true index
      true_dist = true_dist.detach()  # Detach true_dist from the computation graph

      # return self.criterion(x, true_dist)
      return self.criterion(F.log_softmax(x, dim=-1), true_dist)


# Example usage
batch_size = 64
max_words = 26
vocab_size = 18243
smooth = 0.1

# Random tensors for demonstration
prediction = torch.randn(batch_size, max_words, vocab_size)
target = torch.randint(0, vocab_size, (batch_size, max_words))

print(prediction.shape, target.shape)
# Initialize and compute loss
loss_fn = LossWithLS(size=vocab_size, smooth=smooth)
loss = loss_fn(prediction, target)
print(f'Loss: {loss.item()}')

torch.Size([64, 26, 18243]) torch.Size([64, 26])
Loss: 233.9447479248047


## Model evaluation (without training)

In [None]:
def evaluate(model, enc_inp, max_len, word_map):
    model.eval()  # Set the model to evaluation mode
    start_symbol = word_map['<start>']  # Assuming <sos> is the start-of-sequence token
    end_symbol = word_map['<end>']  # Assuming <eos> is the end-of-sequence token

    # Start with a target sequence of length 1 (just the start-of-sequence token)
    dec_inp = torch.LongTensor([start_symbol]).unsqueeze(0).to(device)
    # print(dec_inp.shape)

    # Generate output iteratively
    for i in range(max_len - 1):

        # Calculate the output logits
        output = model(enc_inp, dec_inp)
        # print(output.shape)

        # Get the last token from the output
        next_token_logits = output[:, -1, :]

        # Convert logits to probabilities and pick the token with the highest probability
        next_token = next_token_logits.argmax(dim=-1, keepdim=True)

        # Append the predicted token to the target sequence
        dec_inp = torch.cat([dec_inp, next_token], dim=1)

        # Check if the end-of-sequence token was generated
        if next_token.item() == end_symbol:
            break

    # Convert the target sequence to a list of tokens
    tgt_tokens = dec_inp.squeeze(0).tolist()
    # Convert tokens to words
    sentence = ' '.join([reverse_word_map[token] for token in tgt_tokens if token not in (start_symbol, end_symbol)])

    return sentence

# Assuming you have a word_map and a reverse_word_map to convert between tokens and words
reverse_word_map = {v: k for k, v in word_map.items()}

In [None]:
questions = ["Hello how are you?", "I like Fruits", "Are you hungry?"]
def getResults(transformer, questions):
  for q in questions:
    enc_qus = [word_map.get(word, word_map['<unk>']) for word in q.split()]
    question = torch.LongTensor(enc_qus).to(device).unsqueeze(0)
    # print("question.shape", question.shape)
    sentence = evaluate(transformer, question, max_len, word_map)
    print("\t", sentence)

In [None]:
transformer = Transformer(word_map, max_len=15).to(device)

questions = ["Hello how are you?", "I like Fruits", "Are you hungry?"]

getResults(transformer, questions)

	 lawson? manager? manager? manager? lawson? lawson? lawson? lawson? lawson? lawson? lawson? princess princess princess princess
	 lawson? lawson? lawson? lawson? lawson? lawson? lawson? lawson? lawson? lawson? lawson? lawson? lawson? lawson? lawson?
	 manager? manager? manager? manager? manager? manager? manager? manager? manager? manager? manager? manager? manager? manager? manager?


In [None]:
!pip install tqdm



## Training the model

In [None]:
from tqdm import tqdm

d_model = 512
n_head = 2
num_encoder_layers = 2
num_decoder_layers = num_encoder_layers
dim_feedforward = 512
dropout = 0.2

epochs = 30

transformer = Transformer(word_map,
                 d_model=d_model,
                 n_head=n_head,
                 num_encoder_layers=num_encoder_layers,
                 num_decoder_layers=num_decoder_layers,
                 dim_feedforward=dim_feedforward,
                 dropout=dropout,
                 max_len=15).to(device)

adam_optimizer = torch.optim.Adam(transformer.parameters(), lr=0.00, betas=(0.9, 0.98), eps=1e-9)
transformer_optimizer = AdamWarmup(model_size=512, warmup_steps = 4000, optimizer = adam_optimizer)
criterion = LossWithLS(len(word_map), 0.1)

for epoch in tqdm(range(epochs)):

    transformer.train()
    sum_loss = 0
    count = 0

    for i, (enc_inp, dec_inp, dec_out) in enumerate(train_loader):

        samples = enc_inp.shape[-1]

        # Move to device
        enc_inp = enc_inp.to(device)
        dec_inp, dec_out = dec_inp.to(device), dec_out.to(device)

        # Get the transformer outputs
        out = transformer(enc_inp, dec_inp)

        # Compute the loss
        # print(out.shape, reply_target.shape)
        loss = criterion(out, dec_out)

        # Backprop
        transformer_optimizer.optimizer.zero_grad()
        loss.backward()

        torch.nn.utils.clip_grad_norm_(transformer.parameters(), 0.5)

        transformer_optimizer.step()

        sum_loss += loss.item() * samples
        # print(loss.item(),  samples)
        count += samples

        if i % (batch_size * 5) == 0:
            # print(loss.item(),  samples)
            print("Epoch [{}][{}/{}]\tLoss: {:.3f}\tLR: {:.5f}".format(
                epoch,
                i,
                len(train_loader),
                sum_loss/count,
                transformer_optimizer.lr))

            # getResults(transformer, questions)

  0%|          | 0/30 [00:00<?, ?it/s]

Epoch [0][0/3290]	Loss: 148.259	LR: 0.00000
Epoch [0][320/3290]	Loss: 68.099	LR: 0.00006
Epoch [0][640/3290]	Loss: 57.796	LR: 0.00011
Epoch [0][960/3290]	Loss: 53.588	LR: 0.00017
Epoch [0][1280/3290]	Loss: 50.942	LR: 0.00022
Epoch [0][1600/3290]	Loss: 49.211	LR: 0.00028
Epoch [0][1920/3290]	Loss: 47.886	LR: 0.00034
Epoch [0][2240/3290]	Loss: 46.852	LR: 0.00039
Epoch [0][2560/3290]	Loss: 46.052	LR: 0.00045
Epoch [0][2880/3290]	Loss: 45.375	LR: 0.00050
Epoch [0][3200/3290]	Loss: 44.810	LR: 0.00056


  3%|▎         | 1/30 [03:03<1:28:38, 183.39s/it]

Epoch [1][0/3290]	Loss: 32.685	LR: 0.00057
Epoch [1][320/3290]	Loss: 39.332	LR: 0.00063
Epoch [1][640/3290]	Loss: 39.129	LR: 0.00069
Epoch [1][960/3290]	Loss: 39.060	LR: 0.00068
Epoch [1][1280/3290]	Loss: 39.070	LR: 0.00065
Epoch [1][1600/3290]	Loss: 38.996	LR: 0.00063
Epoch [1][1920/3290]	Loss: 38.899	LR: 0.00061
Epoch [1][2240/3290]	Loss: 38.853	LR: 0.00059
Epoch [1][2560/3290]	Loss: 38.820	LR: 0.00058
Epoch [1][2880/3290]	Loss: 38.756	LR: 0.00056
Epoch [1][3200/3290]	Loss: 38.701	LR: 0.00055


  7%|▋         | 2/30 [06:06<1:25:27, 183.14s/it]

Epoch [2][0/3290]	Loss: 42.356	LR: 0.00054
Epoch [2][320/3290]	Loss: 37.638	LR: 0.00053
Epoch [2][640/3290]	Loss: 37.524	LR: 0.00052
Epoch [2][960/3290]	Loss: 37.546	LR: 0.00051
Epoch [2][1280/3290]	Loss: 37.554	LR: 0.00050
Epoch [2][1600/3290]	Loss: 37.562	LR: 0.00049
Epoch [2][1920/3290]	Loss: 37.545	LR: 0.00048
Epoch [2][2240/3290]	Loss: 37.553	LR: 0.00047
Epoch [2][2560/3290]	Loss: 37.535	LR: 0.00046
Epoch [2][2880/3290]	Loss: 37.554	LR: 0.00045
Epoch [2][3200/3290]	Loss: 37.505	LR: 0.00045


 10%|█         | 3/30 [09:09<1:22:24, 183.14s/it]

Epoch [3][0/3290]	Loss: 42.277	LR: 0.00044
Epoch [3][320/3290]	Loss: 36.487	LR: 0.00044
Epoch [3][640/3290]	Loss: 36.607	LR: 0.00043
Epoch [3][960/3290]	Loss: 36.673	LR: 0.00042
Epoch [3][1280/3290]	Loss: 36.716	LR: 0.00042
Epoch [3][1600/3290]	Loss: 36.733	LR: 0.00041
Epoch [3][1920/3290]	Loss: 36.776	LR: 0.00041
Epoch [3][2240/3290]	Loss: 36.770	LR: 0.00040
Epoch [3][2560/3290]	Loss: 36.765	LR: 0.00040
Epoch [3][2880/3290]	Loss: 36.748	LR: 0.00039
Epoch [3][3200/3290]	Loss: 36.761	LR: 0.00039


 13%|█▎        | 4/30 [12:12<1:19:18, 183.03s/it]

Epoch [4][0/3290]	Loss: 30.978	LR: 0.00039
Epoch [4][320/3290]	Loss: 35.907	LR: 0.00038
Epoch [4][640/3290]	Loss: 35.978	LR: 0.00038
Epoch [4][960/3290]	Loss: 36.001	LR: 0.00037
Epoch [4][1280/3290]	Loss: 35.972	LR: 0.00037
Epoch [4][1600/3290]	Loss: 36.041	LR: 0.00036
Epoch [4][1920/3290]	Loss: 36.101	LR: 0.00036
Epoch [4][2240/3290]	Loss: 36.125	LR: 0.00036
Epoch [4][2560/3290]	Loss: 36.150	LR: 0.00035
Epoch [4][2880/3290]	Loss: 36.142	LR: 0.00035
Epoch [4][3200/3290]	Loss: 36.170	LR: 0.00035


 17%|█▋        | 5/30 [15:14<1:16:10, 182.84s/it]

Epoch [5][0/3290]	Loss: 37.698	LR: 0.00034
Epoch [5][320/3290]	Loss: 35.441	LR: 0.00034
Epoch [5][640/3290]	Loss: 35.504	LR: 0.00034
Epoch [5][960/3290]	Loss: 35.512	LR: 0.00033
Epoch [5][1280/3290]	Loss: 35.502	LR: 0.00033
Epoch [5][1600/3290]	Loss: 35.533	LR: 0.00033
Epoch [5][1920/3290]	Loss: 35.570	LR: 0.00033
Epoch [5][2240/3290]	Loss: 35.642	LR: 0.00032
Epoch [5][2560/3290]	Loss: 35.678	LR: 0.00032
Epoch [5][2880/3290]	Loss: 35.703	LR: 0.00032
Epoch [5][3200/3290]	Loss: 35.701	LR: 0.00032


 20%|██        | 6/30 [18:16<1:13:02, 182.60s/it]

Epoch [6][0/3290]	Loss: 35.767	LR: 0.00031
Epoch [6][320/3290]	Loss: 34.848	LR: 0.00031
Epoch [6][640/3290]	Loss: 35.073	LR: 0.00031
Epoch [6][960/3290]	Loss: 35.041	LR: 0.00031
Epoch [6][1280/3290]	Loss: 35.077	LR: 0.00030
Epoch [6][1600/3290]	Loss: 35.130	LR: 0.00030
Epoch [6][1920/3290]	Loss: 35.153	LR: 0.00030
Epoch [6][2240/3290]	Loss: 35.186	LR: 0.00030
Epoch [6][2560/3290]	Loss: 35.217	LR: 0.00030
Epoch [6][2880/3290]	Loss: 35.250	LR: 0.00029
Epoch [6][3200/3290]	Loss: 35.281	LR: 0.00029


 23%|██▎       | 7/30 [21:18<1:09:54, 182.35s/it]

Epoch [7][0/3290]	Loss: 35.781	LR: 0.00029
Epoch [7][320/3290]	Loss: 34.709	LR: 0.00029
Epoch [7][640/3290]	Loss: 34.655	LR: 0.00029
Epoch [7][960/3290]	Loss: 34.732	LR: 0.00029
Epoch [7][1280/3290]	Loss: 34.724	LR: 0.00028
Epoch [7][1600/3290]	Loss: 34.751	LR: 0.00028
Epoch [7][1920/3290]	Loss: 34.780	LR: 0.00028
Epoch [7][2240/3290]	Loss: 34.870	LR: 0.00028
Epoch [7][2560/3290]	Loss: 34.900	LR: 0.00028
Epoch [7][2880/3290]	Loss: 34.905	LR: 0.00027
Epoch [7][3200/3290]	Loss: 34.936	LR: 0.00027


 27%|██▋       | 8/30 [24:21<1:06:52, 182.38s/it]

Epoch [8][0/3290]	Loss: 33.697	LR: 0.00027
Epoch [8][320/3290]	Loss: 34.317	LR: 0.00027
Epoch [8][640/3290]	Loss: 34.509	LR: 0.00027
Epoch [8][960/3290]	Loss: 34.543	LR: 0.00027
Epoch [8][1280/3290]	Loss: 34.537	LR: 0.00027
Epoch [8][1600/3290]	Loss: 34.506	LR: 0.00026
Epoch [8][1920/3290]	Loss: 34.527	LR: 0.00026
Epoch [8][2240/3290]	Loss: 34.547	LR: 0.00026
Epoch [8][2560/3290]	Loss: 34.590	LR: 0.00026
Epoch [8][2880/3290]	Loss: 34.605	LR: 0.00026
Epoch [8][3200/3290]	Loss: 34.624	LR: 0.00026


 30%|███       | 9/30 [27:23<1:03:48, 182.33s/it]

Epoch [9][0/3290]	Loss: 33.071	LR: 0.00026
Epoch [9][320/3290]	Loss: 33.847	LR: 0.00026
Epoch [9][640/3290]	Loss: 33.923	LR: 0.00025
Epoch [9][960/3290]	Loss: 33.972	LR: 0.00025
Epoch [9][1280/3290]	Loss: 34.107	LR: 0.00025
Epoch [9][1600/3290]	Loss: 34.155	LR: 0.00025
Epoch [9][1920/3290]	Loss: 34.213	LR: 0.00025
Epoch [9][2240/3290]	Loss: 34.260	LR: 0.00025
Epoch [9][2560/3290]	Loss: 34.275	LR: 0.00025
Epoch [9][2880/3290]	Loss: 34.310	LR: 0.00025
Epoch [9][3200/3290]	Loss: 34.342	LR: 0.00024


 33%|███▎      | 10/30 [30:25<1:00:45, 182.30s/it]

Epoch [10][0/3290]	Loss: 36.197	LR: 0.00024
Epoch [10][320/3290]	Loss: 33.712	LR: 0.00024
Epoch [10][640/3290]	Loss: 33.694	LR: 0.00024
Epoch [10][960/3290]	Loss: 33.777	LR: 0.00024
Epoch [10][1280/3290]	Loss: 33.888	LR: 0.00024
Epoch [10][1600/3290]	Loss: 33.927	LR: 0.00024
Epoch [10][1920/3290]	Loss: 33.936	LR: 0.00024
Epoch [10][2240/3290]	Loss: 33.970	LR: 0.00024
Epoch [10][2560/3290]	Loss: 33.991	LR: 0.00023
Epoch [10][2880/3290]	Loss: 34.018	LR: 0.00023
Epoch [10][3200/3290]	Loss: 34.079	LR: 0.00023


 37%|███▋      | 11/30 [33:27<57:42, 182.23s/it]  

Epoch [11][0/3290]	Loss: 35.169	LR: 0.00023
Epoch [11][320/3290]	Loss: 33.547	LR: 0.00023
Epoch [11][640/3290]	Loss: 33.394	LR: 0.00023
Epoch [11][960/3290]	Loss: 33.486	LR: 0.00023
Epoch [11][1280/3290]	Loss: 33.586	LR: 0.00023
Epoch [11][1600/3290]	Loss: 33.642	LR: 0.00023
Epoch [11][1920/3290]	Loss: 33.703	LR: 0.00023
Epoch [11][2240/3290]	Loss: 33.753	LR: 0.00023
Epoch [11][2560/3290]	Loss: 33.813	LR: 0.00022
Epoch [11][2880/3290]	Loss: 33.836	LR: 0.00022
Epoch [11][3200/3290]	Loss: 33.884	LR: 0.00022


 40%|████      | 12/30 [36:29<54:39, 182.19s/it]

Epoch [12][0/3290]	Loss: 30.163	LR: 0.00022
Epoch [12][320/3290]	Loss: 33.248	LR: 0.00022
Epoch [12][640/3290]	Loss: 33.163	LR: 0.00022
Epoch [12][960/3290]	Loss: 33.245	LR: 0.00022
Epoch [12][1280/3290]	Loss: 33.357	LR: 0.00022
Epoch [12][1600/3290]	Loss: 33.464	LR: 0.00022
Epoch [12][1920/3290]	Loss: 33.537	LR: 0.00022
Epoch [12][2240/3290]	Loss: 33.555	LR: 0.00022
Epoch [12][2560/3290]	Loss: 33.641	LR: 0.00022
Epoch [12][2880/3290]	Loss: 33.679	LR: 0.00021
Epoch [12][3200/3290]	Loss: 33.686	LR: 0.00021


 43%|████▎     | 13/30 [39:32<51:37, 182.23s/it]

Epoch [13][0/3290]	Loss: 35.962	LR: 0.00021
Epoch [13][320/3290]	Loss: 33.041	LR: 0.00021
Epoch [13][640/3290]	Loss: 33.008	LR: 0.00021
Epoch [13][960/3290]	Loss: 33.080	LR: 0.00021
Epoch [13][1280/3290]	Loss: 33.134	LR: 0.00021
Epoch [13][1600/3290]	Loss: 33.219	LR: 0.00021
Epoch [13][1920/3290]	Loss: 33.316	LR: 0.00021
Epoch [13][2240/3290]	Loss: 33.410	LR: 0.00021
Epoch [13][2560/3290]	Loss: 33.454	LR: 0.00021
Epoch [13][2880/3290]	Loss: 33.471	LR: 0.00021
Epoch [13][3200/3290]	Loss: 33.517	LR: 0.00021


 47%|████▋     | 14/30 [42:34<48:37, 182.34s/it]

Epoch [14][0/3290]	Loss: 32.874	LR: 0.00021
Epoch [14][320/3290]	Loss: 32.679	LR: 0.00021
Epoch [14][640/3290]	Loss: 32.865	LR: 0.00020
Epoch [14][960/3290]	Loss: 33.008	LR: 0.00020
Epoch [14][1280/3290]	Loss: 33.072	LR: 0.00020
Epoch [14][1600/3290]	Loss: 33.132	LR: 0.00020
Epoch [14][1920/3290]	Loss: 33.177	LR: 0.00020
Epoch [14][2240/3290]	Loss: 33.210	LR: 0.00020
Epoch [14][2560/3290]	Loss: 33.248	LR: 0.00020
Epoch [14][2880/3290]	Loss: 33.291	LR: 0.00020
Epoch [14][3200/3290]	Loss: 33.330	LR: 0.00020


 50%|█████     | 15/30 [45:38<45:39, 182.65s/it]

Epoch [15][0/3290]	Loss: 31.868	LR: 0.00020
Epoch [15][320/3290]	Loss: 32.552	LR: 0.00020
Epoch [15][640/3290]	Loss: 32.709	LR: 0.00020
Epoch [15][960/3290]	Loss: 32.705	LR: 0.00020
Epoch [15][1280/3290]	Loss: 32.818	LR: 0.00020
Epoch [15][1600/3290]	Loss: 32.868	LR: 0.00020
Epoch [15][1920/3290]	Loss: 32.984	LR: 0.00020
Epoch [15][2240/3290]	Loss: 33.008	LR: 0.00019
Epoch [15][2560/3290]	Loss: 33.054	LR: 0.00019
Epoch [15][2880/3290]	Loss: 33.105	LR: 0.00019
Epoch [15][3200/3290]	Loss: 33.163	LR: 0.00019


 53%|█████▎    | 16/30 [48:40<42:34, 182.50s/it]

Epoch [16][0/3290]	Loss: 30.473	LR: 0.00019
Epoch [16][320/3290]	Loss: 32.489	LR: 0.00019
Epoch [16][640/3290]	Loss: 32.500	LR: 0.00019
Epoch [16][960/3290]	Loss: 32.629	LR: 0.00019
Epoch [16][1280/3290]	Loss: 32.680	LR: 0.00019
Epoch [16][1600/3290]	Loss: 32.732	LR: 0.00019
Epoch [16][1920/3290]	Loss: 32.766	LR: 0.00019
Epoch [16][2240/3290]	Loss: 32.849	LR: 0.00019
Epoch [16][2560/3290]	Loss: 32.880	LR: 0.00019
Epoch [16][2880/3290]	Loss: 32.921	LR: 0.00019
Epoch [16][3200/3290]	Loss: 32.978	LR: 0.00019


 57%|█████▋    | 17/30 [51:42<39:31, 182.42s/it]

Epoch [17][0/3290]	Loss: 37.485	LR: 0.00019
Epoch [17][320/3290]	Loss: 32.432	LR: 0.00019
Epoch [17][640/3290]	Loss: 32.551	LR: 0.00019
Epoch [17][960/3290]	Loss: 32.461	LR: 0.00019
Epoch [17][1280/3290]	Loss: 32.587	LR: 0.00018
Epoch [17][1600/3290]	Loss: 32.621	LR: 0.00018
Epoch [17][1920/3290]	Loss: 32.646	LR: 0.00018
Epoch [17][2240/3290]	Loss: 32.689	LR: 0.00018
Epoch [17][2560/3290]	Loss: 32.743	LR: 0.00018
Epoch [17][2880/3290]	Loss: 32.783	LR: 0.00018
Epoch [17][3200/3290]	Loss: 32.841	LR: 0.00018


 60%|██████    | 18/30 [54:45<36:30, 182.52s/it]

Epoch [18][0/3290]	Loss: 31.799	LR: 0.00018
Epoch [18][320/3290]	Loss: 32.084	LR: 0.00018
Epoch [18][640/3290]	Loss: 32.212	LR: 0.00018
Epoch [18][960/3290]	Loss: 32.327	LR: 0.00018
Epoch [18][1280/3290]	Loss: 32.420	LR: 0.00018
Epoch [18][1600/3290]	Loss: 32.487	LR: 0.00018
Epoch [18][1920/3290]	Loss: 32.532	LR: 0.00018
Epoch [18][2240/3290]	Loss: 32.567	LR: 0.00018
Epoch [18][2560/3290]	Loss: 32.616	LR: 0.00018
Epoch [18][2880/3290]	Loss: 32.679	LR: 0.00018
Epoch [18][3200/3290]	Loss: 32.715	LR: 0.00018


 63%|██████▎   | 19/30 [57:48<33:29, 182.64s/it]

Epoch [19][0/3290]	Loss: 26.567	LR: 0.00018
Epoch [19][320/3290]	Loss: 31.984	LR: 0.00018
Epoch [19][640/3290]	Loss: 32.116	LR: 0.00018
Epoch [19][960/3290]	Loss: 32.227	LR: 0.00018
Epoch [19][1280/3290]	Loss: 32.325	LR: 0.00017
Epoch [19][1600/3290]	Loss: 32.339	LR: 0.00017
Epoch [19][1920/3290]	Loss: 32.356	LR: 0.00017
Epoch [19][2240/3290]	Loss: 32.422	LR: 0.00017
Epoch [19][2560/3290]	Loss: 32.464	LR: 0.00017
Epoch [19][2880/3290]	Loss: 32.525	LR: 0.00017
Epoch [19][3200/3290]	Loss: 32.584	LR: 0.00017


 67%|██████▋   | 20/30 [1:00:50<30:24, 182.48s/it]

Epoch [20][0/3290]	Loss: 32.957	LR: 0.00017
Epoch [20][320/3290]	Loss: 31.717	LR: 0.00017
Epoch [20][640/3290]	Loss: 31.934	LR: 0.00017
Epoch [20][960/3290]	Loss: 32.030	LR: 0.00017
Epoch [20][1280/3290]	Loss: 32.140	LR: 0.00017
Epoch [20][1600/3290]	Loss: 32.213	LR: 0.00017
Epoch [20][1920/3290]	Loss: 32.264	LR: 0.00017
Epoch [20][2240/3290]	Loss: 32.322	LR: 0.00017
Epoch [20][2560/3290]	Loss: 32.366	LR: 0.00017
Epoch [20][2880/3290]	Loss: 32.404	LR: 0.00017
Epoch [20][3200/3290]	Loss: 32.453	LR: 0.00017


 70%|███████   | 21/30 [1:03:52<27:22, 182.48s/it]

Epoch [21][0/3290]	Loss: 35.439	LR: 0.00017
Epoch [21][320/3290]	Loss: 31.869	LR: 0.00017
Epoch [21][640/3290]	Loss: 31.836	LR: 0.00017
Epoch [21][960/3290]	Loss: 31.858	LR: 0.00017
Epoch [21][1280/3290]	Loss: 31.980	LR: 0.00017
Epoch [21][1600/3290]	Loss: 32.028	LR: 0.00017
Epoch [21][1920/3290]	Loss: 32.107	LR: 0.00017
Epoch [21][2240/3290]	Loss: 32.177	LR: 0.00017
Epoch [21][2560/3290]	Loss: 32.245	LR: 0.00017
Epoch [21][2880/3290]	Loss: 32.314	LR: 0.00016
Epoch [21][3200/3290]	Loss: 32.353	LR: 0.00016


 73%|███████▎  | 22/30 [1:06:55<24:19, 182.43s/it]

Epoch [22][0/3290]	Loss: 31.700	LR: 0.00016
Epoch [22][320/3290]	Loss: 31.694	LR: 0.00016
Epoch [22][640/3290]	Loss: 31.755	LR: 0.00016
Epoch [22][960/3290]	Loss: 31.819	LR: 0.00016
Epoch [22][1280/3290]	Loss: 31.914	LR: 0.00016
Epoch [22][1600/3290]	Loss: 31.968	LR: 0.00016
Epoch [22][1920/3290]	Loss: 32.050	LR: 0.00016
Epoch [22][2240/3290]	Loss: 32.117	LR: 0.00016
Epoch [22][2560/3290]	Loss: 32.157	LR: 0.00016
Epoch [22][2880/3290]	Loss: 32.199	LR: 0.00016
Epoch [22][3200/3290]	Loss: 32.250	LR: 0.00016


 77%|███████▋  | 23/30 [1:09:57<21:17, 182.51s/it]

Epoch [23][0/3290]	Loss: 29.525	LR: 0.00016
Epoch [23][320/3290]	Loss: 31.485	LR: 0.00016
Epoch [23][640/3290]	Loss: 31.601	LR: 0.00016
Epoch [23][960/3290]	Loss: 31.683	LR: 0.00016
Epoch [23][1280/3290]	Loss: 31.800	LR: 0.00016
Epoch [23][1600/3290]	Loss: 31.894	LR: 0.00016
Epoch [23][1920/3290]	Loss: 31.964	LR: 0.00016
Epoch [23][2240/3290]	Loss: 32.015	LR: 0.00016
Epoch [23][2560/3290]	Loss: 32.048	LR: 0.00016
Epoch [23][2880/3290]	Loss: 32.092	LR: 0.00016
Epoch [23][3200/3290]	Loss: 32.131	LR: 0.00016


 80%|████████  | 24/30 [1:13:00<18:15, 182.50s/it]

Epoch [24][0/3290]	Loss: 27.958	LR: 0.00016
Epoch [24][320/3290]	Loss: 31.568	LR: 0.00016
Epoch [24][640/3290]	Loss: 31.762	LR: 0.00016
Epoch [24][960/3290]	Loss: 31.772	LR: 0.00016
Epoch [24][1280/3290]	Loss: 31.846	LR: 0.00016
Epoch [24][1600/3290]	Loss: 31.850	LR: 0.00016
Epoch [24][1920/3290]	Loss: 31.892	LR: 0.00016
Epoch [24][2240/3290]	Loss: 31.906	LR: 0.00016
Epoch [24][2560/3290]	Loss: 31.970	LR: 0.00015
Epoch [24][2880/3290]	Loss: 32.027	LR: 0.00015
Epoch [24][3200/3290]	Loss: 32.048	LR: 0.00015


 83%|████████▎ | 25/30 [1:16:03<15:13, 182.62s/it]

Epoch [25][0/3290]	Loss: 32.321	LR: 0.00015
Epoch [25][320/3290]	Loss: 31.258	LR: 0.00015
Epoch [25][640/3290]	Loss: 31.476	LR: 0.00015
Epoch [25][960/3290]	Loss: 31.550	LR: 0.00015
Epoch [25][1280/3290]	Loss: 31.610	LR: 0.00015
Epoch [25][1600/3290]	Loss: 31.665	LR: 0.00015
Epoch [25][1920/3290]	Loss: 31.706	LR: 0.00015
Epoch [25][2240/3290]	Loss: 31.776	LR: 0.00015
Epoch [25][2560/3290]	Loss: 31.827	LR: 0.00015
Epoch [25][2880/3290]	Loss: 31.860	LR: 0.00015
Epoch [25][3200/3290]	Loss: 31.940	LR: 0.00015


 87%|████████▋ | 26/30 [1:19:06<12:11, 182.94s/it]

Epoch [26][0/3290]	Loss: 35.237	LR: 0.00015
Epoch [26][320/3290]	Loss: 31.370	LR: 0.00015
Epoch [26][640/3290]	Loss: 31.355	LR: 0.00015
Epoch [26][960/3290]	Loss: 31.440	LR: 0.00015
Epoch [26][1280/3290]	Loss: 31.497	LR: 0.00015
Epoch [26][1600/3290]	Loss: 31.530	LR: 0.00015
Epoch [26][1920/3290]	Loss: 31.602	LR: 0.00015
Epoch [26][2240/3290]	Loss: 31.686	LR: 0.00015
Epoch [26][2560/3290]	Loss: 31.745	LR: 0.00015
Epoch [26][2880/3290]	Loss: 31.788	LR: 0.00015
Epoch [26][3200/3290]	Loss: 31.852	LR: 0.00015


 90%|█████████ | 27/30 [1:22:10<09:09, 183.08s/it]

Epoch [27][0/3290]	Loss: 32.107	LR: 0.00015
Epoch [27][320/3290]	Loss: 31.080	LR: 0.00015
Epoch [27][640/3290]	Loss: 31.208	LR: 0.00015
Epoch [27][960/3290]	Loss: 31.326	LR: 0.00015
Epoch [27][1280/3290]	Loss: 31.362	LR: 0.00015
Epoch [27][1600/3290]	Loss: 31.475	LR: 0.00015
Epoch [27][1920/3290]	Loss: 31.528	LR: 0.00015
Epoch [27][2240/3290]	Loss: 31.617	LR: 0.00015
Epoch [27][2560/3290]	Loss: 31.667	LR: 0.00015
Epoch [27][2880/3290]	Loss: 31.723	LR: 0.00015
Epoch [27][3200/3290]	Loss: 31.762	LR: 0.00015


 93%|█████████▎| 28/30 [1:25:14<06:06, 183.35s/it]

Epoch [28][0/3290]	Loss: 29.930	LR: 0.00015
Epoch [28][320/3290]	Loss: 31.288	LR: 0.00015
Epoch [28][640/3290]	Loss: 31.220	LR: 0.00015
Epoch [28][960/3290]	Loss: 31.243	LR: 0.00014
Epoch [28][1280/3290]	Loss: 31.332	LR: 0.00014
Epoch [28][1600/3290]	Loss: 31.428	LR: 0.00014
Epoch [28][1920/3290]	Loss: 31.471	LR: 0.00014
Epoch [28][2240/3290]	Loss: 31.559	LR: 0.00014
Epoch [28][2560/3290]	Loss: 31.582	LR: 0.00014
Epoch [28][2880/3290]	Loss: 31.640	LR: 0.00014
Epoch [28][3200/3290]	Loss: 31.666	LR: 0.00014


 97%|█████████▋| 29/30 [1:28:18<03:03, 183.55s/it]

Epoch [29][0/3290]	Loss: 28.921	LR: 0.00014
Epoch [29][320/3290]	Loss: 30.999	LR: 0.00014
Epoch [29][640/3290]	Loss: 31.116	LR: 0.00014
Epoch [29][960/3290]	Loss: 31.126	LR: 0.00014
Epoch [29][1280/3290]	Loss: 31.229	LR: 0.00014
Epoch [29][1600/3290]	Loss: 31.285	LR: 0.00014
Epoch [29][1920/3290]	Loss: 31.325	LR: 0.00014
Epoch [29][2240/3290]	Loss: 31.400	LR: 0.00014
Epoch [29][2560/3290]	Loss: 31.458	LR: 0.00014
Epoch [29][2880/3290]	Loss: 31.519	LR: 0.00014
Epoch [29][3200/3290]	Loss: 31.550	LR: 0.00014


100%|██████████| 30/30 [1:31:21<00:00, 182.73s/it]


In [None]:
# saving checkpoints
save_path = f"new_checkpoint_{epoch}.pth"

torch.save({
    'epoch': epoch,
    'model_state_dict': transformer.state_dict(),
    'optimizer_state_dict': transformer_optimizer.optimizer.state_dict(),
    'loss': loss.item(),  # It's good practice to save the last loss, too
}, save_path)

print(f"Checkpoint saved to {save_path}")

Checkpoint saved to new_checkpoint_29.pth


In [None]:
state = {'epoch': epoch, 'transformer': transformer, 'transformer_optimizer': transformer_optimizer}
torch.save(state, 'checkpoint_final' + str(epoch) + '.pth')

In [None]:
questions = ["Hello how are you?",
             "Are you hungry?",
             "How is life going?",
             "I am sad",
             "I kiss a girl"]

getResults(transformer, questions)

	 i don't know.
	 i don't know.
	 i don't know.
	 i don't know.
	 i don't know.


In [None]:
transformer_optimizer

<__main__.AdamWarmup at 0x77fb90584550>

In [None]:
state = {
    'epoch': epoch,
    'transformer_state_dict': transformer.state_dict(),
    'transformer_optimizer_state_dict': transformer_optimizer.optimizer.state_dict()
}
torch.save(state, 'checkpoint_final_' + str(epoch) + '.pth')

In [None]:
# Load the checkpoint
checkpoint = torch.load('checkpoint_final_29.pth')

In [None]:
d_model = 512
n_head = 2
num_encoder_layers = 2
num_decoder_layers = num_encoder_layers
dim_feedforward = 512
dropout = 0.2


epochs = 10

transformer = Transformer(word_map,
                 d_model=d_model,
                 n_head=n_head,
                 num_encoder_layers=num_encoder_layers,
                 num_decoder_layers=num_decoder_layers,
                 dim_feedforward=dim_feedforward,
                 dropout=dropout,
                 max_len=15).to(device)

In [None]:
# Restore the model and optimizer state
transformer.load_state_dict(checkpoint['transformer_state_dict'])
transformer_optimizer.optimizer.load_state_dict(checkpoint['transformer_optimizer_state_dict'])

# Restore the last epoch
start_epoch = checkpoint['epoch']

In [None]:

questions = ["Do you eat Fruits?",
             "Lets go to France?",
             "I am just happy",
             "I kiss a girl"]

getResults(transformer, questions)

	 i don't know.
	 i don't know.
	 i don't know.
	 i don't know.
