# Installation

In [None]:
# pip install pandas transformers 

In [None]:
# pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

In [20]:
import torch

print("torch + cuda:", torch.__version__) # Check PyTorch version
print("Is cuda avialable:", torch.cuda.is_available())

torch + cuda: 2.6.0+cu124
Is cuda avialable: True


# Dataset Pre-processing

In [21]:
DATASET = "FoCus"

In [22]:
import pandas as pd
import json

with open(f'./datasets/{DATASET}/train.json') as f:
    valid_data = json.load(f)

In [23]:
def convertToDialogue(my_list):
    formatted_string = ""
    for index, item in enumerate(my_list):
        if index % 2 == 0:
            user = "User1"
        else:
            user = "User2"
        formatted_string += f"{user}: {item}\n"
    formatted_string = formatted_string.rstrip("\n")
    return formatted_string

flattened_data = []
data_list = valid_data['data']
for entry in data_list:
    persona =  "".join(entry['persona'])
    list_length = len(entry["utterance"])
    last_utterance = entry["utterance"][-1]
    dialogue_key = f"dialogue{list_length}"
    last_item = last_utterance[dialogue_key]
    flattened_data.append({
                'dialogID': entry['dialogID'],
                'persona': persona,
                'utterance': convertToDialogue(last_item)
            })

df = pd.DataFrame(flattened_data)

In [24]:
df = df.replace(r'\*\*', '', regex=True)
df = df.replace(r'\r', '', regex=True)
df = df.replace("'", "", regex=True)

df.dropna(inplace=True)

# Function to split the conversation
def split_conversation(conv_str):
    utterances = conv_str.split("\n")
    context = "\n".join(utterances[:-1])
    response = utterances[-1]
    return context, response

new_rows = []
for index, row in df.iterrows():
    context, response = split_conversation(row['utterance'])
    new_row = {
        'personas': row['persona'],
        'context': context,
        'act_response': response
    }
    new_rows.append(new_row)

new_df = pd.DataFrame(new_rows)

new_df.head(4)

Unnamed: 0,personas,context,act_response
0,I like to go to Church.I am Roman Catholic.I w...,"User1: Wow, this is amazing! What is this?\nUs...",User2: It is in Texas state.
1,I like living in a city.I dont hope to ever vi...,"User1: I know this place, but I dont remember ...","User2: Of course, Captain William Cornwallis S..."
2,I want to visit Mexico.I am interested in the ...,"User1: I know this place, but I dont remember ...","User2: Well, in this rainforest, especially th..."
3,I am afraid of bears.I like valleys.I live in ...,User1: Where is this place?\nUser2: This place...,"User2: Not only it, but there are various othe..."


In [25]:
# Calculate minimum and maximum number of words in each column
min_persona_length = new_df['personas'].apply(lambda x: len(x.split())).min()
max_persona_length = new_df['personas'].apply(lambda x: len(x.split())).max()

min_context_length = new_df['context'].apply(lambda x: len(x.split())).min()
max_context_length = new_df['context'].apply(lambda x: len(x.split())).max()

min_response_length = new_df['act_response'].apply(lambda x: len(x.split())).min()
max_response_length = new_df['act_response'].apply(lambda x: len(x.split())).max()

# Print the lengths in min-max format
print(f"Persona Length (in words): {min_persona_length}-{max_persona_length}")
print(f"Context Length (in words): {min_context_length}-{max_context_length}")
print(f"Response Length (in words): {min_response_length}-{max_response_length}")

Persona Length (in words): 11-92
Context Length (in words): 46-705
Response Length (in words): 2-159


In [26]:
# Save the new DataFrame to a CSV file
new_df.to_csv(f'./datasets/{DATASET}/ds_cleaned.csv', index=False)

# PAA

**The model consists of the following key components:**

- Persona Encoder - Applies self-attention to persona embeddings.
- Context Encoder - Applies self-attention to dialogue history embeddings.
- Persona-Adaptive Attention (PAA) - Uses cross-attention between persona and context, dynamically adjusting their influence.
- Dialogue Decoder - Generates responses using a transformer-based model, incorporating persona-aware representations.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Tokenizer, GPT2LMHeadModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DATASET = "FoCus"
BATCH_SIZE = 16
EPOCS = 100
EMBEDDING_SIZE = 768
MAX_LENGTH_TOKEN = 50

