In [9]:
import pandas as pd
import numpy as np
import os
import joblib 
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from sklearn.model_selection import train_test_split
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    get_linear_schedule_with_warmup
)
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
import re
from tqdm.auto import tqdm

DATA_FILE = "conversationfile.xlsx - userAuserB.csv"
MODEL_NAME = 'distilgpt2'
MAX_CONTEXT_TURNS = 5 
MAX_LENGTH = 256
SEED = 42
EPOCHS = 3
BATCH_SIZE = 4 
LEARNING_RATE = 5e-5
WARMUP_STEPS = 100

DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

SPECIAL_TOKENS = {
    'user_a': '<USER_A>',
    'user_b': '<USER_B>',
    'reply_start': '<REPLY>', 
    'bos_token': '<BOS>',
    'eos_token': '<EOS>',
    'pad_token': '<PAD>'
}
TOKEN_LIST = list(SPECIAL_TOKENS.values())

df = pd.read_csv(DATA_FILE)
df['Timestamp'] = pd.to_datetime(df['Timestamp'])
df = df.sort_values(by=['Conversation ID', 'Timestamp']).reset_index(drop=True)

def format_conversation_for_gpt2(df, max_context_turns):
    conversations = []
    
    for conv_id, conv_df in df.groupby('Conversation ID'):
        conv_list = conv_df.to_dict('records')
        
        for i in range(1, len(conv_list)):
            current_msg = conv_list[i]
            prev_msg = conv_list[i-1]
            
            if current_msg['Sender'] == 'User A' and prev_msg['Sender'] == 'User B':
                
                history_start_idx = max(0, i - max_context_turns - 1)
                context_history = conv_list[history_start_idx : i-1]
                
                context_parts = [
                    f"{SPECIAL_TOKENS['user_a'] if msg['Sender'] == 'User A' else SPECIAL_TOKENS['user_b']} {msg['Message']}"
                    for msg in context_history
                ]
                
                full_text = (
                    f"{SPECIAL_TOKENS['bos_token']} "
                    f"{' '.join(context_parts).strip()} " 
                    f"{SPECIAL_TOKENS['user_b']} {prev_msg['Message']} "
                    f"{SPECIAL_TOKENS['reply_start']} " 
                    f"{current_msg['Message']} " 
                    f"{SPECIAL_TOKENS['eos_token']}"
                )
                
                conversations.append(full_text)
            
    return conversations

formatted_conversations = format_conversation_for_gpt2(df, MAX_CONTEXT_TURNS)
train_texts, test_texts = train_test_split(formatted_conversations, test_size=0.1, random_state=SEED)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.add_special_tokens({
    'additional_special_tokens': [SPECIAL_TOKENS['user_a'], SPECIAL_TOKENS['user_b'], SPECIAL_TOKENS['reply_start']],
    'bos_token': SPECIAL_TOKENS['bos_token'],
    'eos_token': SPECIAL_TOKENS['eos_token'],
    'pad_token': SPECIAL_TOKENS['pad_token']
})

def tokenize_texts(texts):
    return tokenizer(
        texts,
        max_length=MAX_LENGTH,
        truncation=True,
        padding="max_length",
        return_tensors='pt'
    )

train_encodings = tokenize_texts(train_texts)
test_encodings = tokenize_texts(test_texts)


class ConversationDataset(Dataset):
    def __init__(self, encodings):
        self.input_ids = encodings['input_ids']
        self.attention_mask = encodings['attention_mask']

    def __getitem__(self, idx):
        return {
            'input_ids': self.input_ids[idx],
            'attention_mask': self.attention_mask[idx],
            'labels': self.input_ids[idx].clone() 
        }

    def __len__(self):
        return len(self.input_ids)

train_dataset = ConversationDataset(train_encodings)
test_dataset = ConversationDataset(test_encodings)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
model.resize_token_embeddings(len(tokenizer))
model.config.pad_token_id = tokenizer.pad_token_id
model.to(DEVICE)

optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
total_steps = len(train_dataloader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=WARMUP_STEPS, num_training_steps=total_steps)

def train_epoch(model, dataloader, optimizer, scheduler, device):
    model.train()
    total_loss = 0
    
    for batch in tqdm(dataloader, desc="Training"):
        optimizer.zero_grad()
        
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        total_loss += loss.item()

    return total_loss / len(dataloader)

def evaluate_loss(model, dataloader, device):
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluation"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            total_loss += outputs.loss.item()
            
    avg_loss = total_loss / len(dataloader)
    perplexity = np.exp(avg_loss)
    return avg_loss, perplexity

best_val_loss = float('inf')

for epoch in range(EPOCHS):
    train_loss = train_epoch(model, train_dataloader, optimizer, scheduler, DEVICE)
    val_loss, perplexity = evaluate_loss(model, test_dataloader, DEVICE)
    
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Val Loss: {val_loss:.4f} | Perplexity: {perplexity:.2f}")
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_model.pt') 
        print("  -> Saved best model checkpoint.")

model.load_state_dict(torch.load('best_model.pt', map_location=DEVICE))
model.eval()

