<a href="https://colab.research.google.com/github/rajuX75/Basic-Transformer-Chat-Demo/blob/main/Basic_Transformer_Chat_Component_Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [34]:
# @title Cell 1: Imports
import tensorflow as tf
import numpy as np
import zipfile
import time
import os
import re

print("TensorFlow version:", tf.__version__)

# Ensure GPU is available
gpu_available = tf.config.list_physical_devices('GPU')
if gpu_available:
    print("GPU is available.")
    tf.config.experimental.set_memory_growth(gpu_available[0], True)
else:
    print("No GPU available, using CPU. Training will be slow.")

# Set up a basic random seed for reproducibility (optional)
tf.random.set_seed(42)
np.random.seed(42)

TensorFlow version: 2.18.0
GPU is available.


In [36]:
# @title Cell 2: Data Preparation (Using Local Cornell Corpus Zip)

# --- Use Locally Uploaded Dataset ---
# Assuming you have manually uploaded 'cornell_movie_dialogs_corpus.zip'
# to your Colab working directory (/content/)
zip_file_path = '/content/cornell_movie_dialogs_corpus.zip'

# Directory to extract the contents into
extraction_base_dir = '/content/cornell_dialogs_extracted/' # Choose a descriptive name

# Create the extraction directory if it doesn't exist
os.makedirs(extraction_base_dir, exist_ok=True)

print(f"Extracting dataset from: {zip_file_path}")
try:
    with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
        zip_ref.extractall(extraction_base_dir)
    print(f"Dataset extracted to: {extraction_base_dir}")

except FileNotFoundError:
    print(f"Error: Zip file not found at {zip_file_path}")
    print("Please ensure 'cornell_movie_dialogs_corpus.zip' is uploaded to /content/")
    # Exit or handle error
    raise # Re-raise the error to stop execution
except zipfile.BadZipFile:
    print(f"Error: Bad zip file - could not extract {zip_file_path}")
    print("Please ensure the zip file is not corrupted.")
    raise


# The actual data files are typically located inside the folder named 'cornell movie-dialogs corpus'
# within the extracted contents. Note the spaces in the folder name.
actual_data_folder_name_in_zip = "cornell movie-dialogs corpus"

# Construct the full path to the directory containing the data files
dataset_dir = os.path.join(extraction_base_dir, actual_data_folder_name_in_zip)

print(f"Actual data files are located in: {dataset_dir}")

# --- Load and Parse the Data ---

# File paths within the actual data directory
lines_filepath = os.path.join(dataset_dir, 'movie_lines.txt')
conversations_filepath = os.path.join(dataset_dir, 'movie_conversations.txt')

# Dictionary to map line IDs to text
id2line = {}
# Use 'latin-1' encoding as specified for this dataset
try:
    with open(lines_filepath, 'r', encoding='latin-1') as f:
        for line in f:
            parts = line.split(' +++$+++ ')
            if len(parts) == 5:
                # Format: lineID +++$+++ characterID ++++$+++ movieID ++++$+++ character ++++$+++ text
                # Store line ID and the text
                line_id = parts[0].strip()
                text = parts[4].strip()
                id2line[line_id] = text
except FileNotFoundError:
     print(f"Error: Could not find movie_lines.txt at {lines_filepath}")
     print("Please check the extraction path and the folder name 'cornell movie-dialogs corpus' inside the zip.")
     # Exit or handle error appropriately if file isn't found after expected extraction
     raise # Re-raise the error to stop execution


# List of conversation line IDs
conversations = []
try:
    with open(conversations_filepath, 'r', encoding='latin-1') as f:
        for line in f:
            parts = line.split(' +++$+++ ')
            if len(parts) == 4:
                # Format: characterID1 +++$+++ characterID2 ++++$+++ movieID ++++$+++ utterance sequence
                # The last part is a string representation of a list of line IDs, e.g., "['L1', 'L2', 'L3', ...]"
                line_ids_str = parts[3].strip()
                # Safely convert string representation of list to a Python list
                try:
                    # Ensure we are parsing a string that looks like a Python list
                    if line_ids_str.startswith('[') and line_ids_str.endswith(']'):
                        line_ids = eval(line_ids_str) # Example: ['L194', 'L195', 'L196']
                        # Ensure the result is actually a list before appending
                        if isinstance(line_ids, list):
                            conversations.append(line_ids)
                        else:
                             print(f"Warning: eval did not return a list for '{line_ids_str}'. Skipping.")
                    else:
                         print(f"Warning: Conversation line IDs string did not start/end with brackets: '{line_ids_str}'. Skipping.")

                except Exception as e:
                     print(f"Warning: Could not parse line IDs in conversation: '{line_ids_str}' - {e}")
                     # Skip this conversation if parsing fails
                     continue
except FileNotFoundError:
    print(f"Error: Could not find movie_conversations.txt at {conversations_filepath}")
    print("Please check the extraction path and the folder name 'cornell movie-dialogs corpus' inside the zip.")
    # Exit or handle error appropriately
    raise # Re-raise the error to stop execution


print(f"Loaded {len(id2line)} lines and {len(conversations)} conversations.")

# --- Create Question-Answer Pairs ---

# We will create pairs of (current turn, next turn)
qa_pairs = []
for conversation in conversations:
    # Iterate through the conversation, creating pairs (line i, line i+1)
    # Only iterate up to the second to last line to have a valid pair
    for i in range(len(conversation) - 1):
        input_line_id = conversation[i]
        target_line_id = conversation[i+1]

        # Get the text using the IDs, use .get with a default None to handle missing IDs gracefully
        input_text = id2line.get(input_line_id)
        target_text = id2line.get(target_line_id)

        # Ensure both lines exist and are not empty strings after stripping
        if input_text and target_text and input_text.strip() and target_text.strip():
            qa_pairs.append([input_text, target_text])
        # Optional: Log if a pair is skipped (can be noisy)
        # else:
        #     print(f"Skipping pair for line IDs {input_line_id}, {target_line_id} due to missing or empty text.")