In [3]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer


csv_path = f"datasets/{DATASET}/ds_cleaned.csv"

df = pd.read_csv(csv_path)

assert "personas" in df.columns and "context" in df.columns and "act_response" in df.columns, "Missing columns in dataset!"


In [4]:
df.head()

Unnamed: 0,personas,context,act_response
0,I like to go to Church.I am Roman Catholic.I w...,"User1: Wow, this is amazing! What is this?\nUs...",User2: It is in Texas state.
1,I like living in a city.I dont hope to ever vi...,"User1: I know this place, but I dont remember ...","User2: Of course, Captain William Cornwallis S..."
2,I want to visit Mexico.I am interested in the ...,"User1: I know this place, but I dont remember ...","User2: Well, in this rainforest, especially th..."
3,I am afraid of bears.I like valleys.I live in ...,User1: Where is this place?\nUser2: This place...,"User2: Not only it, but there are various othe..."
4,I would like to visit New Zealand.I would like...,User1: I think Ive been there before but I don...,"User2: Prior to a 2006–2008 street upgrade, Co..."


In [5]:
# class SelfAttention(nn.Module):
#     """Single-head self-attention mechanism."""
#     def __init__(self, embed_size):
#         super(SelfAttention, self).__init__()
#         self.query = nn.Linear(embed_size, embed_size)
#         self.key = nn.Linear(embed_size, embed_size)
#         self.value = nn.Linear(embed_size, embed_size)
#         self.softmax = nn.Softmax(dim=-1)

#     def forward(self, x):
#         Q = self.query(x)
#         K = self.key(x)
#         V = self.value(x)

#         attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (K.shape[-1] ** 0.5)
#         attention_weights = self.softmax(attention_scores)

#         output = torch.matmul(attention_weights, V)
#         return output


In [5]:
import torch
import torch.nn as nn

class TransformerEncoderLayer(nn.Module):
    """Multi-head attention-based Transformer Encoder layer."""
    def __init__(self, embed_size, num_heads=4, ff_hidden_size=1024, dropout=0.1):
        super(TransformerEncoderLayer, self).__init__()
        
        self.self_attention = nn.MultiheadAttention(embed_dim=embed_size, num_heads=num_heads, dropout=dropout)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, ff_hidden_size),
            nn.ReLU(),
            nn.Linear(ff_hidden_size, embed_size)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Self-attention layer
        attn_output, _ = self.self_attention(x, x, x)
        x = self.norm1(x + self.dropout(attn_output))  # Residual connection + LayerNorm
        
        # Feed-forward network
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))  # Residual connection + LayerNorm
        
        return x


In [6]:
# class PersonaEncoder(nn.Module):
#     def __init__(self, embed_size):
#         super(PersonaEncoder, self).__init__()
#         self.attention = SelfAttention(embed_size)
#         self.norm = nn.LayerNorm(embed_size)

#     def forward(self, persona_embeds):
#         attended = self.attention(persona_embeds)
#         return self.norm(attended + persona_embeds)  # Residual Connection

# class ContextEncoder(nn.Module):
#     def __init__(self, embed_size):
#         super(ContextEncoder, self).__init__()
#         self.attention = SelfAttention(embed_size)
#         self.norm = nn.LayerNorm(embed_size)

#     def forward(self, context_embeds):
#         attended = self.attention(context_embeds)
#         return self.norm(attended + context_embeds)  # Residual Connection


In [6]:
class PersonaEncoder(nn.Module):
    def __init__(self, embed_size, num_layers=4):
        super(PersonaEncoder, self).__init__()
        self.layers = nn.ModuleList([TransformerEncoderLayer(embed_size) for _ in range(num_layers)])

    def forward(self, persona_embeds):
        x = persona_embeds  # Shape: (seq_len, batch_size, embed_size)
        for layer in self.layers:
            x = layer(x)
        return x  # Encoded persona representation


class ContextEncoder(nn.Module):
    def __init__(self, embed_size, num_layers=4):
        super(ContextEncoder, self).__init__()
        self.layers = nn.ModuleList([TransformerEncoderLayer(embed_size) for _ in range(num_layers)])

    def forward(self, context_embeds):
        x = context_embeds  # Shape: (seq_len, batch_size, embed_size)
        for layer in self.layers:
            x = layer(x)
        return x  # Encoded context representation