def generate_reply(model, tokenizer, last_user_b_message, context_history_list, max_new_tokens=50):
    
    model.eval()
    history_string = ' '.join(context_history_list)
    
    prompt = (
        f"{SPECIAL_TOKENS['bos_token']} "
        f"{history_string.strip()} "
        f"{SPECIAL_TOKENS['user_b']} {last_user_b_message} "
        f"{SPECIAL_TOKENS['reply_start']}" 
    )
    
    input_ids = tokenizer.encode(prompt, return_tensors='pt', max_length=MAX_LENGTH, truncation=True).to(DEVICE)
    
    with torch.no_grad():
        output = model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            pad_token_id=tokenizer.pad_token_id,
            do_sample=True,
            top_k=50,
            top_p=0.95,
            temperature=0.7,
            num_return_sequences=1
        )

    generated_text = tokenizer.decode(output[0], skip_special_tokens=False)
    
    try:
        reply_start_index = generated_text.find(SPECIAL_TOKENS['reply_start']) + len(SPECIAL_TOKENS['reply_start'])
        reply_end_index = generated_text.find(SPECIAL_TOKENS['eos_token'], reply_start_index)
        
        reply = generated_text[reply_start_index:reply_end_index].strip()
        
        for token in TOKEN_LIST:
            reply = reply.replace(token, '').strip()
            
        reply_parts = re.split(r'[.?!]', reply.strip())
        return reply_parts[0].strip()
        
    except:
        return "Generation Error: Could not parse reply."


def extract_context_and_reference(formatted_text):
    try:
        reply_start = formatted_text.find(SPECIAL_TOKENS['reply_start']) + len(SPECIAL_TOKENS['reply_start'])
        reply_end = formatted_text.find(SPECIAL_TOKENS['eos_token'])
        reference = formatted_text[reply_start:reply_end].strip()

        prompt_text = formatted_text[:formatted_text.find(SPECIAL_TOKENS['reply_start'])].strip()
        prompt_text = prompt_text.replace(SPECIAL_TOKENS['bos_token'], '').strip()
        
        last_b_idx = prompt_text.rfind(SPECIAL_TOKENS['user_b'])
        
        last_user_b_message = prompt_text[last_b_idx + len(SPECIAL_TOKENS['user_b']):].strip()
        context_history_string = prompt_text[:last_b_idx].strip()
        
        history_list = [
            turn.strip() for turn in context_history_string.split() 
            if turn.strip() and turn.strip() not in TOKEN_LIST
        ]

        return last_user_b_message, context_history_string.split(), reference
    except:
        return None, None, None


def evaluate_generation(model, tokenizer, test_texts):
    
    generated_replies = []
    reference_replies = []
    
    for text in test_texts[:50]:
        last_b_msg, history_list, reference = extract_context_and_reference(text)
        
        if last_b_msg and reference and reference.strip():
            generated = generate_reply(model, tokenizer, last_b_msg, history_list)
            
            generated_replies.append(generated)
            reference_replies.append(reference)

    smooth_function = SmoothingFunction().method4
    bleu_scores = [
        sentence_bleu([ref.split()], gen.split(), smoothing_function=smooth_function)
        for ref, gen in zip(reference_replies, generated_replies)
    ]
    avg_bleu = np.mean(bleu_scores) if bleu_scores else 0.0

    scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
    rouge_l_scores = [
        scorer.score(ref, gen)['rougeL'].fmeasure
        for ref, gen in zip(reference_replies, generated_replies)
    ]
    avg_rouge_l = np.mean(rouge_l_scores) if rouge_l_scores else 0.0
    
    final_loss, perplexity = evaluate_loss(model, test_dataloader, DEVICE)

    print(f"--- Final Evaluation Results ---")
    print(f"Number of samples evaluated: {len(generated_replies)}")
    print(f"Average BLEU Score: {avg_bleu:.4f}")
    print(f"Average ROUGE-L F1: {avg_rouge_l:.4f}")
    print(f"Final Perplexity: {perplexity:.2f}")

    return {
        'avg_bleu': avg_bleu,
        'avg_rouge_l': avg_rouge_l,
        'perplexity': perplexity
    }

Training:   0%|          | 0/2 [00:00<?, ?it/s]

Evaluation:   0%|          | 0/1 [00:00<?, ?it/s]


Epoch 1/3
  Train Loss: 13.1087
  Val Loss: 12.7720 | Perplexity: 352230.92
  -> Saved best model checkpoint.


Training:   0%|          | 0/2 [00:00<?, ?it/s]

Evaluation:   0%|          | 0/1 [00:00<?, ?it/s]


Epoch 2/3
  Train Loss: 13.1682
  Val Loss: 12.5049 | Perplexity: 269654.85
  -> Saved best model checkpoint.


Training:   0%|          | 0/2 [00:00<?, ?it/s]

Evaluation:   0%|          | 0/1 [00:00<?, ?it/s]


Epoch 3/3
  Train Loss: 12.9028
  Val Loss: 12.1080 | Perplexity: 181310.87
  -> Saved best model checkpoint.