print(f"Created {len(qa_pairs)} question-answer pairs.")

# --- Basic Preprocessing (Same as before, applied to new data) ---

# 1. Clean and Normalize Text
# Keep this function definition the same as the previous corrected version
def clean_text(text):
    text = str(text) # Ensure text is string type
    text = text.lower()
    # Keep letters, numbers, and some punctuation (like ? . ! , '), replace others with space
    text = re.sub(r"[^a-zA-Z0-9?.!,']+", " ", text) # Keep apostrophes for contractions
    # Optional: clean up multiple spaces
    text = re.sub(r"\s+", " ", text).strip()
    # Ensure text is not empty after cleaning before adding tokens
    if not text:
        return '' # Return empty string if cleaning results in empty text

    # Add start and end tokens (essential for seq2seq)
    text = '<start> ' + text + ' <end>'
    return text

# Apply cleaning to all pairs
# Create a new list to hold cleaned pairs, skipping pairs where cleaning resulted in empty text
cleaned_data = []
for q, a in qa_pairs:
    cleaned_q = clean_text(q)
    cleaned_a = clean_text(a)
    # Only keep the pair if both cleaned question and answer are not empty
    if cleaned_q and cleaned_a:
        cleaned_data.append([cleaned_q, cleaned_a])
    # else:
    #     print(f"Skipping pair: Original Q='{q}', A='{a}' resulted in empty cleaned text.")


print(f"\nCleaned {len(cleaned_data)} valid pairs after applying text cleaning.")
print("Cleaned Data Example:", cleaned_data[0])

# If no valid data remains, stop here
if not cleaned_data:
    raise ValueError("No valid question-answer pairs found after cleaning. Check data source and cleaning function.")


# 2. Create Vocabulary and Tokenizer
# Use a smaller vocabulary size to manage memory in Colab, or let it be full initially
# num_words=None means keep all words found
# oov_token='<unk>' handles words not in vocabulary during prediction
tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='', split=' ', oov_token='<unk>')

# Fit on all cleaned text
tokenizer.fit_on_texts([q + ' ' + a for q, a in cleaned_data])

# Add special tokens to word_index if not already there (clean_text adds them before fit)
# This ensures their IDs are properly mapped. Keras Tokenizer assigns IDs from 1 up.
# 0 is reserved for padding. <unk> is handled by oov_token.
# Check if special tokens were included in vocabulary during fitting
if '<start>' not in tokenizer.word_index:
    print("Warning: '<start>' not found in vocabulary after fitting.")
if '<end>' not in tokenizer.word_index:
    print("Warning: '<end>' not found in vocabulary after fitting.")
if '<unk>' not in tokenizer.word_index:
    # If oov_token was used, <unk> should be in word_index
     print("Warning: '<unk>' not found in vocabulary after fitting, despite using oov_token.")


# Convert text to sequences of integers
input_sequences = tokenizer.texts_to_sequences([q for q, a in cleaned_data])
target_sequences = tokenizer.texts_to_sequences([a for q, a in cleaned_data])

print("\nOriginal sequences example:", input_sequences[0])
print("Target sequences example:", target_sequences[0])

# 3. Padding sequences to a fixed length
# Determine a reasonable max length (e.g., 95th percentile of sequence lengths)
all_lengths = [len(seq) for seq in input_sequences + target_sequences]
# Handle case where all_lengths is empty or contains only 0s after filtering
if not all_lengths or max(all_lengths) == 0:
     print("Error: No sequences with length > 0 found. Cannot determine MAX_LENGTH.")
     MAX_LENGTH = 10 # Set a default or handle as error (e.g., raise ValueError)
else:
    # Calculate a percentile length, cap at a reasonable value if needed for memory
    MAX_LENGTH = int(np.percentile(all_lengths, 95))
    # Optional: Set a hard cap on max length if 95th percentile is too large
    # MAX_LENGTH = min(MAX_LENGTH, 80) # Example cap at 80 tokens - adjust based on memory

print(f"Calculated MAX_LENGTH (95th percentile capped/uncapped): {MAX_LENGTH}")
print(f"Number of pairs to pad: {len(input_sequences)}")

# Apply padding
input_padded = tf.keras.preprocessing.sequence.pad_sequences(input_sequences,
                                                             maxlen=MAX_LENGTH,
                                                             padding='post')
target_padded = tf.keras.preprocessing.sequence.pad_sequences(target_sequences,
                                                              maxlen=MAX_LENGTH,
                                                              padding='post')

print("\nPadded input sequence example:", input_padded[0])
print("Padded target sequence example:", target_padded[0])

# The vocabulary size is based on the number of unique words found by the tokenizer + 1 for padding (ID 0)
VOCAB_SIZE = len(tokenizer.word_index) + 1
print("Vocabulary size:", VOCAB_SIZE) # Should be much larger now
print("Number of final training pairs:", len(input_padded))

# 4. Create tf.data.Dataset
# Use a smaller buffer size for shuffling if BUFFER_SIZE is too large for memory
BUFFER_SIZE = 20000 # A good balance for shuffling a large dataset without huge memory
BATCH_SIZE = 64 # Use a larger batch size, adjust based on GPU memory

# Ensure input_padded and target_padded are numpy arrays or tensors for from_tensor_slices
input_padded = np.array(input_padded)
target_padded = np.array(target_padded)

