# Installation

In [None]:
# pip install pandas transformers 

In [None]:
# pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

In [20]:
import torch

print("torch + cuda:", torch.__version__) # Check PyTorch version
print("Is cuda avialable:", torch.cuda.is_available())

torch + cuda: 2.6.0+cu124
Is cuda avialable: True


# Dataset Pre-processing

In [21]:
DATASET = "FoCus"

In [22]:
import pandas as pd
import json

with open(f'./datasets/{DATASET}/train.json') as f:
    valid_data = json.load(f)

In [23]:
def convertToDialogue(my_list):
    formatted_string = ""
    for index, item in enumerate(my_list):
        if index % 2 == 0:
            user = "User1"
        else:
            user = "User2"
        formatted_string += f"{user}: {item}\n"
    formatted_string = formatted_string.rstrip("\n")
    return formatted_string

flattened_data = []
data_list = valid_data['data']
for entry in data_list:
    persona =  "".join(entry['persona'])
    list_length = len(entry["utterance"])
    last_utterance = entry["utterance"][-1]
    dialogue_key = f"dialogue{list_length}"
    last_item = last_utterance[dialogue_key]
    flattened_data.append({
                'dialogID': entry['dialogID'],
                'persona': persona,
                'utterance': convertToDialogue(last_item)
            })

df = pd.DataFrame(flattened_data)

In [24]:
df = df.replace(r'\*\*', '', regex=True)
df = df.replace(r'\r', '', regex=True)
df = df.replace("'", "", regex=True)

df.dropna(inplace=True)

# Function to split the conversation
def split_conversation(conv_str):
    utterances = conv_str.split("\n")
    context = "\n".join(utterances[:-1])
    response = utterances[-1]
    return context, response

new_rows = []
for index, row in df.iterrows():
    context, response = split_conversation(row['utterance'])
    new_row = {
        'personas': row['persona'],
        'context': context,
        'act_response': response
    }
    new_rows.append(new_row)

new_df = pd.DataFrame(new_rows)

new_df.head(4)

Unnamed: 0,personas,context,act_response
0,I like to go to Church.I am Roman Catholic.I w...,"User1: Wow, this is amazing! What is this?\nUs...",User2: It is in Texas state.
1,I like living in a city.I dont hope to ever vi...,"User1: I know this place, but I dont remember ...","User2: Of course, Captain William Cornwallis S..."
2,I want to visit Mexico.I am interested in the ...,"User1: I know this place, but I dont remember ...","User2: Well, in this rainforest, especially th..."
3,I am afraid of bears.I like valleys.I live in ...,User1: Where is this place?\nUser2: This place...,"User2: Not only it, but there are various othe..."


In [25]:
# Calculate minimum and maximum number of words in each column
min_persona_length = new_df['personas'].apply(lambda x: len(x.split())).min()
max_persona_length = new_df['personas'].apply(lambda x: len(x.split())).max()

min_context_length = new_df['context'].apply(lambda x: len(x.split())).min()
max_context_length = new_df['context'].apply(lambda x: len(x.split())).max()

min_response_length = new_df['act_response'].apply(lambda x: len(x.split())).min()
max_response_length = new_df['act_response'].apply(lambda x: len(x.split())).max()

# Print the lengths in min-max format
print(f"Persona Length (in words): {min_persona_length}-{max_persona_length}")
print(f"Context Length (in words): {min_context_length}-{max_context_length}")
print(f"Response Length (in words): {min_response_length}-{max_response_length}")

Persona Length (in words): 11-92
Context Length (in words): 46-705
Response Length (in words): 2-159


In [26]:
# Save the new DataFrame to a CSV file
new_df.to_csv(f'./datasets/{DATASET}/ds_cleaned.csv', index=False)

# PAA

**The model consists of the following key components:**

