In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, classification_report
from sklearn.model_selection import train_test_split
from collections import Counter
from tqdm import tqdm
import numpy as np
import json
import re
import os
import math
from sklearn.metrics import precision_recall_curve

# Check GPU availability
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU device:", torch.cuda.get_device_name(0))
    print("GPU memory:", torch.cuda.get_device_properties(0).total_memory / 1e9, "GB")

# -------------------------r
# 1. Tokenization and Vocab
# -------------------------
def diplomacy_tokenizer(text):
    """Tokenize text, preserving punctuation as separate tokens."""
    text = text.lower()
    tokens = re.findall(r"\w+|[.,!?;]", text)
    return tokens

class Vocab:
    def __init__(self, min_freq=2):
        self.token_to_idx = {'<PAD>': 0, '<UNK>': 1}
        self.idx_to_token = ['<PAD>', '<UNK>']
        self.min_freq = min_freq
        self.harbinger_indices = set()  # Store harbinger token indices

    def build_vocab(self, texts, harbingers):
        counter = Counter()
        for text in texts:
            counter.update(diplomacy_tokenizer(text))
        for token, freq in counter.items():
            if freq >= self.min_freq:
                self.idx_to_token.append(token)
                self.token_to_idx[token] = len(self.idx_to_token) - 1
        # Identify harbinger indices
        self.harbinger_indices = set([self.token_to_idx[token] for token in harbingers
                                    if token in self.token_to_idx])

    def encode(self, tokens):
        return [self.token_to_idx.get(t, 1) for t in tokens]  # 1 = <UNK>

    def __len__(self):
        return len(self.idx_to_token)

# -------------------
# 2. Custom Dataset
# -------------------
class DiplomacyDataset(Dataset):
    def __init__(self, messages, labels, vocab, context_size=2, max_len=300):
        self.messages = messages
        self.labels = labels
        self.vocab = vocab
        self.context_size = context_size
        self.max_len = max_len
        self.encoded = [vocab.encode(diplomacy_tokenizer(msg)) for msg in messages]

    def __len__(self):
        return len(self.messages)

    def __getitem__(self, idx):
        context = []
        for i in range(max(0, idx - self.context_size), idx + 1):
            context.extend(self.encoded[i])
        if len(context) < self.max_len:
            context += [0] * (self.max_len - len(context))
        else:
            context = context[:self.max_len]
        # Create harbinger mask
        harbinger_mask = [1 if token in self.vocab.harbinger_indices else 0
                         for token in context]
        context_tensor = torch.tensor(context, dtype=torch.long)
        harbinger_mask_tensor = torch.tensor(harbinger_mask, dtype=torch.float)
        label = torch.tensor(self.labels[idx], dtype=torch.float)
        return context_tensor, harbinger_mask_tensor, label

# -------------------
# 3. Positional Encoding (Fixed)
# -------------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        # Handle both even and odd dimensions
        pe[:, 0::2] = torch.sin(position * div_term)
        if d_model % 2 == 0:
            pe[:, 1::2] = torch.cos(position * div_term)
        else:
            pe[:, 1::2] = torch.cos(position * div_term[:-1])  # Skip last when odd
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return x

# -------------------
# 4. Transformer-based Model with Harbingers (Fixed)
# -------------------
class TransformerClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, num_heads=4, num_layers=2, hidden_dim=256, dropout=0.3, max_len=300):
        super(TransformerClassifier, self).__init__()
        # Ensure embed_dim + 1 is divisible by num_heads
        assert (embed_dim + 1) % num_heads == 0, "embed_dim + 1 must be divisible by num_heads"

        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        d_model = embed_dim + 1  # Add 1 for harbinger feature
        self.pos_encoder = PositionalEncoding(d_model, max_len)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=num_heads,
            dim_feedforward=hidden_dim,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(d_model, 1)

    def forward(self, x, harb_mask):
        # x: (batch_size, max_len)
        embedded = self.embedding(x)  # (batch_size, max_len, embed_dim)
        harb_mask = harb_mask.unsqueeze(2)  # (batch_size, max_len, 1)
        combined = torch.cat([embedded, harb_mask], dim=2)  # (batch_size, max_len, d_model)
        combined = self.pos_encoder(combined)  # (batch_size, max_len, d_model)
        # Create mask for padding tokens
        mask = (x == 0)  # (batch_size, max_len)
        output = self.transformer_encoder(combined, src_key_padding_mask=mask)  # (batch_size, max_len, d_model)
        output = output.mean(dim=1)  # (batch_size, d_model)
        output = self.dropout(output)
        output = self.fc(output).squeeze(1)  # (batch_size,)
        return output