dataset = tf.data.Dataset.from_tensor_slices((input_padded, target_padded))
dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
dataset = dataset.prefetch(tf.data.AUTOTUNE) # Overlap data preprocessing and model execution

# Calculate total steps *after* batching
# Use the correct constant for unknown cardinality
total_steps_per_epoch = tf.data.experimental.cardinality(dataset).numpy()
# CORRECTED LINE: Use tf.data.UNKNOWN_CARDINALITY instead of tf.data.experimental.UNKNOWN
if total_steps_per_epoch == tf.data.UNKNOWN_CARDINALITY:
     print("Warning: Dataset cardinality is unknown. Using estimated steps.")
     total_steps_per_epoch = len(input_padded) // BATCH_SIZE
     if len(input_padded) % BATCH_SIZE != 0:
         total_steps_per_epoch += 1 # Account for partial last batch
     # Ensure total_steps_per_epoch is at least 1 if data exists
     total_steps_per_epoch = max(1, total_steps_per_epoch)


print("\nDataset prepared.")

# Pass total_steps_per_epoch to Cell 11 via a global variable (or ensure Cell 11 calculates it)
# Let's make it a global variable accessible to the next cell
global total_steps_per_epoch_global
total_steps_per_epoch_global = total_steps_per_epoch

Extracting dataset from: /content/cornell_movie_dialogs_corpus.zip
Dataset extracted to: /content/cornell_dialogs_extracted/
Actual data files are located in: /content/cornell_dialogs_extracted/cornell movie-dialogs corpus
Loaded 304713 lines and 83097 conversations.
Created 221282 question-answer pairs.

Cleaned 221282 valid pairs after applying text cleaning.
Cleaned Data Example: ['<start> can we make this quick? roxanne korrine and andrew barrett are having an incredibly horrendous public break up on the quad. again. <end>', "<start> well, i thought we'd start with pronunciation, if that's okay with you. <end>"]

Original sequences example: [2, 47, 23, 100, 22, 24158, 84339, 84340, 9, 6338, 20175, 35, 416, 71, 4898, 84341, 1286, 630, 58, 33, 6, 84342, 336, 3]
Target sequences example: [2, 79, 5, 152, 676, 361, 31, 53075, 41, 44, 790, 31, 55, 3]
Calculated MAX_LENGTH (95th percentile capped/uncapped): 33
Number of pairs to pad: 221282

Padded input sequence example: [    2    47    

In [38]:
# @title Cell 3: Hyperparameters and Model Parameters

# These values are now determined in Cell 2 based on the Cornell dataset
# Make sure you run Cell 2 *before* running this cell for the values to be correct.
VOCAB_SIZE = VOCAB_SIZE       # From Cell 2
MAX_LENGTH = MAX_LENGTH       # From Cell 2
BATCH_SIZE = BATCH_SIZE       # From Cell 2

EMBEDDING_DIM = 128
UNITS = 512
NUM_HEADS = 8
D_MODEL = EMBEDDING_DIM # Dimension of the model (embedding size)
FFN_UNITS = UNITS # Dimension of the feed-forward network inner layer

# Training parameters
EPOCHS = 20 # Reduced epochs for quicker demo with larger data, increase if needed

In [39]:
# @title Cell 4: Positional Encoding
def get_angles(pos, i, d_model):
    angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model))
    return pos * angle_rates

def positional_encoding(position, d_model):
    angle_rads = get_angles(np.arange(position)[:, np.newaxis],
                            np.arange(d_model)[np.newaxis, :],
                            d_model)

    # apply sin to even indices in the array; 2i
    angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])

    # apply cos to odd indices in the array; 2i+1
    angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])

    pos_encoding = angle_rads[np.newaxis, ...]

    return tf.cast(pos_encoding, dtype=tf.float32)

# Example visualization (optional)
# pos_encoding_example = positional_encoding(50, 128)
# print(pos_encoding_example.shape) # Expected: (1, 50, 128)

In [40]:
# @title Cell 5: Masking
def create_padding_mask(seq):
    seq = tf.cast(tf.math.equal(seq, 0), tf.float32)
    # add extra dimensions to add the padding to the attention logits.
    return seq[:, tf.newaxis, tf.newaxis, :] # Shape (batch_size, 1, 1, seq_len)

def create_look_ahead_mask(size):
    mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
    # mask is a lower triangular matrix.
    # We return the *upper* triangular part which becomes 1s
    # where we *want* to mask (look ahead tokens).
    return tf.cast(mask, tf.float32) # Shape (size, size)

# Combine masks for the decoder's first attention layer
def create_masks(inp, tar):
    # Encoder padding mask
    enc_padding_mask = create_padding_mask(inp)

    # Used in the 2nd attention block in the decoder.
    # This padding mask is used to mask the encoder outputs.
    dec_padding_mask = create_padding_mask(inp)

    # Used in the 1st attention block in the decoder.
    # It is used to pad and mask future tokens in the input received by the decoder.
    look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1])
    dec_target_padding_mask = create_padding_mask(tar)
    combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)

    return enc_padding_mask, combined_mask, dec_padding_mask

