# Big Five Personality Recognition using LUKE on RealPersonaChat

**Project Overview:**
- Dataset: RealPersonaChat (14,000 dialogues, 233 speakers) **← Using full dataset**
- Task: Big Five personality trait regression prediction
- Model: LUKE (studio-ousia/luke-japanese-base)
- Setting: Monologue (speaker's utterances only)

**Evaluation Metrics:**
- Regression: MAE, RMSE, Pearson, Spearman correlation
- Classification: Accuracy, Balanced Accuracy, Precision, Recall, F1

**Memory Optimization (T4 GPU):**
- Batch size: 4
- Max Length: 256 tokens
- Gradient accumulation: 8 steps (effective batch size 32)
- Gradient checkpointing: Enabled
- Mixed Precision (FP16): Enabled

**Execution Instructions:**

Execute all cells in order from top to bottom.

**Data Loading:**
- Speaker data: Download from GitHub (233 speakers)
- Dialogue data: Download from GitHub (14,000 dialogues, takes 30-40 minutes)

**Recommended Environment:**
- Google Colab with T4 GPU (15GB VRAM)

## 1. Environment Setup

In [None]:
# Import libraries
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import (
    AutoTokenizer,
    LukeModel,
    get_linear_schedule_with_warmup
)
import json
import requests
from pathlib import Path
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    mean_absolute_error,
    mean_squared_error,
    accuracy_score,
    balanced_accuracy_score,
    precision_recall_fscore_support
)
from scipy.stats import pearsonr, spearmanr
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

# Configure device (GPU/CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

# Model and data configuration
MODEL_NAME = "studio-ousia/luke-japanese-base"
BIG_FIVE_TRAITS = ['Openness', 'Conscientiousness', 'Extraversion', 'Agreeableness', 'Neuroticism']
CHECKPOINT_DIR = "/content/checkpoints"

In [None]:
## 2. Data Loading

Load RealPersonaChat dataset directly from GitHub.
- Speaker data: `interlocutors.json` (233 speakers)
- Dialogue data: `dialogues/*.json` (14,000 dialogues)

**Data Size:**
- **Production (Recommended)**: `num_dialogues = 14000` - Full dataset
- Test: `num_dialogues = 1000` - Partial dataset for quick testing

**Note**: 
- Initial data loading takes 30-40 minutes for full dataset (14,000 dialogues)
- Using full data is strongly recommended for better results

## 2. Data Loading

Load RealPersonaChat dataset directly from GitHub.
- Speaker data: `interlocutors.json` (233 speakers)
- Dialogue data: `dialogues/*.json` (14,000 dialogues)

**Data Size:**
- **Production (Recommended)**: `num_dialogues = 14000` - Full dataset
- Test: `num_dialogues = 1000` - Partial dataset for quick testing

**Note**: 
- Initial data loading takes 30-40 minutes for full dataset (14,000 dialogues)
- Using full data is strongly recommended for better results

In [None]:
# Load dataset directly from GitHub
import json
import requests
from tqdm.auto import tqdm

BASE_URL = "https://raw.githubusercontent.com/nu-dialogue/real-persona-chat/main/real_persona_chat"

# Load speaker data (interlocutors)
print("="*70)
print("Loading speaker data...")
print("="*70)

interlocutors_url = f"{BASE_URL}/interlocutors.json"
response = requests.get(interlocutors_url)
interlocutors_raw = response.json()

# Convert data structure to dictionary
if isinstance(interlocutors_raw, dict):
    interlocutor_dict = interlocutors_raw
elif isinstance(interlocutors_raw, list):
    interlocutor_dict = {
        item['interlocutor_id']: item
        for item in interlocutors_raw
    }
else:
    print(f"Unexpected data type: {type(interlocutors_raw)}")
    interlocutor_dict = {}

print(f"Loaded {len(interlocutor_dict)} speakers")

# Display sample
if interlocutor_dict:
    first_speaker_id = list(interlocutor_dict.keys())[0]
    sample_speaker = interlocutor_dict[first_speaker_id]
    print(f"\nSample speaker (ID: {first_speaker_id}):")
    print(f"  Keys: {list(sample_speaker.keys())}")
    if 'personality' in sample_speaker:
        print(f"  Personality keys: {list(sample_speaker['personality'].keys())[:10]}")

# Load dialogue data
print("\n" + "="*70)
print("Loading dialogue data...")
print("="*70)

num_dialogues = 1000  # Set to 14000 for full dataset

print(f"Downloading {num_dialogues} dialogues from GitHub...")
print("(This may take 30-40 minutes for full dataset)")

dialogue_data = []
failed_downloads = 0