# -------------------
# 5. Training + Eval
# -------------------
# Add these imports at the top
import torch.cuda as cuda
import gc

# Modify the train_epoch function
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    losses = []
    total_batches = len(dataloader)
    
    # Clear GPU memory before training
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()
    
    pbar = tqdm(enumerate(dataloader), total=total_batches, desc='Training')
    
    for batch_idx, (x, harb_mask, y) in pbar:
        try:
            x = x.to(device)
            harb_mask = harb_mask.to(device)
            y = y.to(device)
            
            optimizer.zero_grad()
            logits = model(x, harb_mask)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()
            
            current_loss = loss.item()
            losses.append(current_loss)
            
            # Update progress bar with current batch and loss
            pbar.set_postfix({
                'batch': f'{batch_idx+1}/{total_batches}',
                'loss': f'{current_loss:.4f}',
                'avg_loss': f'{np.mean(losses):.4f}'
            })
            
            # Clear some memory
            del x, harb_mask, y, logits, loss
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                
        except Exception as e:
            print(f"Error in batch {batch_idx}: {str(e)}")
            continue
    
    return np.mean(losses)

def evaluate(model, dataloader, device):
    model.eval()
    all_logits = []
    all_labels = []
    with torch.no_grad():
        for x, harb_mask, y in tqdm(dataloader, desc='Evaluating'):
            x = x.to(device)
            harb_mask = harb_mask.to(device)
            y = y.to(device)
            logits = model(x, harb_mask)
            all_logits.extend(logits.cpu().tolist())
            all_labels.extend(y.cpu().tolist())
    return np.array(all_logits), np.array(all_labels)

def print_metrics(true, preds, probs, name=""):
    acc = accuracy_score(true, preds)
    f1 = f1_score(true, preds)
    auc = roc_auc_score(true, probs)
    print(f"\n{name} Set Performance")
    print(f"Accuracy: {acc:.4f}, F1: {f1:.4f}, AUC: {auc:.4f}")
    print(classification_report(true, preds))

# -------------------
# 6. Threshold Adjustment
# -------------------
def find_best_threshold(logits, labels):
    probs = torch.sigmoid(torch.tensor(logits)).numpy()
    precision, recall, thresholds = precision_recall_curve(labels, probs)
    f1_scores = 2 * (precision * recall) / (precision + recall + 1e-10)
    best_idx = np.argmax(f1_scores)
    best_threshold = thresholds[best_idx]
    best_f1 = f1_scores[best_idx]
    return best_threshold, best_f1