In [41]:
# @title Cell 6: Core Transformer Components (Layers)
class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model

        # Ensure d_model is divisible by num_heads
        assert d_model % self.num_heads == 0

        self.depth = d_model // self.num_heads # Depth of each attention head

        self.wq = tf.keras.layers.Dense(d_model) # Query weight matrix
        self.wk = tf.keras.layers.Dense(d_model) # Key weight matrix
        self.wv = tf.keras.layers.Dense(d_model) # Value weight matrix

        self.dense = tf.keras.layers.Dense(d_model) # Output dense layer

    def split_heads(self, x, batch_size):
        """Split the last dimension into (num_heads, depth).
           Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
        """
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def scaled_dot_product_attention(self, q, k, v, mask):
        """Calculate the attention weights.
        q, k, v must have matching leading dimensions.
        k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
        The mask has different shapes depending on its type (padding or look ahead)
        but both shapes are broadcastable.

        Args:
          q: query shape == (..., seq_len_q, depth)
          k: key shape == (..., seq_len_k, depth)
          v: value shape == (..., seq_len_v, depth_v)
          mask: Float tensor with shape broadcastable
                to (..., seq_len_q, seq_len_k). Defaults to None.

        Returns:
          output, attention_weights
        """
        # Matmul(Q, K_T)
        matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)

        # scale by sqrt(depth)
        dk = tf.cast(tf.shape(k)[-1], tf.float32)
        scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

        # add the mask to the scaled attention logits.
        if mask is not None:
            # Add a large negative number to masked positions.
            # When softmax is applied, these positions get nearly zero weight.
            scaled_attention_logits += (mask * -1e9)

        # softmax is normalized on the last axis (seq_len_k) so that the scores add up to 1.
        attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)

        output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)

        return output, attention_weights

    def call(self, v, k, q, mask):
        batch_size = tf.shape(q)[0]

        q = self.wq(q)  # (batch_size, seq_len, d_model)
        k = self.wk(k)  # (batch_size, seq_len, d_model)
        v = self.wv(v)  # (batch_size, seq_len, d_model)

        q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
        k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
        v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth_v)

        # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
        # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
        scaled_attention, attention_weights = self.scaled_dot_product_attention(
            q, k, v, mask)

        # Concatenate heads
        scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)
        concat_attention = tf.reshape(scaled_attention,
                                      (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)

        # Final linear layer
        output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)

        return output, attention_weights

class PointWiseFeedForwardNetwork(tf.keras.layers.Layer):
    def __init__(self, d_model, dff):
        super(PointWiseFeedForwardNetwork, self).__init__()
        self.dense1 = tf.keras.layers.Dense(dff, activation='relu') # Inner layer
        self.dense2 = tf.keras.layers.Dense(d_model) # Output layer

    def call(self, x):
        # Applies the network independently to each position
        return self.dense2(self.dense1(x))

In [42]:
# @title Cell 7: Building the Transformer Block
class EncoderLayer(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, dff, rate=0.1):
        super(EncoderLayer, self).__init__()

        # Self-Attention Layer
        self.mha = MultiHeadAttention(d_model, num_heads)

        # Point-wise Feed-Forward Network
        self.ffn = PointWiseFeedForwardNetwork(d_model, dff)

        # Layer Normalization (applied *after* residual connection and dropout in standard impl)
        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

        # Dropout for regularization
        self.dropout1 = tf.keras.layers.Dropout(rate)
        self.dropout2 = tf.keras.layers.Dropout(rate)

    def call(self, x, training, mask):
        # 1. Self-Attention Block
        # x = Input embedding + positional encoding
        # Attention(Q, K, V) = Attention(x, x, x)
        attn_output, _ = self.mha(x, x, x, mask)
        attn_output = self.dropout1(attn_output, training=training)

        # 2. Add & Norm 1 (Residual Connection + Layer Norm)
        # x + Dropout(Attention(x)) followed by LayerNorm
        out1 = self.layernorm1(x + attn_output)

        # 3. Feed Forward Block
        ffn_output = self.ffn(out1) # Point-wise FFN applied to output of Add & Norm 1
        ffn_output = self.dropout2(ffn_output, training=training)

        # 4. Add & Norm 2 (Residual Connection + Layer Norm)
        # out1 + Dropout(FFN(out1)) followed by LayerNorm
        out2 = self.layernorm2(out1 + ffn_output)

        return out2 # Output of the encoder layer

class DecoderLayer(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, dff, rate=0.1):
        super(DecoderLayer, self).__init__()

        # Masked Multi-Head Self-Attention (for decoder input)
        self.mha1 = MultiHeadAttention(d_model, num_heads)

        # Multi-Head Attention (Encoder-Decoder Attention)
        # Queries come from decoder input, Keys/Values come from encoder output
        self.mha2 = MultiHeadAttention(d_model, num_heads)

        # Point-wise Feed-Forward Network
        self.ffn = PointWiseFeedForwardNetwork(d_model, dff)

        # Layer Normalization
        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

        # Dropout
        self.dropout1 = tf.keras.layers.Dropout(rate)
        self.dropout2 = tf.keras.layers.Dropout(rate)
        self.dropout3 = tf.keras.layers.Dropout(rate)

    def call(self, x, enc_output, training, look_ahead_mask, padding_mask):
        # enc_output.shape == (batch_size, input_seq_len, d_model)

        # 1. Masked Multi-Head Self-Attention (Q=x, K=x, V=x)
        attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask)
        attn1 = self.dropout1(attn1, training=training)
        out1 = self.layernorm1(x + attn1) # Add & Norm 1

        # 2. Multi-Head Attention (Encoder-Decoder Attention) (Q=out1, K=enc_output, V=enc_output)
        # Here, queries come from the output of the first attention block in the decoder
        # and keys/values come from the encoder output.
        attn2, attn_weights_block2 = self.mha2(enc_output, enc_output, out1, padding_mask)
        attn2 = self.dropout2(attn2, training=training)
        out2 = self.layernorm2(out1 + attn2) # Add & Norm 2

        # 3. Point-wise Feed-Forward Network
        ffn_output = self.ffn(out2)
        ffn_output = self.dropout3(ffn_output, training=training)
        out3 = self.layernorm3(out2 + ffn_output) # Add & Norm 3

        return out3, attn_weights_block1, attn_weights_block2

In [43]:
# @title Cell 8: The Full Transformer Model (Corrected Structure)
# This cell outlines the structure. The full implementation is complex.