In [8]:
# class PersonaAdaptiveAttention(nn.Module):
#     def __init__(self, embed_size):
#         super(PersonaAdaptiveAttention, self).__init__()
#         self.context_attention = SelfAttention(embed_size)
#         self.persona_attention = SelfAttention(embed_size)

#         self.weighting = nn.Linear(embed_size, 1)  # Adaptive weighting
#         self.norm = nn.LayerNorm(embed_size)

#     def forward(self, persona_embeds, context_embeds):
#         attended_context = self.context_attention(context_embeds)
#         attended_persona = self.persona_attention(persona_embeds)

#         # Adaptive Weighting (From PAA Mechanism)
#         weights = torch.sigmoid(self.weighting(attended_persona))
#         weighted_persona = weights * attended_persona

#         # Fusion: Combining Persona and Context
#         fused_representation = attended_context + weighted_persona
#         return self.norm(fused_representation)


In [7]:
class PersonaAdaptiveAttention(nn.Module):
    def __init__(self, embed_size, num_heads=4):
        super(PersonaAdaptiveAttention, self).__init__()

        # Cross-attention layers
        self.context_cross_attn = nn.MultiheadAttention(embed_dim=embed_size, num_heads=num_heads)
        self.persona_cross_attn = nn.MultiheadAttention(embed_dim=embed_size, num_heads=num_heads)

        # Weight computation for persona importance
        self.weighting = nn.Linear(embed_size, 1)
        
        # Normalization and residual connections
        self.norm = nn.LayerNorm(embed_size)

    def forward(self, persona_embeds, context_embeds, decoder_hidden_states):
        """
        Inputs:
            - persona_embeds: Encoded persona representation (seq_len, batch, embed_size)
            - context_embeds: Encoded context representation (seq_len, batch, embed_size)
            - decoder_hidden_states: Hidden states from decoder (seq_len, batch, embed_size)
        """

        # Compute cross-attention between decoder and persona/context
        attended_context, _ = self.context_cross_attn(decoder_hidden_states, context_embeds, context_embeds)
        attended_persona, _ = self.persona_cross_attn(decoder_hidden_states, persona_embeds, persona_embeds)

        # Compute adaptive weighting for persona importance
        weights = torch.sigmoid(self.weighting(attended_persona))  # (seq_len, batch, 1)

        # Apply Dynamic Masking Mechanism
        seq_len = persona_embeds.shape[0]
        persona_length = persona_embeds.shape[1]
        context_length = context_embeds.shape[1]

        # Compute τ (threshold for masking)
        tau = context_length / (context_length + persona_length)

        # Create binary masks
        persona_mask = (weights > tau).float()  # Keep persona info if weight > τ
        context_mask = (1 - weights > tau).float()  # Keep context info if weight < 1 - τ

        # Apply masks
        masked_persona = persona_mask * attended_persona
        masked_context = context_mask * attended_context

        # Combine masked representations
        fused_representation = masked_persona + masked_context

        return self.norm(fused_representation)


In [9]:
# class DialogueDecoder(nn.Module):
#     def __init__(self, model_name="gpt2", embed_size = EMBEDDING_SIZE):
#         super(DialogueDecoder, self).__init__()
#         self.gpt2 = GPT2LMHeadModel.from_pretrained(model_name)

#     def forward(self, input_embeds, response_tokens):
#         outputs = self.gpt2(inputs_embeds=input_embeds, labels=response_tokens)
#         return outputs.loss, outputs.logits


In [8]:
from transformers import GPT2LMHeadModel

class DialogueDecoder(nn.Module):
    def __init__(self, model_name="gpt2", embed_size=768):
        super(DialogueDecoder, self).__init__()
        self.gpt2 = GPT2LMHeadModel.from_pretrained(model_name)
        
        # Linear layer to match input dimensions
        self.fusion_projection = nn.Linear(embed_size, self.gpt2.config.n_embd)

    def forward(self, fused_representation, response_tokens):
        """
        Inputs:
            - fused_representation: Output from Persona-Adaptive Attention (seq_len, batch, embed_size)
            - response_tokens: Tokenized response input (batch, seq_len)
        """
        
        # Project fused_representation to match GPT-2 hidden size
        fused_representation = self.fusion_projection(fused_representation)

        # Prepend fused persona-context embeddings as a prompt
        inputs_embeds = self.gpt2.transformer.wte(response_tokens)  # GPT-2 embeddings
        inputs_embeds = torch.cat([fused_representation, inputs_embeds], dim=1)  # Prepend PAA output

        # Pass through GPT-2 decoder
        outputs = self.gpt2(inputs_embeds=inputs_embeds, labels=response_tokens)

        return outputs.loss, outputs.logits