- Persona Encoder - Applies self-attention to persona embeddings.
- Context Encoder - Applies self-attention to dialogue history embeddings.
- Persona-Adaptive Attention (PAA) - Uses cross-attention between persona and context, dynamically adjusting their influence.
- Dialogue Decoder - Generates responses using a transformer-based model, incorporating persona-aware representations.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Tokenizer, GPT2LMHeadModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DATASET = "FoCus"
BATCH_SIZE = 16
EPOCS = 100
EMBEDDING_SIZE = 768
MAX_LENGTH_TOKEN = 50

In [3]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer


csv_path = f"datasets/{DATASET}/ds_cleaned.csv"

df = pd.read_csv(csv_path)

assert "personas" in df.columns and "context" in df.columns and "act_response" in df.columns, "Missing columns in dataset!"


In [4]:
df.head()

Unnamed: 0,personas,context,act_response
0,I would like to visit the Nazareth House again...,User1: I think Ive been there before but I don...,User2: The history of the house you are intere...
1,I have been to Vermont a few times to go skiin...,"User1: Wow, this is amazing! What is this?\nUs...",User2: This house was use as a stop for slaves...
2,I am fascinated by the Spanish Colonial Reviva...,"User1: Wow, this is amazing! What is this?\nUs...","User2: Sure, you will like to know that this p..."
3,I want to become a college student.I want to s...,User1: Where is this place?\nUser2: Hello! Wel...,User2: Technische Universität Darmstadt in the...
4,I like to visit england.I love church.I would ...,User1: Where is this place?\nUser2: This place...,"User2: I suggest a place, for your wish of see..."


In [5]:
class SelfAttention(nn.Module):
    """Single-head self-attention mechanism."""
    def __init__(self, embed_size):
        super(SelfAttention, self).__init__()
        self.query = nn.Linear(embed_size, embed_size)
        self.key = nn.Linear(embed_size, embed_size)
        self.value = nn.Linear(embed_size, embed_size)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)

        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (K.shape[-1] ** 0.5)
        attention_weights = self.softmax(attention_scores)

        output = torch.matmul(attention_weights, V)
        return output


In [6]:
class PersonaEncoder(nn.Module):
    def __init__(self, embed_size):
        super(PersonaEncoder, self).__init__()
        self.attention = SelfAttention(embed_size)
        self.norm = nn.LayerNorm(embed_size)

    def forward(self, persona_embeds):
        attended = self.attention(persona_embeds)
        return self.norm(attended + persona_embeds)  # Residual Connection


In [7]:
class ContextEncoder(nn.Module):
    def __init__(self, embed_size):
        super(ContextEncoder, self).__init__()
        self.attention = SelfAttention(embed_size)
        self.norm = nn.LayerNorm(embed_size)

    def forward(self, context_embeds):
        attended = self.attention(context_embeds)
        return self.norm(attended + context_embeds)  # Residual Connection


In [8]:
class PersonaAdaptiveAttention(nn.Module):
    def __init__(self, embed_size):
        super(PersonaAdaptiveAttention, self).__init__()
        self.context_attention = SelfAttention(embed_size)
        self.persona_attention = SelfAttention(embed_size)

        self.weighting = nn.Linear(embed_size, 1)  # Adaptive weighting
        self.norm = nn.LayerNorm(embed_size)

    def forward(self, persona_embeds, context_embeds):
        attended_context = self.context_attention(context_embeds)
        attended_persona = self.persona_attention(persona_embeds)

        # Adaptive Weighting (From PAA Mechanism)
        weights = torch.sigmoid(self.weighting(attended_persona))
        weighted_persona = weights * attended_persona

        # Fusion: Combining Persona and Context
        fused_representation = attended_context + weighted_persona
        return self.norm(fused_representation)


In [9]:
class DialogueDecoder(nn.Module):
    def __init__(self, model_name="gpt2", embed_size = EMBEDDING_SIZE):
        super(DialogueDecoder, self).__init__()
        self.gpt2 = GPT2LMHeadModel.from_pretrained(model_name)

    def forward(self, input_embeds, response_tokens):
        outputs = self.gpt2(inputs_embeds=input_embeds, labels=response_tokens)
        return outputs.loss, outputs.logits