class Encoder(tf.keras.layers.Layer):
    def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size,
                 maximum_position_encoding, rate=0.1):
        super(Encoder, self).__init__()

        self.d_model = d_model
        self.num_layers = num_layers

        self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
        self.pos_encoding = positional_encoding(maximum_position_encoding,
                                                self.d_model)

        self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate)
                           for _ in range(num_layers)]

        self.dropout = tf.keras.layers.Dropout(rate)

    # Corrected call signature to explicitly name 'training' and 'mask'
    def call(self, x, training=False, mask=None): # Added default values for clarity
        seq_len = tf.shape(x)[1]

        # Add embedding and positional encoding.
        x = self.embedding(x)  # (batch_size, input_seq_len, d_model)
        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32)) # Scale embedding
        # Add positional encoding (broadcasting to batch)
        x += self.pos_encoding[:, :seq_len, :]

        x = self.dropout(x, training=training) # Pass training=training here

        for i in range(self.num_layers):
            # Pass training=training and mask=mask to each encoder layer
            x = self.enc_layers[i](x, training=training, mask=mask)

        return x  # (batch_size, input_seq_len, d_model)

class Decoder(tf.keras.layers.Layer):
    def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size,
                 maximum_position_encoding, rate=0.1):
        super(Decoder, self).__init__()

        self.d_model = d_model
        self.num_layers = num_layers

        self.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model)
        self.pos_encoding = positional_encoding(maximum_position_encoding,
                                                d_model)

        self.dec_layers = [DecoderLayer(d_model, num_heads, dff, rate)
                           for _ in range(num_layers)]

        self.dropout = tf.keras.layers.Dropout(rate)

    # Corrected call signature to explicitly name 'training' and masks
    def call(self, x, enc_output, training=False, look_ahead_mask=None, padding_mask=None): # Added default values
        seq_len = tf.shape(x)[1]
        attention_weights = {}

        # Add embedding and positional encoding.
        x = self.embedding(x)  # (batch_size, target_seq_len, d_model)
        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32)) # Scale embedding
        # Add positional encoding (broadcasting to batch)
        x += self.pos_encoding[:, :seq_len, :]

        x = self.dropout(x, training=training) # Pass training=training here

        for i in range(self.num_layers):
            # Pass training and masks as keyword arguments to each decoder layer
            x, block1, block2 = self.dec_layers[i](x, enc_output,
                                                   training=training, # Pass as keyword
                                                   look_ahead_mask=look_ahead_mask, # Pass as keyword
                                                   padding_mask=padding_mask) # Pass as keyword

            attention_weights[f'decoder_layer{i+1}_block1'] = block1
            attention_weights[f'decoder_layer{i+1}_block2'] = block2

        # x.shape == (batch_size, target_seq_len, d_model)
        return x, attention_weights

class Transformer(tf.keras.Model):
    def __init__(self, num_layers, d_model, num_heads, dff,
                 input_vocab_size, target_vocab_size, pe_input, pe_target, rate=0.1):
        super(Transformer, self).__init__()

        self.encoder = Encoder(num_layers, d_model, num_heads, dff,
                               input_vocab_size, pe_input, rate)

        self.decoder = Decoder(num_layers, d_model, num_heads, dff,
                               target_vocab_size, pe_target, rate)

        # Final linear layer to predict vocabulary tokens
        self.final_layer = tf.keras.layers.Dense(target_vocab_size)

    # The call signature of Transformer itself is correct (using keywords already based on previous fix)
    # This error was *within* this call method, when calling self.encoder/decoder
    def call(self, inp, tar, training, enc_padding_mask, look_ahead_mask, dec_padding_mask):
        # inp = input sequence (question)
        # tar = target sequence (answer - shifted right)

        # Encode the input sequence
        # !!! FIX: Pass training and mask as keyword arguments to the encoder !!!
        enc_output = self.encoder(inp, training=training, mask=enc_padding_mask) # (batch_size, inp_seq_len, d_model)

        # Decode the target sequence
        # dec_output.shape == (batch_size, tar_seq_len, d_model)
        # !!! FIX: Pass training and masks as keyword arguments to the decoder !!!
        dec_output, attention_weights = self.decoder(
            tar,
            enc_output,
            training=training,         # Pass as keyword
            look_ahead_mask=look_ahead_mask, # Pass as keyword
            padding_mask=dec_padding_mask) # Pass as keyword (using the param name from Decoder.call)

        # Final linear layer
        final_output = self.final_layer(dec_output) # (batch_size, tar_seq_len, target_vocab_size)

        return final_output, attention_weights

# Instantiate the model (using small number of layers for demo)
NUM_LAYERS = 2 # Number of encoder and decoder layers
# Make sure D_MODEL, NUM_HEADS, FFN_UNITS, VOCAB_SIZE, MAX_LENGTH are defined in Cell 3
transformer = Transformer(NUM_LAYERS, D_MODEL, NUM_HEADS, FFN_UNITS,
                          VOCAB_SIZE, VOCAB_SIZE, # Using same vocab for input/target
                          MAX_LENGTH, MAX_LENGTH) # Max length for positional encoding

print("Transformer model instantiated.")

Transformer model instantiated.


In [49]:
# @title Cell 9: Optimizer and Loss Function
# Custom Learning Rate Schedule (as in original Transformer paper)
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, d_model, warmup_steps=4000):
        super(CustomSchedule, self).__init__()

        self.d_model = tf.cast(d_model, dtype=tf.float32)
        self.warmup_steps = tf.cast(warmup_steps, dtype=tf.float32)

    def __call__(self, step):
        step = tf.cast(step, dtype=tf.float32)
        arg1 = tf.math.rsqrt(step) # 1 / sqrt(step)
        arg2 = step * (self.warmup_steps ** -1.5) # step * warmup_steps^-1.5

        # Learning rate = d_model^-0.5 * min(arg1, arg2)
        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