In [10]:
# class PersonaAdaptiveChatbot(nn.Module):
#     def __init__(self, embed_size = EMBEDDING_SIZE, model_name="gpt2"):
#         super(PersonaAdaptiveChatbot, self).__init__()

#         self.persona_encoder = PersonaEncoder(embed_size)
#         self.context_encoder = ContextEncoder(embed_size)
#         self.paa = PersonaAdaptiveAttention(embed_size)
#         self.decoder = DialogueDecoder(model_name, embed_size)

#         self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)

#     def forward(self, persona_tokens, context_tokens, response_tokens):
#         # Token Embeddings
#         persona_embeds = self.decoder.gpt2.transformer.wte(persona_tokens)
#         context_embeds = self.decoder.gpt2.transformer.wte(context_tokens)

#         # Encoders
#         encoded_persona = self.persona_encoder(persona_embeds)
#         encoded_context = self.context_encoder(context_embeds)

#         # Persona-Adaptive Attention
#         fused_representation = self.paa(encoded_persona, encoded_context)

#         # Use PAA Output for GPT-2 Decoder
#         response_embeds = self.decoder.gpt2.transformer.wte(response_tokens)
#         response_embeds = response_embeds + fused_representation  # Inject PAA info

#         return self.decoder(response_embeds, response_tokens)


In [9]:
from transformers import GPT2Tokenizer

class PersonaAdaptiveChatbot(nn.Module):
    def __init__(self, embed_size=768, model_name="gpt2"):
        super(PersonaAdaptiveChatbot, self).__init__()

        # Encoders
        self.persona_encoder = PersonaEncoder(embed_size)
        self.context_encoder = ContextEncoder(embed_size)

        # Persona-Adaptive Attention
        self.paa = PersonaAdaptiveAttention(embed_size)

        # Decoder
        self.decoder = DialogueDecoder(model_name, embed_size)

        # Tokenizer for text processing
        self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)

    def forward(self, persona_tokens, context_tokens, response_tokens):
        """
        Inputs:
            - persona_tokens: Tokenized persona input (batch, seq_len)
            - context_tokens: Tokenized context input (batch, seq_len)
            - response_tokens: Tokenized response input (batch, seq_len)
        """

        # Convert token IDs to embeddings using GPT-2 word embeddings
        persona_embeds = self.decoder.gpt2.transformer.wte(persona_tokens)
        context_embeds = self.decoder.gpt2.transformer.wte(context_tokens)

        # Encode persona and context
        encoded_persona = self.persona_encoder(persona_embeds.permute(1, 0, 2))  # (seq_len, batch, embed)
        encoded_context = self.context_encoder(context_embeds.permute(1, 0, 2))  # (seq_len, batch, embed)

        # Persona-Adaptive Attention (PAA) to fuse persona and context
        fused_representation = self.paa(encoded_persona, encoded_context, decoder_hidden_states=None)

        # Use PAA output in GPT-2 Decoder
        return self.decoder(fused_representation, response_tokens)


In [11]:
# # Load GPT-2 tokenizer
# tokenizer = GPT2Tokenizer.from_pretrained("gpt2", cache_dir="downloaded_LM")  # GPT-2 tokenizer with caching
# tokenizer.pad_token = tokenizer.eos_token  # GPT-2 does not have a padding token

# # Tokenization function
# def tokenize_texts(persona, context, response, max_length = MAX_LENGTH_TOKEN):
#     persona_tokens = tokenizer(persona, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")["input_ids"]
#     context_tokens = tokenizer(context, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")["input_ids"]
#     response_tokens = tokenizer(response, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")["input_ids"]
#     return persona_tokens.squeeze(0), context_tokens.squeeze(0), response_tokens.squeeze(0)

# # Apply tokenization to the entire dataset
# df["tokenized"] = df.apply(lambda row: tokenize_texts(row["personas"], row["context"], row["act_response"]), axis=1)