for i in tqdm(range(1, num_dialogues + 1), desc="Downloading dialogues"):
    dialogue_id = f"{i:05d}"
    dialogue_url = f"{BASE_URL}/dialogues/{dialogue_id}.json"
    
    try:
        response = requests.get(dialogue_url)
        if response.status_code == 200:
            dialogue = response.json()
            dialogue_data.append(dialogue)
        else:
            failed_downloads += 1
    except Exception as e:
        failed_downloads += 1
        if i <= 10:  # Only show first 10 errors
            print(f"Failed to download {dialogue_id}: {e}")

if failed_downloads > 0:
    print(f"\nDownload failed: {failed_downloads} dialogues")

print(f"Loaded {len(dialogue_data)} dialogues")

# Display sample
if dialogue_data:
    sample = dialogue_data[0]
    print(f"\nSample dialogue:")
    print(f"  Keys: {list(sample.keys())}")
    print(f"  Dialogue ID: {sample.get('dialogue_id', 'N/A')}")
    print(f"  Interlocutors: {sample.get('interlocutors', [])}")
    print(f"  Utterances: {len(sample.get('utterances', []))}")

print("="*70)

In [None]:
def create_monologue_dataset(dialogue_data, interlocutor_dict):
    """
    Extract monologues for each speaker from dialogue data

    Args:
        dialogue_data: List of dialogue dictionaries
        interlocutor_dict: Dictionary of speaker ID -> speaker information

    Returns:
        monologues: List[Dict]
            - 'speaker_id': str
            - 'text': str (concatenated utterances)
            - 'personality': Dict[str, float] (Big Five scores)
    """
    monologues = []

    # Process each dialogue
    for dialogue in tqdm(dialogue_data, desc="Processing dialogues"):
        speaker_utterances = {}

        # Collect utterances by speaker
        for utterance in dialogue.get('utterances', []):
            speaker_id = utterance.get('interlocutor_id')
            text = utterance.get('text', '')

            if not speaker_id or not text:
                continue

            if speaker_id not in speaker_utterances:
                speaker_utterances[speaker_id] = []
            speaker_utterances[speaker_id].append(text)

        # Create monologue for each speaker
        for speaker_id, utterances in speaker_utterances.items():
            # Skip if speaker info not available
            if speaker_id not in interlocutor_dict:
                continue

            speaker_info = interlocutor_dict[speaker_id]
            personality_data = speaker_info.get('personality', {})

            # Extract Big Five scores (1-7 scale)
            big_five_scores = {
                'Openness': personality_data.get('big_five_openness', personality_data.get('openness', 4.0)),
                'Conscientiousness': personality_data.get('big_five_conscientiousness', personality_data.get('conscientiousness', 4.0)),
                'Extraversion': personality_data.get('big_five_extraversion', personality_data.get('extraversion', 4.0)),
                'Agreeableness': personality_data.get('big_five_agreeableness', personality_data.get('agreeableness', 4.0)),
                'Neuroticism': personality_data.get('big_five_neuroticism', personality_data.get('neuroticism', 4.0)),
            }

            # Check if scores are properly obtained (not all default values)
            if all(score == 4.0 for score in big_five_scores.values()):
                # Try alternative key names
                for key, value in personality_data.items():
                    if 'openness' in key.lower():
                        big_five_scores['Openness'] = value
                    elif 'conscientiousness' in key.lower():
                        big_five_scores['Conscientiousness'] = value
                    elif 'extraversion' in key.lower():
                        big_five_scores['Extraversion'] = value
                    elif 'agreeableness' in key.lower():
                        big_five_scores['Agreeableness'] = value
                    elif 'neuroticism' in key.lower():
                        big_five_scores['Neuroticism'] = value

            # Add monologue sample
            monologues.append({
                'speaker_id': speaker_id,
                'text': ' '.join(utterances),
                'personality': big_five_scores
            })

    return monologues

# Create monologue dataset
monologue_data = create_monologue_dataset(dialogue_data, interlocutor_dict)

print(f"\nMonologue samples: {len(monologue_data)}")
if monologue_data:
    print(f"Sample: Speaker {monologue_data[0]['speaker_id']}, Text length: {len(monologue_data[0]['text'])} chars")