# Initialize Learning Rate Scheduler and Optimizer
learning_rate = CustomSchedule(D_MODEL)
optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)

# Loss Function (Sparse Categorical Crossentropy)
# We don't want to calculate loss on padding tokens (label 0)
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')

def loss_function(real, pred):
    # real: target sequence shifted left (the actual next tokens)
    # pred: model's output logits for each token position
    mask = tf.math.logical_not(tf.math.equal(real, 0)) # Create mask where real != 0
    loss_ = loss_object(real, pred)

    mask = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask # Apply mask: loss is 0 for padded positions

    return tf.reduce_sum(loss_)/tf.reduce_sum(mask) # Compute mean loss over non-padded tokens

# Metrics (Optional but good)
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

In [50]:
# @title Cell 10: Training Step and Checkpointing
# Checkpoint manager (for saving/loading model weights)
checkpoint_path = "./checkpoints/transformer_basic"
ckpt = tf.train.Checkpoint(transformer=transformer,
                           optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# If a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print('Latest checkpoint restored!!')
else:
    print('No checkpoint found. Starting fresh.')

# The @tf.function decorator compiles the function into a TensorFlow graph
# for faster execution.
@tf.function
def train_step(inp, tar):
    # tar has shape (batch_size, seq_len)
    # We need to split it into decoder input (tar_inp) and target (tar_real)
    # Decoder input is the target sequence shifted right by 1 (including <start>)
    tar_inp = tar[:, :-1] # Exclude the last token
    # Target is the actual next tokens (from index 1 to the end, including <end>)
    tar_real = tar[:, 1:] # Exclude the first token (<start>)

    # Create masks for this batch
    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)

    with tf.GradientTape() as tape:
        # Forward pass through the model
        predictions, _ = transformer(inp, tar_inp,
                             training=True, # Pass as keyword
                             enc_padding_mask=enc_padding_mask, # Good practice for masks too
                             look_ahead_mask=combined_mask,     # Use the correct name from Transformer.call signature
                             dec_padding_mask=dec_padding_mask)

        # Calculate loss (compare predictions for tar_inp with actual next tokens tar_real)
        loss = loss_function(tar_real, predictions)

    # Calculate gradients
    gradients = tape.gradient(loss, transformer.trainable_variables)

    # Apply gradients
    optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))

    # Update metrics
    train_loss(loss)
    # Accuracy: compare predicted token (argmax) with real token
    train_accuracy(tar_real, predictions)

# Test the train step with a batch (optional)
# for (batch_inp, batch_tar) in dataset.take(1):
#     train_step(batch_inp, batch_tar)
#     print("Initial loss:", train_loss.result())
#     print("Initial accuracy:", train_accuracy.result())

No checkpoint found. Starting fresh.


In [None]:
# @title Cell 11: Training Loop (with fixes)

print("Starting training...")

# Ensure these variables are defined in previous cells and accessible:
# - EPOCHS
# - dataset (the tf.data.Dataset)
# - train_step (the @tf.function for a single training step)
# - train_loss (tf.keras.metrics.Mean instance) - WILL BE RE-INITIALIZED
# - train_accuracy (tf.keras.metrics.SparseCategoricalAccuracy instance) - WILL BE RE-INITIALIZED
# - ckpt_manager (tf.train.CheckpointManager instance)
# - time (imported)

for epoch in range(EPOCHS):
    start = time.time()

    # WORKAROUND FOR AttributeError: 'Metric' object has no attribute 'reset_states'
    # Instead of resetting, re-create the metric objects.
    # This is less efficient but achieves the goal of starting each epoch with fresh metrics.
    # The train_step @tf.function should still be able to reference these if defined
    # at the module level above the loop.
    print(f"Epoch {epoch + 1}: Re-initializing metrics as a workaround.")
    train_loss = tf.keras.metrics.Mean(name='train_loss')
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')


    # Iterate over the dataset batches
    for (batch, (inp, tar)) in enumerate(dataset):
        # Perform one training step defined in Cell 10
        # Make sure the train_step definition in Cell 10 is updated
        # to pass 'training=True' and named mask arguments!
        train_step(inp, tar)

        # Print progress every few batches (adjust frequency based on dataset size)
        # For this very small dataset, printing every batch is reasonable.
        # You can check the batch number against len(list(dataset)) if needed,
        # but iterating over the dataset object directly is usually better.
        if batch % 1 == 0: # Print every batch
            # Note: Accessing len(list(dataset)) here will consume the dataset iterator.
            # For large datasets, avoid this. Just show batch number.
            # print(f'Epoch {epoch + 1} Batch {batch} Loss {train_loss.result():.4f} Accuracy {train_accuracy.result():.4f}')
             print(f'Epoch {epoch + 1} Batch {batch} Loss {train_loss.result():.4f} Accuracy {train_accuracy.result():.4f}')


    # Save a checkpoint after each epoch (optional, but good practice)
    # Using ckpt_manager defined in Cell 10
    if (epoch + 1) % 5 == 0: # Save every N epochs (e.g., 5)
        ckpt_save_path = ckpt_manager.save()
        print(f'Saving checkpoint for epoch {epoch + 1} at {ckpt_save_path}')

    # Print end-of-epoch summary
    print(f'Epoch {epoch + 1} complete - Loss {train_loss.result():.4f} Accuracy {train_accuracy.result():.4f}')
    print(f'Time taken for 1 epoch: {time.time() - start:.2f} secs\n')

print("Training finished.")