# # Extract tokenized tensors
# persona_tensors = torch.stack([x[0] for x in df["tokenized"]])
# context_tensors = torch.stack([x[1] for x in df["tokenized"]])
# response_tensors = torch.stack([x[2] for x in df["tokenized"]])


In [10]:
# Load GPT-2 tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", cache_dir="downloaded_LM")
tokenizer.pad_token = tokenizer.eos_token  # GPT-2 does not have a padding token

# Tokenization function
def tokenize_texts(persona, context, response, max_length=MAX_LENGTH_TOKEN):
    persona_tokens = tokenizer(persona, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")["input_ids"]
    context_tokens = tokenizer(context, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")["input_ids"]
    response_tokens = tokenizer(response, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")["input_ids"]

    # Ensure correct shape (batch, seq_len)
    persona_tokens = persona_tokens.squeeze(0)
    context_tokens = context_tokens.squeeze(0)
    response_tokens = response_tokens.squeeze(0)

    return persona_tokens, context_tokens, response_tokens

# Apply tokenization to the dataset
df["tokenized"] = df.apply(lambda row: tokenize_texts(row["personas"], row["context"], row["act_response"]), axis=1)

# Extract tokenized tensors
persona_tensors = torch.stack([x[0] for x in df["tokenized"]])
context_tensors = torch.stack([x[1] for x in df["tokenized"]])
response_tensors = torch.stack([x[2] for x in df["tokenized"]])

# Ensure batch-first format
persona_tensors = persona_tensors.permute(1, 0)  # Shape (batch, seq_len)
context_tensors = context_tensors.permute(1, 0)  # Shape (batch, seq_len)
response_tensors = response_tensors.permute(1, 0)  # Shape (batch, seq_len)


In [15]:
class FoCusDataset(Dataset):
    def __init__(self, persona_tensors, context_tensors, response_tensors):
        self.persona_tensors = persona_tensors
        self.context_tensors = context_tensors
        self.response_tensors = response_tensors

    def __len__(self):
        return len(self.persona_tensors)

    def __getitem__(self, idx):
        return {
            "personas": self.persona_tensors[idx],
            "context": self.context_tensors[idx],
            "response": self.response_tensors[idx],
        }

# Create dataset
dataset = FoCusDataset(persona_tensors, context_tensors, response_tensors)

# Create DataLoader
train_loader = DataLoader(dataset, batch_size = BATCH_SIZE, shuffle=True)


In [13]:
# # Define model
# model = PersonaAdaptiveChatbot()

# # Define optimizer
# optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

# # Define loss function (GPT-2 uses CrossEntropyLoss)
# criterion = nn.CrossEntropyLoss()


In [12]:
MODEL_SAVE_DIR = f"trained_model/{DATASET}"


# Define model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PersonaAdaptiveChatbot().to(device)

# Define optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)


In [15]:
# import torch
# from tqdm import tqdm
# import os

# # Set device for training
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# def train(model, dataloader, optimizer, num_epochs = EPOCS):
#     model.to(device)  # Move model to GPU
#     model.train()

#     for epoch in range(num_epochs):
#         total_loss = 0
#         progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=True)

#         for batch in progress_bar:
#             # Move batch tensors to GPU
#             persona_tokens = batch["personas"].to(device)
#             context_tokens = batch["context"].to(device)
#             response_tokens = batch["response"].to(device)

#             optimizer.zero_grad()
#             loss, logits = model(persona_tokens, context_tokens, response_tokens)

#             loss.backward()
#             optimizer.step()

#             total_loss += loss.item()

#             # Update tqdm progress bar with current loss
#             progress_bar.set_postfix(loss=total_loss / (progress_bar.n + 1))

#         print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss:.4f}")

#     # Save the model after each epoch
#     model_save_path = os.path.join(MODEL_SAVE_DIR, f"model_epoch_{EPOCS}.pth")
#     torch.save(model.state_dict(), model_save_path)
#     print(f"Model saved at: {model_save_path}")


In [13]:
import torch
from tqdm import tqdm
import os

# Set device for training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train(model, dataloader, optimizer, num_epochs=EPOCS):
    model.to(device)  # Move model to GPU
    model.train()

    for epoch in range(num_epochs):
        total_loss = 0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=True)

        for batch in progress_bar:
            # Move batch tensors to GPU
            persona_tokens = batch["personas"].to(device)
            context_tokens = batch["context"].to(device)
            response_tokens = batch["response"].to(device)

            optimizer.zero_grad()
            
            # Forward pass
            loss, logits = model(persona_tokens, context_tokens, response_tokens)
            
            # Backward pass
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            # Update tqdm progress bar with current loss
            progress_bar.set_postfix(loss=total_loss / (progress_bar.n + 1))

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss:.4f}")

        # Save the model after each epoch
        model_save_path = os.path.join(MODEL_SAVE_DIR, f"model_epoch_{epoch+1}.pth")
        torch.save(model.state_dict(), model_save_path)
        print(f"Model saved at: {model_save_path}")


In [16]:
# Train the model with CSV data
train(model, train_loader, optimizer)


Epoch 1/100:   0%|          | 0/4 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 37.16 GiB. GPU 0 has a total capacity of 47.40 GiB of which 29.46 GiB is free. Process 544030 has 11.51 GiB memory in use. Including non-PyTorch memory, this process has 6.32 GiB memory in use. Of the allocated memory 4.63 GiB is allocated by PyTorch, and 1.20 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

# Response Generation

In [17]:
DATASET = "FoCus"
MODEL_SAVE_DIR = f"trained_model/{DATASET}"
EPOCS = 100
MAX_LENGTH_TOKEN = 50

In [18]:
model = PersonaAdaptiveChatbot()  # Recreate the model architecture
model.load_state_dict(torch.load(f"{MODEL_SAVE_DIR}/model_epoch_{EPOCS}.pth"))
model.to(device)


PersonaAdaptiveChatbot(
  (persona_encoder): PersonaEncoder(
    (attention): SelfAttention(
      (query): Linear(in_features=768, out_features=768, bias=True)
      (key): Linear(in_features=768, out_features=768, bias=True)
      (value): Linear(in_features=768, out_features=768, bias=True)
      (softmax): Softmax(dim=-1)
    )
    (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (context_encoder): ContextEncoder(
    (attention): SelfAttention(
      (query): Linear(in_features=768, out_features=768, bias=True)
      (key): Linear(in_features=768, out_features=768, bias=True)
      (value): Linear(in_features=768, out_features=768, bias=True)
      (softmax): Softmax(dim=-1)
    )
    (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (paa): PersonaAdaptiveAttention(
    (context_attention): SelfAttention(
      (query): Linear(in_features=768, out_features=768, bias=True)
      (key): Linear(in_features=768, out_features=768, bias=True)
      (

In [19]:
import torch

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def generate_response(model, persona_text, context_text, max_length= MAX_LENGTH_TOKEN):
    model.to(device)  # Ensure model is on GPU
    model.eval()

    # Tokenize and move to GPU, ensuring padding for alignment
    persona_tokens = tokenizer(persona_text, return_tensors="pt", padding=True, truncation=True)["input_ids"].to(device)
    context_tokens = tokenizer(context_text, return_tensors="pt", padding=True, truncation=True)["input_ids"].to(device)

    with torch.no_grad():
        persona_embeds = model.decoder.gpt2.transformer.wte(persona_tokens)
        context_embeds = model.decoder.gpt2.transformer.wte(context_tokens)

        # Determine max sequence length for padding
        max_seq_len = max(persona_embeds.shape[1], context_embeds.shape[1])

        # Pad tensors to the same length
        pad_size_persona = max_seq_len - persona_embeds.shape[1]
        pad_size_context = max_seq_len - context_embeds.shape[1]

        persona_embeds = torch.nn.functional.pad(persona_embeds, (0, 0, 0, pad_size_persona), "constant", 0)
        context_embeds = torch.nn.functional.pad(context_embeds, (0, 0, 0, pad_size_context), "constant", 0)

        # Encode using persona and context encoders
        encoded_persona = model.persona_encoder(persona_embeds)
        encoded_context = model.context_encoder(context_embeds)

        # Apply Persona-Adaptive Attention (PAA)
        fused_representation = model.paa(encoded_persona, encoded_context)

    # Generate response ensuring input is on GPU
    generated = model.decoder.gpt2.generate(input_ids=context_tokens, max_length=max_length).to(device)
    return tokenizer.decode(generated[0], skip_special_tokens=True)

# Example
persona_text = "I love sci-fi movies."
context_text = "user1: What's your favorite movie?"
response = generate_response(model, persona_text, context_text)
print("Generated Response:", response)


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Generated Response: user1: What's your favorite movie?