In [None]:
def split_by_speaker(monologue_data, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1, seed=42):
    """
    Split data ensuring no speaker overlap between train/val/test sets
    Following the paper's approach (Train/Val/Test = 8:1:1)
    """
    # Group samples by speaker
    speaker_groups = {}
    for sample in monologue_data:
        speaker_id = sample['speaker_id']
        if speaker_id not in speaker_groups:
            speaker_groups[speaker_id] = []
        speaker_groups[speaker_id].append(sample)

    # Shuffle speakers
    speakers = list(speaker_groups.keys())
    np.random.seed(seed)
    np.random.shuffle(speakers)

    # Split speakers into train/val/test
    n_speakers = len(speakers)
    n_train = int(n_speakers * train_ratio)
    n_val = int(n_speakers * val_ratio)

    train_speakers = speakers[:n_train]
    val_speakers = speakers[n_train:n_train+n_val]
    test_speakers = speakers[n_train+n_val:]

    # Extract samples for each split
    train_data = [s for spk in train_speakers for s in speaker_groups[spk]]
    val_data = [s for spk in val_speakers for s in speaker_groups[spk]]
    test_data = [s for spk in test_speakers for s in speaker_groups[spk]]

    print(f"\nData split:")
    print(f"  Train: {len(train_speakers)} speakers, {len(train_data)} samples")
    print(f"  Val:   {len(val_speakers)} speakers, {len(val_data)} samples")
    print(f"  Test:  {len(test_speakers)} speakers, {len(test_data)} samples")

    return train_data, val_data, test_data

# Split data
train_data, val_data, test_data = split_by_speaker(monologue_data)

## 3. Data Preprocessing

- Encode text with LUKE tokenizer
- Normalize Big Five scores (1-7 → 0-1)

## 4. Model Construction

LUKE + 5 regression heads (one for each Big Five trait)

**Note**: Clear GPU memory before model construction.

In [None]:
# Clear GPU memory before model initialization
import gc
torch.cuda.empty_cache()
gc.collect()