Starting training...
Epoch 1: Re-initializing metrics as a workaround.




Epoch 1 Batch 0 Loss 11.7116 Accuracy 0.0000
Epoch 1 Batch 1 Loss 11.7119 Accuracy 0.0000
Epoch 1 Batch 2 Loss 11.7118 Accuracy 0.0000
Epoch 1 Batch 3 Loss 11.7117 Accuracy 0.0000
Epoch 1 Batch 4 Loss 11.7113 Accuracy 0.0000
Epoch 1 Batch 5 Loss 11.7112 Accuracy 0.0000
Epoch 1 Batch 6 Loss 11.7109 Accuracy 0.0000
Epoch 1 Batch 7 Loss 11.7105 Accuracy 0.0000
Epoch 1 Batch 8 Loss 11.7102 Accuracy 0.0000
Epoch 1 Batch 9 Loss 11.7099 Accuracy 0.0000
Epoch 1 Batch 10 Loss 11.7098 Accuracy 0.0000
Epoch 1 Batch 11 Loss 11.7096 Accuracy 0.0000
Epoch 1 Batch 12 Loss 11.7091 Accuracy 0.0000
Epoch 1 Batch 13 Loss 11.7089 Accuracy 0.0000
Epoch 1 Batch 14 Loss 11.7083 Accuracy 0.0000
Epoch 1 Batch 15 Loss 11.7081 Accuracy 0.0000
Epoch 1 Batch 16 Loss 11.7078 Accuracy 0.0000
Epoch 1 Batch 17 Loss 11.7074 Accuracy 0.0000
Epoch 1 Batch 18 Loss 11.7072 Accuracy 0.0000
Epoch 1 Batch 19 Loss 11.7068 Accuracy 0.0000
Epoch 1 Batch 20 Loss 11.7065 Accuracy 0.0000
Epoch 1 Batch 21 Loss 11.7062 Accuracy 0.000

In [None]:
# @title Cell 12: Inference (Generating Responses) with Input Loop

# Ensure these variables/objects are defined in previous cells and accessible:
# - transformer (the instantiated Transformer model from Cell 8)
# - tokenizer (from Cell 2)
# - MAX_LENGTH (from Cell 3)
# - ckpt_manager (from Cell 10)
# - clean_text (from Cell 2)
# - create_masks (from Cell 5)
# - ckpt (from Cell 10) # Needed for ckpt.restore


# Function to convert sequence of tokens back to text
def tokens_to_text(sequence, tokenizer):
    """Converts a sequence of token IDs back to text string."""
    # Initialize list to store words
    words = []
    # Iterate through the sequence of token IDs
    for token in sequence:
        # Check if the token ID is valid (not 0 for padding) and not a special token
        if token != 0:
             # Get the word for the token ID from the tokenizer's index_word mapping
            word = tokenizer.index_word.get(token, '') # Use .get to handle potential missing tokens
            # Append the word to the list if it's not a special start/end token
            if word not in ('<start>', '<end>'):
                words.append(word)
        else:
             # Stop if a padding token is encountered (assuming post-padding)
            break

    # Join the list of words into a single string
    # Handle cases where the sequence might be empty or only contain special tokens
    return ' '.join(words) if words else ''


# Function to generate a response given an input sentence
def evaluate(sentence, transformer, tokenizer, max_length):
    """Generates a response token by token for a given input sentence."""

    # Preprocess the input sentence using the same cleaning as training data
    cleaned_sentence = clean_text(sentence)
    input_sequence = tokenizer.texts_to_sequences([cleaned_sentence])

    # Pad the input sequence to the maximum length used during training
    encoder_input = tf.keras.preprocessing.sequence.pad_sequences(input_sequence,
                                                                maxlen=max_length,
                                                                padding='post')
    encoder_input = tf.constant(encoder_input) # Convert to TensorFlow tensor

    # The decoder input starts with the <start> token
    # We need to know the ID of the <start> token from the tokenizer
    try:
        start_token = tokenizer.word_index['<start>']
        end_token = tokenizer.word_index['<end>']
    except KeyError:
        print("Error: '<start>' or '<end>' token not found in tokenizer vocabulary.")
        print("Please ensure these tokens were added during data preparation (Cell 2).")
        return "Error: Special tokens not found.", {}


    # Initialize the decoder input sequence with the <start> token
    decoder_input = tf.expand_dims([start_token], 0) # Shape (1, 1) -> batch size 1, sequence length 1

    # List to store the output token IDs
    output_tokens = []

    # Generate tokens one by one up to max_length
    # Use a loop limit to prevent infinite generation
    for i in range(max_length): # Use max_length as a safety limit for generated sequence
        # Create masks for the current step of decoding
        # enc_padding_mask is based on the fixed encoder_input
        # combined_mask is based on the growing decoder_input (look-ahead and padding)
        # dec_padding_mask is based on the fixed encoder_input (for cross-attention)
        enc_padding_mask, combined_mask, dec_padding_mask = create_masks(encoder_input, decoder_input)

        # Get the model's predictions for the next token
        # Pass training=False for inference and use keyword arguments for masks
        predictions, attention_weights = transformer(encoder_input,
                                                     decoder_input,
                                                     training=False, # Set training to False for inference
                                                     enc_padding_mask=enc_padding_mask,
                                                     look_ahead_mask=combined_mask, # Use the correct param name
                                                     dec_padding_mask=dec_padding_mask) # Use the correct param name


        # Select the last token prediction from the sequence dimension
        # predictions.shape == (batch_size, current_seq_len, vocab_size)
        # We are only interested in the prediction for the *last* token in decoder_input
        predictions = predictions[:, -1:, :] # Shape (batch_size, 1, vocab_size)

        # Get the predicted token ID by taking the argmax (most likely token)
        # predicted_token_id will have shape (batch_size, 1)
        predicted_token_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)

        # Get the scalar token ID (assuming batch size is 1 for inference)
        predicted_token_scalar = predicted_token_id.numpy()[0][0]

        # Append the predicted token ID to the output tokens list
        output_tokens.append(predicted_token_scalar)

        # If the predicted token is the <end> token, stop generation
        if predicted_token_scalar == end_token:
            break

        # Concatenate the predicted token to the decoder input for the next step
        # decoder_input will grow by one token each iteration
        decoder_input = tf.concat([decoder_input, predicted_token_id], axis=-1) # Shape (batch_size, current_seq_len + 1)


    # Convert the list of output token IDs to a tensor and squeeze (remove batch dim)
    # output_tokens is a Python list of ints, convert to numpy array first
    output_tokens_np = np.array(output_tokens, dtype=np.int32)
    # Convert numpy array to tensor
    output_tokens_tensor = tf.constant(output_tokens_np)
    # Squeeze the tensor to remove any extra dimensions (like batch size 1)
    # tf.squeeze is not needed if we are converting a 1D list/array


    # Convert the output token sequence back to text using the helper function
    predicted_sentence = tokens_to_text(output_tokens_tensor.numpy(), tokenizer)

    # Note: We are not returning attention weights in this version of evaluate
    # for simplicity, but you could return the 'attention_weights' variable
    # from the transformer call if needed for visualization.
    return predicted_sentence, attention_weights