In [10]:
class PersonaAdaptiveChatbot(nn.Module):
    def __init__(self, embed_size = EMBEDDING_SIZE, model_name="gpt2"):
        super(PersonaAdaptiveChatbot, self).__init__()

        self.persona_encoder = PersonaEncoder(embed_size)
        self.context_encoder = ContextEncoder(embed_size)
        self.paa = PersonaAdaptiveAttention(embed_size)
        self.decoder = DialogueDecoder(model_name, embed_size)

        self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)

    def forward(self, persona_tokens, context_tokens, response_tokens):
        # Token Embeddings
        persona_embeds = self.decoder.gpt2.transformer.wte(persona_tokens)
        context_embeds = self.decoder.gpt2.transformer.wte(context_tokens)

        # Encoders
        encoded_persona = self.persona_encoder(persona_embeds)
        encoded_context = self.context_encoder(context_embeds)

        # Persona-Adaptive Attention
        fused_representation = self.paa(encoded_persona, encoded_context)

        # Use PAA Output for GPT-2 Decoder
        response_embeds = self.decoder.gpt2.transformer.wte(response_tokens)
        response_embeds = response_embeds + fused_representation  # Inject PAA info

        return self.decoder(response_embeds, response_tokens)


In [11]:
# Load GPT-2 tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token  # GPT-2 does not have a padding token