class LukePersonalityModel(nn.Module):
    """
    LUKE-based personality prediction model
    Architecture: LUKE encoder + 5 regression heads (one per trait)
    """
    def __init__(self, model_name=MODEL_NAME, num_traits=5):
        super().__init__()
        # Load pre-trained LUKE model
        self.luke = LukeModel.from_pretrained(model_name)
        
        # Enable gradient checkpointing for memory efficiency
        self.luke.gradient_checkpointing_enable()
        
        self.hidden_size = self.luke.config.hidden_size
        
        # Create regression head for each Big Five trait
        self.regression_heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(self.hidden_size, 256),
                nn.ReLU(),
                nn.Dropout(0.1),
                nn.Linear(256, 1)
            )
            for _ in range(num_traits)
        ])
    
    def forward(self, input_ids, attention_mask):
        # Encode with LUKE
        outputs = self.luke(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        
        # Extract [CLS] token representation
        pooled_output = outputs.last_hidden_state[:, 0, :]
        
        # Predict each trait
        predictions = []
        for head in self.regression_heads:
            pred = head(pooled_output)
            predictions.append(pred)
        
        # Concatenate predictions and apply sigmoid (0-1 range)
        predictions = torch.cat(predictions, dim=1)
        predictions = torch.sigmoid(predictions)
        
        return predictions

# Initialize model
model = LukePersonalityModel().to(device)

print(f"Model initialized: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M parameters")
if torch.cuda.is_available():
    print(f"GPU memory: {torch.cuda.memory_allocated(0) / 1e9:.2f}GB allocated")

## 5. Training Configuration

In [None]:
def evaluate_model(model, dataloader, device):
    """
    Evaluate model on given dataloader
    Returns predictions, labels, and average loss
    """
    model.eval()
    all_predictions = []
    all_labels = []
    total_loss = 0

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # Forward pass
            predictions = model(input_ids, attention_mask)
            loss = criterion(predictions, labels)

            # Collect results
            total_loss += loss.item()
            all_predictions.append(predictions.cpu().numpy())
            all_labels.append(labels.cpu().numpy())

    # Stack all batches
    all_predictions = np.vstack(all_predictions)
    all_labels = np.vstack(all_labels)
    avg_loss = total_loss / len(dataloader)

    return all_predictions, all_labels, avg_loss

In [None]:
# Save preprocessed data and model state to Google Drive
# This enables fast resume if training loop crashes
import pickle

PREPROCESSED_CACHE = CACHE_DIR / "preprocessed_training_state.pkl"

# Package all training state
training_state = {
    'train_data': train_data,
    'val_data': val_data,
    'test_data': test_data,
    'model_state': model.state_dict(),
    'config': {
        'MODEL_NAME': MODEL_NAME,
        'BIG_FIVE_TRAITS': BIG_FIVE_TRAITS,
        'BATCH_SIZE': BATCH_SIZE,
        'CHECKPOINT_DIR': CHECKPOINT_DIR,
        'LEARNING_RATE': LEARNING_RATE,
        'NUM_EPOCHS': NUM_EPOCHS,
        'WARMUP_STEPS': WARMUP_STEPS,
        'EARLY_STOPPING_PATIENCE': EARLY_STOPPING_PATIENCE
    }
}

# Save to cache
with open(PREPROCESSED_CACHE, 'wb') as f:
    pickle.dump(training_state, f)

print(f"Saved preprocessed data: {PREPROCESSED_CACHE.stat().st_size / 1e6:.1f} MB")
print("Next time: Run cells 1-3, then cell 20 to restore, then training loop")

## 6. Training Loop

Start training.

## 7. Test Set Evaluation

In [None]:
# Load best model and evaluate on test set
checkpoint = torch.load(f"{CHECKPOINT_DIR}/best_model.pt")
model.load_state_dict(checkpoint['model_state_dict'])

print(f"Best model loaded: Epoch {checkpoint['epoch']+1}, Val Loss {checkpoint['val_loss']:.4f}")

# Evaluate on test set
test_predictions, test_labels, test_loss = evaluate_model(model, test_loader, device)

print(f"\nTest Loss: {test_loss:.4f}")
print(f"Test MAE: {mean_absolute_error(test_labels, test_predictions):.4f}")

## 8. Visualization

In [None]:
def plot_training_history(history):
    """Plot training and validation loss curves"""
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))

    # Loss curves
    axes[0].plot(history['train_loss'], label='Train Loss', marker='o')
    axes[0].plot(history['val_loss'], label='Val Loss', marker='s')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss (MAE)')
    axes[0].set_title('Training and Validation Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)

    # MAE curve
    axes[1].plot(history['val_mae'], label='Val MAE', marker='o', color='green')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('MAE')
    axes[1].set_title('Validation MAE')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(f'{CHECKPOINT_DIR}/training_history.png', dpi=300, bbox_inches='tight')
    plt.show()

# Plot training history
plot_training_history(history)

In [None]:
def plot_predictions(predictions, labels, trait_names=BIG_FIVE_TRAITS):
    """Plot prediction vs true value scatter plots for each trait"""
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    axes = axes.flatten()

    for i, trait in enumerate(trait_names):
        ax = axes[i]

        pred = predictions[:, i]
        true = labels[:, i]

        # Scatter plot
        ax.scatter(true, pred, alpha=0.5, s=20)

        # Ideal line (perfect prediction)
        min_val = min(true.min(), pred.min())
        max_val = max(true.max(), pred.max())
        ax.plot([min_val, max_val], [min_val, max_val], 'r--', lw=2, label='Ideal')

        # Compute metrics for title
        pearson, _ = pearsonr(true, pred)
        mae = mean_absolute_error(true, pred)

        ax.set_xlabel('True Score')
        ax.set_ylabel('Predicted Score')
        ax.set_title(f'{trait}\nPearson: {pearson:.3f}, MAE: {mae:.3f}')
        ax.legend()
        ax.grid(True, alpha=0.3)

    # Hide unused subplot
    axes[-1].axis('off')

    plt.tight_layout()
    plt.savefig(f'{CHECKPOINT_DIR}/predictions_scatter.png', dpi=300, bbox_inches='tight')
    plt.show()

# Plot predictions
plot_predictions(test_predictions, test_labels)

In [None]:
## 5. Training Loop

Start training.

In [None]:
from sklearn.metrics import confusion_matrix

def plot_confusion_matrices(predictions, labels, trait_names=BIG_FIVE_TRAITS):
    """Plot confusion matrices for binarized predictions (High/Low by median)"""
    medians = np.median(labels, axis=0)

    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    axes = axes.flatten()

    for i, trait in enumerate(trait_names):
        ax = axes[i]

        pred = predictions[:, i]
        true = labels[:, i]
        median = medians[i]

        # Binarize by median
        pred_binary = (pred > median).astype(int)
        true_binary = (true > median).astype(int)

        # Compute confusion matrix
        cm = confusion_matrix(true_binary, pred_binary)

        # Plot heatmap
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax,
                    xticklabels=['Low', 'High'],
                    yticklabels=['Low', 'High'])
        ax.set_xlabel('Predicted')
        ax.set_ylabel('True')
        ax.set_title(f'{trait} - Confusion Matrix')

    # Hide unused subplot
    axes[-1].axis('off')

    plt.tight_layout()
    plt.savefig(f'{CHECKPOINT_DIR}/confusion_matrices.png', dpi=300, bbox_inches='tight')
    plt.show()

# Plot confusion matrices
plot_confusion_matrices(test_predictions, test_labels)