# --- Inference Setup ---

print("\n--- Setting up Inference ---")

# Restore latest checkpoint if training was run
# This loads the trained weights into the transformer and optimizer objects
if ckpt_manager.latest_checkpoint:
    # Use expect_partial() because the optimizer state might not fully match
    # if you stopped training early or changed settings slightly.
    # For pure inference, restoring only model weights is often sufficient.
    # Assuming 'ckpt' object was created in Cell 10 and includes the 'transformer' and 'optimizer'
    try:
        ckpt.restore(ckpt_manager.latest_checkpoint).expect_partial()
        print('Latest checkpoint restored for inference.')
    except Exception as e:
        print(f"Error restoring checkpoint: {e}")
        print("Model weights are untrained or from initial load.")
        print("Inference may not produce meaningful results without training.")
else:
     print('No checkpoint found. Model weights are untrained or from initial load.')
     print('Inference may not produce meaningful results without training.')


# --- Interactive Chat Loop ---

print("\n--- Start Chat ---")
print("Type 'quit' or 'exit' to end the chat.")

while True:
    try:
        user_input = input("You: ")

        if user_input.lower() in ['quit', 'exit']:
            print("Chat ended.")
            break

        # Generate response using the evaluate function
        # We discard attention_weights here
        response, _ = evaluate(user_input, transformer, tokenizer, MAX_LENGTH)

        # Print the generated response
        print(f"Bot: {response}")
        print("-" * 20) # Separator for clarity

    except tf.errors.OutOfRangeError:
        print("Error: Dataset finished during evaluation? This is unexpected.")
        break
    except Exception as e:
        print(f"An error occurred during inference: {e}")
        # Continue the loop or break depending on desired error handling
        # break # Uncomment to stop chat on any error
        print("Attempting to continue chat...")

In [None]:
# @title Cell 13: Addressing Advanced Components & Next Steps
print("--- Advanced Components & Next Steps ---")

print("\nKey components mentioned by user:")
print("1. Core transformer architecture: Covered by implementing layers and structure.")
print("2. Advanced training techniques:")
print("   * Pre-training on vast text corpora: Requires *massive* datasets (terabytes), significant computational resources (clusters of GPUs for weeks/months), and complex distributed training setups. Our demo uses tiny data.")
print("   * Fine-tuning with reinforcement learning from human feedback (RLHF): A complex multi-stage process involving training a reward model from human preferences and then fine-tuning the language model using RL (e.g., Proximal Policy Optimization). Beyond the scope of a basic Colab demo.")
print("   * Constitutional AI approaches for safety and alignment: Advanced techniques building on RLHF, using AI feedback based on principles/constitutions. Also highly advanced and resource-intensive.")
print("3. System design components:")
print("   * Efficient inference infrastructure: For large models, requires specialized hardware (TPUs, high-end GPUs), optimized libraries (TensorRT, TensorFlow Lite), and deployment strategies (serving frameworks). Colab provides a single GPU, good for small models/demos.")
print("   * Prompt engineering and context management: Crucial for interacting with large models. Our demo's context is just the current turn. Real systems need robust history tracking and prompt formatting.")
print("   * Memory systems for conversation history: Requires storing and potentially summarizing or selecting relevant past conversation turns to feed into the model's context window. Our demo lacks this.")
print("   * Moderation layers for safety: Additional systems (often separate models or rule-based filters) to detect and prevent harmful, biased, or inappropriate responses. Our demo has no safety mechanisms.")

print("\nSummary:")
print("This notebook provides a foundational implementation of the core Transformer *architecture* components (Self-Attention, FFN, Norm, Residual).")
print("It demonstrates basic data preparation, training loop setup, and inference *concept* on a tiny dataset.")
print("\nTo build a truly effective chat system, you would need:")
print("1.  Vastly more data and computational resources for pre-training/training.")
print("2.  Implementation of advanced techniques like RLHF and potentially Constitutional AI.")
print("3.  Sophisticated system design for context management, efficient serving, and safety.")
print("\nStarting with readily available large pre-trained models (like those from Hugging Face Transformers) and fine-tuning them on a smaller, domain-specific dataset is a far more practical path for building a functional chat system in Colab or with limited resources than training a powerful model from scratch.")