# -------------------
# 7. Run Training
# -------------------
def load_jsonl(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        return [json.loads(line) for line in f]

def preprocess_messages(data):
    messages, labels = [], []
    for game in data:
        game_messages = game.get("messages", [])
        sender_labels = game.get("sender_labels", [])
        if not game_messages or len(game_messages) < 2 or len(sender_labels) != len(game_messages):
            continue
        for i in range(len(game_messages)):
            is_deceptive = 0 if sender_labels[i] else 1  # 1 = deceptive, 0 = truthful
            messages.append(game_messages[i])
            labels.append(is_deceptive)

    num_deceptive = sum(labels)
    num_truthful = len(labels) - num_deceptive
    print(f"Processed {len(messages)} messages with {num_deceptive} deceptive and {num_truthful} truthful")
    return messages, labels, num_truthful, num_deceptive

def run_training():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Load data from local files
    train_data = load_jsonl("data/train.jsonl")
    val_data = load_jsonl("data/validation.jsonl") 
    test_data = load_jsonl("data/test.jsonl")

    print(f"Loaded {len(train_data)} training games")
    
    # Rest of the function remains the same
    train_msgs, train_labels, train_num_truthful, train_num_deceptive = preprocess_messages(train_data)
    val_msgs, val_labels, _, _ = preprocess_messages(val_data) 
    test_msgs, test_labels, _, _ = preprocess_messages(test_data)
    
    # ...existing code...

    train_msgs, train_labels, train_num_truthful, train_num_deceptive = preprocess_messages(train_data)
    val_msgs, val_labels, _, _ = preprocess_messages(val_data)
    test_msgs, test_labels, _, _ = preprocess_messages(test_data)

    # Define harbingers
    harbingers = [
    # 1. Uncertainty & Probability (40 terms)
    "maybe", "perhaps", "possibly", "might", "could", "would", "should",
    "potentially", "presumably", "likely", "unlikely", "probably", "possibly",
    "conceivably", "hopefully", "eventually", "ultimately", "definitely",
    "certainly", "surely", "absolutely", "undoubtedly", "clearly", "obviously",
    "apparently", "seemingly", "allegedly", "supposedly", "reportedly",
    "essentially", "basically", "fundamentally", "significantly", "considerably",
    "virtually", "practically", "nearly", "almost", "marginally", "somewhat",

    # 2. Cognitive & Mental State Verbs (35 terms)
    "think", "believe", "feel", "suppose", "guess", "wonder", "assume",
    "suspect", "estimate", "imagine", "figure", "reckon", "expect", "predict",
    "anticipate", "foresee", "presume", "infer", "deduce", "conclude",
    "gather", "surmise", "speculate", "theorize", "hypothesize", "sense",
    "perceive", "notice", "realize", "recognize", "understand", "know",
    "remember", "recall", "forget",

    # 3. Hedges & Approximators (30 terms)
    "sort", "kind", "rather", "quite", "somewhat", "slightly", "moderately",
    "relatively", "comparatively", "reasonably", "fairly", "pretty", "mostly",
    "mainly", "primarily", "partially", "largely", "substantially", "typically",
    "generally", "usually", "normally", "commonly", "regularly", "often",
    "frequently", "sometimes", "occasionally", "rarely", "seldom",

    # 4. Intensifiers & Emphatics (25 terms)
    "very", "really", "extremely", "absolutely", "totally", "completely",
    "utterly", "perfectly", "entirely", "thoroughly", "fully", "wholly",
    "downright", "positively", "simply", "just", "merely", "only",
    "literally", "actually", "honestly", "truly", "genuinely", "sincerely",
    "frankly",

    # 5. Evasive & Ambiguous Terms (30 terms)
    "about", "around", "approximately", "roughly", "nearly", "close to",
    "in the range of", "something like", "or so", "more or less", "give or take",
    "in general", "on the whole", "all things considered", "by and large",
    "for the most part", "to some extent", "in some ways", "in a sense",
    "in theory", "technically", "strictly speaking", "officially",
    "formally", "nominally", "effectively", "in effect", "in principle",
    "ideally", "theoretically",

    # 6. Time-based Hedges (20 terms)
    "now", "currently", "presently", "at present", "at the moment",
    "these days", "lately", "recently", "in recent times", "over time",
    "with time", "in time", "sooner or later", "eventually", "ultimately",
    "in the end", "at the end of the day", "when all is said and done",
    "in the long run", "in the final analysis",

    # 7. Source Distancing (15 terms)
    "according to", "as per", "based on", "in light of", "in view of",
    "given that", "seeing as", "considering", "taking into account",
    "from what I understand", "from what I gather", "from my perspective",
    "in my opinion", "to my knowledge", "as far as I know",

    # 8. Diplomatic-Specific Terms (25 terms)
    "diplomatically", "strategically", "tactically", "politically",
    "negotiable", "flexible", "adaptable", "revisable", "amendable",
    "subject to", "conditional upon", "dependent on", "contingent on",
    "pending", "awaiting", "considering", "reviewing", "evaluating",
    "assessing", "monitoring", "observing", "watching", "tracking",
    "following", "pursuant to"
]

    # Build vocab with harbingers
    vocab = Vocab(min_freq=2)
    vocab.build_vocab(train_msgs, harbingers)
    print(f"Vocabulary size: {len(vocab)}")

    # Calculate class weights (truthful / deceptive)
    pos_weight = train_num_truthful / train_num_deceptive if train_num_deceptive > 0 else 1.0
    print(f"Positive weight (truthful/deceptive): {pos_weight:.4f}")

    # Datasets and Loaders
    context_size = 2
    max_len = 300
    train_set = DiplomacyDataset(train_msgs, train_labels, vocab, context_size=context_size, max_len=max_len)
    val_set = DiplomacyDataset(val_msgs, val_labels, vocab, context_size=context_size, max_len=max_len)
    test_set = DiplomacyDataset(test_msgs, test_labels, vocab, context_size=context_size, max_len=max_len)

    # Change DataLoader configurations
    batch_size = 32
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)  # Changed from 2 to 0
    val_loader = DataLoader(val_set, batch_size=batch_size, num_workers=0)  # Changed from 2 to 0
    test_loader = DataLoader(test_set, batch_size=batch_size, num_workers=0)  # Changed from 2 to 0

    # Model setup - using embed_dim=127 to make d_model=128 (divisible by 4 heads)
    model = TransformerClassifier(
        vocab_size=len(vocab),
        embed_dim=127,  # Changed from 128 to 127 to make d_model=128
        num_heads=4,
        num_layers=2,
        hidden_dim=256,
        dropout=0.3,
        max_len=max_len
    ).to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight, dtype=torch.float).to(device))

    # For tracking best model
    best_val_f1 = 0
    best_model_path = "/kaggle/working/best_diplomacy_model.pth" if os.path.exists("/kaggle/working") else "best_diplomacy_model.pth"

    # Training loop (20 epochs)
    for epoch in range(8):
        print(f"\nStarting Epoch {epoch+1}...")
        loss = train_epoch(model, train_loader, optimizer, criterion, device)
        print(f"Epoch {epoch+1}, Loss: {loss:.4f}")

        # Evaluate on validation set
        val_logits, val_true = evaluate(model, val_loader, device)
        val_probs = torch.sigmoid(torch.tensor(val_logits)).numpy()
        val_preds = (val_probs > 0.5).astype(int)
        val_f1 = f1_score(val_true, val_preds)
        print(f"Validation F1 (threshold=0.5): {val_f1:.4f}")

        # Find best threshold for validation set
        best_threshold, best_f1 = find_best_threshold(val_logits, val_true)
        print(f"Best threshold: {best_threshold:.4f}, Best F1: {best_f1:.4f}")

        # Save best model based on F1 with best threshold
        if best_f1 > best_val_f1:
            best_val_f1 = best_f1
            torch.save(model.state_dict(), best_model_path)
            print(f"New best model saved with F1: {best_f1:.4f}")

    # Load best model
    model.load_state_dict(torch.load(best_model_path))

    # Final evaluation on validation and test sets
    for name, loader, labels in [("Validation", val_loader, val_labels), ("Test", test_loader, test_labels)]:
        logits, true = evaluate(model, loader, device)
        probs = torch.sigmoid(torch.tensor(logits)).numpy()
        preds = (probs > best_threshold).astype(int)
        print_metrics(true, preds, probs, name)