# Tokenization function
def tokenize_texts(persona, context, response, max_length = MAX_LENGTH_TOKEN):
    persona_tokens = tokenizer(persona, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")["input_ids"]
    context_tokens = tokenizer(context, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")["input_ids"]
    response_tokens = tokenizer(response, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")["input_ids"]
    return persona_tokens.squeeze(0), context_tokens.squeeze(0), response_tokens.squeeze(0)

# Apply tokenization to the entire dataset
df["tokenized"] = df.apply(lambda row: tokenize_texts(row["personas"], row["context"], row["act_response"]), axis=1)

# Extract tokenized tensors
persona_tensors = torch.stack([x[0] for x in df["tokenized"]])
context_tensors = torch.stack([x[1] for x in df["tokenized"]])
response_tensors = torch.stack([x[2] for x in df["tokenized"]])


In [12]:
class FoCusDataset(Dataset):
    def __init__(self, persona_tensors, context_tensors, response_tensors):
        self.persona_tensors = persona_tensors
        self.context_tensors = context_tensors
        self.response_tensors = response_tensors

    def __len__(self):
        return len(self.persona_tensors)

    def __getitem__(self, idx):
        return {
            "personas": self.persona_tensors[idx],
            "context": self.context_tensors[idx],
            "response": self.response_tensors[idx],
        }

# Create dataset
dataset = FoCusDataset(persona_tensors, context_tensors, response_tensors)

# Create DataLoader
train_loader = DataLoader(dataset, batch_size = BATCH_SIZE, shuffle=True)


In [13]:
# Define model
model = PersonaAdaptiveChatbot()

# Define optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

# Define loss function (GPT-2 uses CrossEntropyLoss)
criterion = nn.CrossEntropyLoss()


In [14]:
MODEL_SAVE_DIR = f"trained_model/{DATASET}"

device="cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [15]:
import torch
from tqdm import tqdm
import os

# Set device for training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train(model, dataloader, optimizer, num_epochs = EPOCS):
    model.to(device)  # Move model to GPU
    model.train()

    for epoch in range(num_epochs):
        total_loss = 0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=True)

        for batch in progress_bar:
            # Move batch tensors to GPU
            persona_tokens = batch["personas"].to(device)
            context_tokens = batch["context"].to(device)
            response_tokens = batch["response"].to(device)

            optimizer.zero_grad()
            loss, logits = model(persona_tokens, context_tokens, response_tokens)

            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            # Update tqdm progress bar with current loss
            progress_bar.set_postfix(loss=total_loss / (progress_bar.n + 1))

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss:.4f}")

    # Save the model after each epoch
    model_save_path = os.path.join(MODEL_SAVE_DIR, f"model_epoch_{EPOCS}.pth")
    torch.save(model.state_dict(), model_save_path)
    print(f"Model saved at: {model_save_path}")


In [16]:
# Train the model with CSV data
train(model, train_loader, optimizer)


Epoch 1/100: 100%|██████████| 63/63 [00:03<00:00, 20.49it/s, loss=4.88]


Epoch 1/100, Loss: 302.6696


Epoch 2/100: 100%|██████████| 63/63 [00:03<00:00, 20.78it/s, loss=3.91]


Epoch 2/100, Loss: 238.2913


Epoch 3/100: 100%|██████████| 63/63 [00:03<00:00, 20.81it/s, loss=3.11]


Epoch 3/100, Loss: 189.6502


Epoch 4/100: 100%|██████████| 63/63 [00:03<00:00, 20.71it/s, loss=2.16]


Epoch 4/100, Loss: 131.8541


Epoch 5/100: 100%|██████████| 63/63 [00:03<00:00, 20.67it/s, loss=1.96]


Epoch 5/100, Loss: 119.3260


Epoch 6/100: 100%|██████████| 63/63 [00:03<00:00, 20.56it/s, loss=1.81]


Epoch 6/100, Loss: 110.3237


Epoch 7/100: 100%|██████████| 63/63 [00:03<00:00, 20.58it/s, loss=1.67]


Epoch 7/100, Loss: 101.7993


Epoch 8/100: 100%|██████████| 63/63 [00:03<00:00, 20.73it/s, loss=1.55]


Epoch 8/100, Loss: 94.2624


Epoch 9/100: 100%|██████████| 63/63 [00:03<00:00, 20.65it/s, loss=1.42]


Epoch 9/100, Loss: 86.8951


Epoch 10/100: 100%|██████████| 63/63 [00:03<00:00, 20.52it/s, loss=1.29]


Epoch 10/100, Loss: 78.9312


Epoch 11/100: 100%|██████████| 63/63 [00:03<00:00, 20.13it/s, loss=1.15]


Epoch 11/100, Loss: 71.4479


Epoch 12/100: 100%|██████████| 63/63 [00:03<00:00, 20.09it/s, loss=1.02]


Epoch 12/100, Loss: 64.2800


Epoch 13/100: 100%|██████████| 63/63 [00:03<00:00, 19.98it/s, loss=0.942]


Epoch 13/100, Loss: 57.4456


Epoch 14/100: 100%|██████████| 63/63 [00:03<00:00, 19.97it/s, loss=0.816]


Epoch 14/100, Loss: 51.4091


Epoch 15/100: 100%|██████████| 63/63 [00:03<00:00, 20.02it/s, loss=0.791]


Epoch 15/100, Loss: 49.8222


Epoch 16/100: 100%|██████████| 63/63 [00:03<00:00, 19.95it/s, loss=0.662]


Epoch 16/100, Loss: 41.7015


Epoch 17/100: 100%|██████████| 63/63 [00:03<00:00, 19.87it/s, loss=0.583]


Epoch 17/100, Loss: 36.1201


Epoch 18/100: 100%|██████████| 63/63 [00:03<00:00, 20.01it/s, loss=0.524]


Epoch 18/100, Loss: 31.9662


Epoch 19/100: 100%|██████████| 63/63 [00:03<00:00, 19.78it/s, loss=0.472]


Epoch 19/100, Loss: 29.2620


Epoch 20/100: 100%|██████████| 63/63 [00:03<00:00, 19.72it/s, loss=0.422]


Epoch 20/100, Loss: 26.5818


Epoch 21/100: 100%|██████████| 63/63 [00:03<00:00, 19.70it/s, loss=0.386]


Epoch 21/100, Loss: 24.3160


Epoch 22/100: 100%|██████████| 63/63 [00:03<00:00, 19.55it/s, loss=0.357]


Epoch 22/100, Loss: 22.4893


Epoch 23/100: 100%|██████████| 63/63 [00:03<00:00, 19.62it/s, loss=0.333]


Epoch 23/100, Loss: 20.9568


Epoch 24/100: 100%|██████████| 63/63 [00:03<00:00, 19.42it/s, loss=0.31] 


Epoch 24/100, Loss: 19.5279


Epoch 25/100: 100%|██████████| 63/63 [00:03<00:00, 19.37it/s, loss=0.294]


Epoch 25/100, Loss: 18.2568


Epoch 26/100: 100%|██████████| 63/63 [00:03<00:00, 19.32it/s, loss=0.277]


Epoch 26/100, Loss: 17.4329


Epoch 27/100: 100%|██████████| 63/63 [00:03<00:00, 19.33it/s, loss=0.268]


Epoch 27/100, Loss: 16.6279


Epoch 28/100: 100%|██████████| 63/63 [00:03<00:00, 19.24it/s, loss=0.257]


Epoch 28/100, Loss: 15.9571


Epoch 29/100: 100%|██████████| 63/63 [00:03<00:00, 19.30it/s, loss=0.247]


Epoch 29/100, Loss: 15.2867


Epoch 30/100: 100%|██████████| 63/63 [00:03<00:00, 19.26it/s, loss=0.234]


Epoch 30/100, Loss: 14.7220


Epoch 31/100: 100%|██████████| 63/63 [00:03<00:00, 19.22it/s, loss=0.231]


Epoch 31/100, Loss: 14.2977


Epoch 32/100: 100%|██████████| 63/63 [00:03<00:00, 19.18it/s, loss=0.229]


Epoch 32/100, Loss: 14.4224


Epoch 33/100: 100%|██████████| 63/63 [00:03<00:00, 19.18it/s, loss=0.218]


Epoch 33/100, Loss: 13.7181


Epoch 34/100: 100%|██████████| 63/63 [00:03<00:00, 19.15it/s, loss=0.212]


Epoch 34/100, Loss: 13.1723


Epoch 35/100: 100%|██████████| 63/63 [00:03<00:00, 19.13it/s, loss=0.207]


Epoch 35/100, Loss: 13.0640


Epoch 36/100: 100%|██████████| 63/63 [00:03<00:00, 19.12it/s, loss=0.202]


Epoch 36/100, Loss: 12.7253


Epoch 37/100: 100%|██████████| 63/63 [00:03<00:00, 19.07it/s, loss=0.198]


Epoch 37/100, Loss: 12.4702


Epoch 38/100: 100%|██████████| 63/63 [00:03<00:00, 19.08it/s, loss=0.192]


Epoch 38/100, Loss: 12.1125


Epoch 39/100: 100%|██████████| 63/63 [00:03<00:00, 19.03it/s, loss=0.208]


Epoch 39/100, Loss: 13.1341


Epoch 40/100: 100%|██████████| 63/63 [00:03<00:00, 18.97it/s, loss=0.203]


Epoch 40/100, Loss: 12.8004


Epoch 41/100: 100%|██████████| 63/63 [00:03<00:00, 18.95it/s, loss=0.186]


Epoch 41/100, Loss: 11.7216


Epoch 42/100: 100%|██████████| 63/63 [00:03<00:00, 19.12it/s, loss=0.181]


Epoch 42/100, Loss: 11.4138


Epoch 43/100: 100%|██████████| 63/63 [00:03<00:00, 19.11it/s, loss=0.178]


Epoch 43/100, Loss: 11.2412


Epoch 44/100: 100%|██████████| 63/63 [00:03<00:00, 19.10it/s, loss=0.175]


Epoch 44/100, Loss: 11.0270


Epoch 45/100: 100%|██████████| 63/63 [00:03<00:00, 19.06it/s, loss=0.171]


Epoch 45/100, Loss: 10.7454


Epoch 46/100: 100%|██████████| 63/63 [00:03<00:00, 18.99it/s, loss=0.169]


Epoch 46/100, Loss: 10.6403


Epoch 47/100: 100%|██████████| 63/63 [00:03<00:00, 18.93it/s, loss=0.167]


Epoch 47/100, Loss: 10.4903


Epoch 48/100: 100%|██████████| 63/63 [00:03<00:00, 19.03it/s, loss=0.166]


Epoch 48/100, Loss: 10.4460


Epoch 49/100: 100%|██████████| 63/63 [00:03<00:00, 19.02it/s, loss=0.162]


Epoch 49/100, Loss: 10.1974


Epoch 50/100: 100%|██████████| 63/63 [00:03<00:00, 19.00it/s, loss=0.155]


Epoch 50/100, Loss: 9.7645


Epoch 51/100: 100%|██████████| 63/63 [00:03<00:00, 19.00it/s, loss=0.151]


Epoch 51/100, Loss: 9.5301


Epoch 52/100: 100%|██████████| 63/63 [00:03<00:00, 18.96it/s, loss=0.149]


Epoch 52/100, Loss: 9.3722


Epoch 53/100: 100%|██████████| 63/63 [00:03<00:00, 18.96it/s, loss=0.155]


Epoch 53/100, Loss: 9.7497


Epoch 54/100: 100%|██████████| 63/63 [00:03<00:00, 18.89it/s, loss=0.142]


Epoch 54/100, Loss: 8.9542


Epoch 55/100: 100%|██████████| 63/63 [00:03<00:00, 18.95it/s, loss=0.131]


Epoch 55/100, Loss: 8.2543


Epoch 56/100: 100%|██████████| 63/63 [00:03<00:00, 18.92it/s, loss=0.127]


Epoch 56/100, Loss: 7.9707


Epoch 57/100: 100%|██████████| 63/63 [00:03<00:00, 18.89it/s, loss=0.121]


Epoch 57/100, Loss: 7.6230


Epoch 58/100: 100%|██████████| 63/63 [00:03<00:00, 18.90it/s, loss=0.11] 


Epoch 58/100, Loss: 6.9334


Epoch 59/100: 100%|██████████| 63/63 [00:03<00:00, 18.79it/s, loss=0.102] 


Epoch 59/100, Loss: 6.4119


Epoch 60/100: 100%|██████████| 63/63 [00:03<00:00, 18.76it/s, loss=0.0941]


Epoch 60/100, Loss: 5.9267


Epoch 61/100: 100%|██████████| 63/63 [00:03<00:00, 18.73it/s, loss=0.0821]


Epoch 61/100, Loss: 5.1712


Epoch 62/100: 100%|██████████| 63/63 [00:03<00:00, 18.77it/s, loss=0.075] 


Epoch 62/100, Loss: 4.7267


Epoch 63/100: 100%|██████████| 63/63 [00:03<00:00, 18.80it/s, loss=0.0677]


Epoch 63/100, Loss: 4.2664


Epoch 64/100: 100%|██████████| 63/63 [00:03<00:00, 18.90it/s, loss=0.0603]


Epoch 64/100, Loss: 3.7370


Epoch 65/100: 100%|██████████| 63/63 [00:03<00:00, 18.87it/s, loss=0.0479]


Epoch 65/100, Loss: 3.0208


Epoch 66/100: 100%|██████████| 63/63 [00:03<00:00, 18.87it/s, loss=0.0417]


Epoch 66/100, Loss: 2.6261


Epoch 67/100: 100%|██████████| 63/63 [00:03<00:00, 18.83it/s, loss=0.0359]


Epoch 67/100, Loss: 2.2628


Epoch 68/100: 100%|██████████| 63/63 [00:03<00:00, 18.83it/s, loss=0.0305]


Epoch 68/100, Loss: 1.9245


Epoch 69/100: 100%|██████████| 63/63 [00:03<00:00, 18.75it/s, loss=0.03]  


Epoch 69/100, Loss: 1.8889


Epoch 70/100: 100%|██████████| 63/63 [00:03<00:00, 18.64it/s, loss=0.0503]


Epoch 70/100, Loss: 3.1697


Epoch 71/100: 100%|██████████| 63/63 [00:03<00:00, 18.70it/s, loss=0.0275]


Epoch 71/100, Loss: 1.7336


Epoch 72/100: 100%|██████████| 63/63 [00:03<00:00, 18.69it/s, loss=0.0288]


Epoch 72/100, Loss: 1.8115


Epoch 73/100: 100%|██████████| 63/63 [00:03<00:00, 18.74it/s, loss=0.0193]


Epoch 73/100, Loss: 1.2166


Epoch 74/100: 100%|██████████| 63/63 [00:03<00:00, 18.75it/s, loss=0.0167]


Epoch 74/100, Loss: 1.0515


Epoch 75/100: 100%|██████████| 63/63 [00:03<00:00, 18.75it/s, loss=0.0157]


Epoch 75/100, Loss: 0.9902


Epoch 76/100: 100%|██████████| 63/63 [00:03<00:00, 18.79it/s, loss=0.0143]


Epoch 76/100, Loss: 0.9006


Epoch 77/100: 100%|██████████| 63/63 [00:03<00:00, 18.78it/s, loss=0.0135]


Epoch 77/100, Loss: 0.8508


Epoch 78/100: 100%|██████████| 63/63 [00:03<00:00, 18.70it/s, loss=0.013] 


Epoch 78/100, Loss: 0.8172


Epoch 79/100: 100%|██████████| 63/63 [00:03<00:00, 18.73it/s, loss=0.0115] 


Epoch 79/100, Loss: 0.7268


Epoch 80/100: 100%|██████████| 63/63 [00:03<00:00, 18.70it/s, loss=0.0113]


Epoch 80/100, Loss: 0.7090


Epoch 81/100: 100%|██████████| 63/63 [00:03<00:00, 18.72it/s, loss=0.0155] 


Epoch 81/100, Loss: 0.9776


Epoch 82/100: 100%|██████████| 63/63 [00:03<00:00, 18.75it/s, loss=0.0197]


Epoch 82/100, Loss: 1.2399


Epoch 83/100: 100%|██████████| 63/63 [00:03<00:00, 18.75it/s, loss=0.0172]


Epoch 83/100, Loss: 1.0819


Epoch 84/100: 100%|██████████| 63/63 [00:03<00:00, 18.74it/s, loss=0.0122]


Epoch 84/100, Loss: 0.7686


Epoch 85/100: 100%|██████████| 63/63 [00:03<00:00, 18.74it/s, loss=0.0145]


Epoch 85/100, Loss: 0.9113


Epoch 86/100: 100%|██████████| 63/63 [00:03<00:00, 18.72it/s, loss=0.0144]


Epoch 86/100, Loss: 0.9076


Epoch 87/100: 100%|██████████| 63/63 [00:03<00:00, 18.73it/s, loss=0.0113]


Epoch 87/100, Loss: 0.7116


Epoch 88/100: 100%|██████████| 63/63 [00:03<00:00, 18.80it/s, loss=0.01]   


Epoch 88/100, Loss: 0.6298


Epoch 89/100: 100%|██████████| 63/63 [00:03<00:00, 18.88it/s, loss=0.00973]


Epoch 89/100, Loss: 0.6133


Epoch 90/100: 100%|██████████| 63/63 [00:03<00:00, 18.87it/s, loss=0.00883]


Epoch 90/100, Loss: 0.5564


Epoch 91/100: 100%|██████████| 63/63 [00:03<00:00, 18.76it/s, loss=0.00917]


Epoch 91/100, Loss: 0.5780


Epoch 92/100: 100%|██████████| 63/63 [00:03<00:00, 18.69it/s, loss=0.00882]


Epoch 92/100, Loss: 0.5556


Epoch 93/100: 100%|██████████| 63/63 [00:03<00:00, 18.76it/s, loss=0.00858]


Epoch 93/100, Loss: 0.5404


Epoch 94/100: 100%|██████████| 63/63 [00:03<00:00, 18.73it/s, loss=0.00677]


Epoch 94/100, Loss: 0.4263


Epoch 95/100: 100%|██████████| 63/63 [00:03<00:00, 18.58it/s, loss=0.00697]


Epoch 95/100, Loss: 0.4390


Epoch 96/100: 100%|██████████| 63/63 [00:03<00:00, 18.71it/s, loss=0.00661]


Epoch 96/100, Loss: 0.4165


Epoch 97/100: 100%|██████████| 63/63 [00:03<00:00, 18.74it/s, loss=0.00717]


Epoch 97/100, Loss: 0.4517


Epoch 98/100: 100%|██████████| 63/63 [00:03<00:00, 18.66it/s, loss=0.00824]


Epoch 98/100, Loss: 0.5189


Epoch 99/100: 100%|██████████| 63/63 [00:03<00:00, 18.70it/s, loss=0.00767]


Epoch 99/100, Loss: 0.4830


Epoch 100/100: 100%|██████████| 63/63 [00:03<00:00, 18.65it/s, loss=0.00712]


Epoch 100/100, Loss: 0.4488
Model saved at: trained_model/FoCus/model_epoch_100.pth


# Response Generation

In [17]:
DATASET = "FoCus"
MODEL_SAVE_DIR = f"trained_model/{DATASET}"
EPOCS = 100
MAX_LENGTH_TOKEN = 50

In [18]:
model = PersonaAdaptiveChatbot()  # Recreate the model architecture
model.load_state_dict(torch.load(f"{MODEL_SAVE_DIR}/model_epoch_{EPOCS}.pth"))
model.to(device)


PersonaAdaptiveChatbot(
  (persona_encoder): PersonaEncoder(
    (attention): SelfAttention(
      (query): Linear(in_features=768, out_features=768, bias=True)
      (key): Linear(in_features=768, out_features=768, bias=True)
      (value): Linear(in_features=768, out_features=768, bias=True)
      (softmax): Softmax(dim=-1)
    )
    (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (context_encoder): ContextEncoder(
    (attention): SelfAttention(
      (query): Linear(in_features=768, out_features=768, bias=True)
      (key): Linear(in_features=768, out_features=768, bias=True)
      (value): Linear(in_features=768, out_features=768, bias=True)
      (softmax): Softmax(dim=-1)
    )
    (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (paa): PersonaAdaptiveAttention(
    (context_attention): SelfAttention(
      (query): Linear(in_features=768, out_features=768, bias=True)
      (key): Linear(in_features=768, out_features=768, bias=True)
      (

In [19]:
import torch

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def generate_response(model, persona_text, context_text, max_length= MAX_LENGTH_TOKEN):
    model.to(device)  # Ensure model is on GPU
    model.eval()

    # Tokenize and move to GPU, ensuring padding for alignment
    persona_tokens = tokenizer(persona_text, return_tensors="pt", padding=True, truncation=True)["input_ids"].to(device)
    context_tokens = tokenizer(context_text, return_tensors="pt", padding=True, truncation=True)["input_ids"].to(device)

    with torch.no_grad():
        persona_embeds = model.decoder.gpt2.transformer.wte(persona_tokens)
        context_embeds = model.decoder.gpt2.transformer.wte(context_tokens)

        # Determine max sequence length for padding
        max_seq_len = max(persona_embeds.shape[1], context_embeds.shape[1])

        # Pad tensors to the same length
        pad_size_persona = max_seq_len - persona_embeds.shape[1]
        pad_size_context = max_seq_len - context_embeds.shape[1]

        persona_embeds = torch.nn.functional.pad(persona_embeds, (0, 0, 0, pad_size_persona), "constant", 0)
        context_embeds = torch.nn.functional.pad(context_embeds, (0, 0, 0, pad_size_context), "constant", 0)

        # Encode using persona and context encoders
        encoded_persona = model.persona_encoder(persona_embeds)
        encoded_context = model.context_encoder(context_embeds)

        # Apply Persona-Adaptive Attention (PAA)
        fused_representation = model.paa(encoded_persona, encoded_context)

    # Generate response ensuring input is on GPU
    generated = model.decoder.gpt2.generate(input_ids=context_tokens, max_length=max_length).to(device)
    return tokenizer.decode(generated[0], skip_special_tokens=True)

# Example
persona_text = "I love sci-fi movies."
context_text = "user1: What's your favorite movie?"
response = generate_response(model, persona_text, context_text)
print("Generated Response:", response)


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Generated Response: user1: What's your favorite movie?