if __name__ == "__main__":
    run_training()

CUDA available: True
GPU device: NVIDIA GeForce RTX 4060 Laptop GPU
GPU memory: 8.585216 GB
Using device: cuda:0
Loaded 189 training games
Processed 13128 messages with 590 deceptive and 12538 truthful
Processed 1416 messages with 56 deceptive and 1360 truthful
Processed 2741 messages with 240 deceptive and 2501 truthful
Processed 13128 messages with 590 deceptive and 12538 truthful
Processed 1416 messages with 56 deceptive and 1360 truthful
Processed 2741 messages with 240 deceptive and 2501 truthful
Vocabulary size: 4650
Positive weight (truthful/deceptive): 21.2508

Starting Epoch 1...


Training: 100%|██████████| 411/411 [00:08<00:00, 50.49it/s, batch=411/411, loss=3.3417, avg_loss=1.3294]


Epoch 1, Loss: 1.3294


  output = torch._nested_tensor_from_mask(
Evaluating: 100%|██████████| 45/45 [00:00<00:00, 225.97it/s]


Validation F1 (threshold=0.5): 0.0761
Best threshold: 0.5452, Best F1: 0.1190
New best model saved with F1: 0.1190

Starting Epoch 2...


Training: 100%|██████████| 411/411 [00:07<00:00, 53.44it/s, batch=411/411, loss=0.7202, avg_loss=1.2983]


Epoch 2, Loss: 1.2983


Evaluating: 100%|██████████| 45/45 [00:00<00:00, 287.47it/s]


Validation F1 (threshold=0.5): 0.0761
Best threshold: 0.5864, Best F1: 0.1351
New best model saved with F1: 0.1351

Starting Epoch 3...


Training: 100%|██████████| 411/411 [00:07<00:00, 53.99it/s, batch=411/411, loss=0.7539, avg_loss=1.2847]


Epoch 3, Loss: 1.2847


Evaluating: 100%|██████████| 45/45 [00:00<00:00, 258.32it/s]

Validation F1 (threshold=0.5): 0.0761
Best threshold: 0.5761, Best F1: 0.1370
New best model saved with F1: 0.1370

Starting Epoch 4...



Training: 100%|██████████| 411/411 [00:07<00:00, 53.34it/s, batch=411/411, loss=2.3014, avg_loss=1.2670]


Epoch 4, Loss: 1.2670


Evaluating: 100%|██████████| 45/45 [00:00<00:00, 256.37it/s]


Validation F1 (threshold=0.5): 0.0761
Best threshold: 0.5759, Best F1: 0.1287

Starting Epoch 5...


Training: 100%|██████████| 411/411 [00:07<00:00, 51.70it/s, batch=411/411, loss=1.7109, avg_loss=1.2162]


Epoch 5, Loss: 1.2162


Evaluating: 100%|██████████| 45/45 [00:00<00:00, 196.90it/s]


Validation F1 (threshold=0.5): 0.0784
Best threshold: 0.5738, Best F1: 0.1505
New best model saved with F1: 0.1505

Starting Epoch 6...


Training: 100%|██████████| 411/411 [00:07<00:00, 51.64it/s, batch=411/411, loss=0.8482, avg_loss=1.1583]


Epoch 6, Loss: 1.1583


Evaluating: 100%|██████████| 45/45 [00:00<00:00, 186.12it/s]


Validation F1 (threshold=0.5): 0.0793
Best threshold: 0.6212, Best F1: 0.1618
New best model saved with F1: 0.1618

Starting Epoch 7...


Training: 100%|██████████| 411/411 [00:07<00:00, 51.98it/s, batch=411/411, loss=1.9822, avg_loss=1.1081]


Epoch 7, Loss: 1.1081


Evaluating: 100%|██████████| 45/45 [00:00<00:00, 179.58it/s]


Validation F1 (threshold=0.5): 0.0858
Best threshold: 0.5868, Best F1: 0.1519

Starting Epoch 8...


Training: 100%|██████████| 411/411 [00:08<00:00, 51.26it/s, batch=411/411, loss=0.4386, avg_loss=1.0523]


Epoch 8, Loss: 1.0523


Evaluating: 100%|██████████| 45/45 [00:00<00:00, 190.69it/s]
  model.load_state_dict(torch.load(best_model_path))


Validation F1 (threshold=0.5): 0.0879
Best threshold: 0.5891, Best F1: 0.1529


Evaluating: 100%|██████████| 45/45 [00:00<00:00, 243.61it/s]



Validation Set Performance
Accuracy: 0.8531, F1: 0.1261, AUC: 0.6365
              precision    recall  f1-score   support

         0.0       0.97      0.88      0.92      1360
         1.0       0.08      0.27      0.13        56

    accuracy                           0.85      1416
   macro avg       0.52      0.57      0.52      1416
weighted avg       0.93      0.85      0.89      1416



Evaluating: 100%|██████████| 86/86 [00:00<00:00, 245.21it/s]


Test Set Performance
Accuracy: 0.8179, F1: 0.2067, AUC: 0.6287
              precision    recall  f1-score   support

         0.0       0.93      0.87      0.90      2501
         1.0       0.17      0.27      0.21       240

    accuracy                           0.82      2741
   macro avg       0.55      0.57      0.55      2741
weighted avg       0.86      0.82      0.84      2741




