In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Sampler
import pandas as pd
import numpy as np
from sklearn.metrics import classification_report

In [2]:
class DiplomacyDataset(Dataset):
    def __init__(self, hdf5_file):
        # Load the processed HDF5 data
        self.data = pd.read_hdf(hdf5_file, key='diplomacy_data')
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        # Get message embedding, context embeddings, and label
        message_emb = torch.tensor(self.data['message_embedding'].iloc[idx], dtype=torch.float32)
        context_emb = torch.tensor(self.data['context_embeddings'].iloc[idx], dtype=torch.float32)
        label = torch.tensor(self.data['label'].iloc[idx], dtype=torch.long)
        
        return {
            'message_emb': message_emb,    # Shape: (300,)
            'context_emb': context_emb,    # Shape: (context_window, 300)
            'label': label                 # Scalar: 0 or 1
        }

# Example DataLoader setup
def create_dataloaders(train_file, val_file, test_file, batch_size=32):
    train_dataset = DiplomacyDataset(train_file)
    val_dataset = DiplomacyDataset(val_file)
    test_dataset = DiplomacyDataset(test_file)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, val_loader, test_loader

In [3]:
class DiplomacyLieDetector(nn.Module):
    def __init__(self, embedding_dim=300, hidden_dim=128, context_window=2, dropout=0.3):
        super(DiplomacyLieDetector, self).__init__()
        
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.context_window = context_window
        
        # Encoder for context messages (LSTM)
        self.context_encoder = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            num_layers=1,
            batch_first=True,
            bidirectional=True
        )
        
        # Attention mechanism
        self.attention = nn.Linear(hidden_dim * 2 + embedding_dim, hidden_dim)
        
        # Decoder layers
        self.fc1 = nn.Linear(hidden_dim * 2 + embedding_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)
        self.fc2 = nn.Linear(hidden_dim, 2)  # Binary classification (true/false)
        
    def forward(self, message_emb, context_emb):
        # message_emb: (batch_size, embedding_dim)
        # context_emb: (batch_size, context_window, embedding_dim)
        
        # Encode context messages
        # Shape: (batch_size, context_window, 2 * hidden_dim)
        context_output, (hidden, _) = self.context_encoder(context_emb)
        
        # Use last hidden state from both directions
        # Shape: (batch_size, 2 * hidden_dim)
        context_summary = hidden[-2:].transpose(0, 1).contiguous().view(-1, 2 * self.hidden_dim)
        
        # Combine message embedding with context summary for attention
        # Shape: (batch_size, 2 * hidden_dim + embedding_dim)
        combined = torch.cat((message_emb, context_summary), dim=1)
        
        # Attention weights
        # Shape: (batch_size, hidden_dim)
        attn_weights = F.tanh(self.attention(combined))
        
        # Final classification
        x = F.relu(self.fc1(combined))
        x = self.dropout(x)
        x = self.fc2(x)  # Shape: (batch_size, 2)
        
        return x

In [4]:
import torch.optim as optim
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

In [12]:
from tqdm import tqdm

def evaluate_model(model, data_loader, criterion, device):
    model.eval()
    total_loss = 0
    preds, labels = [], []
    
    with torch.no_grad():
        for batch in data_loader:
            message_emb = batch['message_emb'].to(device)
            context_emb = batch['context_emb'].to(device)
            true_labels = batch['label'].to(device)
            
            outputs = model(message_emb, context_emb)
            loss = criterion(outputs, true_labels)
            
            total_loss += loss.item()
            pred = torch.argmax(outputs, dim=1)
            preds.extend(pred.cpu().numpy())
            labels.extend(true_labels.cpu().numpy())
    
    avg_loss = total_loss / len(data_loader)
    accuracy = accuracy_score(labels, preds)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    
    return avg_loss, accuracy, precision, recall, f1

def train_model(model, train_loader, val_loader, num_epochs=20, learning_rate=0.001):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    best_val_acc = 0.0
    
    for epoch in range(num_epochs):
        # Training
        model.train()
        train_loss = 0
        train_preds, train_labels = [], []
        
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training"):
            message_emb = batch['message_emb'].to(device)
            context_emb = batch['context_emb'].to(device)
            labels = batch['label'].to(device)
            
            optimizer.zero_grad()
            outputs = model(message_emb, context_emb)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            preds = torch.argmax(outputs, dim=1)
            train_preds.extend(preds.cpu().numpy())
            train_labels.extend(labels.cpu().numpy())
        
        train_acc = accuracy_score(train_labels, train_preds)
        
        # Validation
        val_loss, val_acc,_,_,_ = evaluate_model(model, val_loader, criterion, device)
        
        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"Train Loss: {train_loss/len(train_loader):.4f}, Train Acc: {train_acc:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_model.pt')
    
    return model


In [16]:
train_file = "data/processed/train_processed.h5"
val_file = "data/processed/val_processed.h5"
test_file = "data/processed/test_processed.h5"

# Create data loaders
train_loader, val_loader, test_loader = create_dataloaders(train_file, val_file, test_file)

# Initialize model
model = DiplomacyLieDetector(embedding_dim=300, hidden_dim=128, context_window=2)


In [17]:
# Train model
trained_model = train_model(model, train_loader, val_loader)

# Load best model
trained_model.load_state_dict(torch.load('best_model.pt'))


Epoch 1/20 - Training: 100%|██████████| 378/378 [00:05<00:00, 65.77it/s]


Epoch 1/20
Train Loss: 0.2085, Train Acc: 0.9494
Val Loss: 0.1982, Val Acc: 0.9566


Epoch 2/20 - Training: 100%|██████████| 378/378 [00:05<00:00, 66.92it/s]


Epoch 2/20
Train Loss: 0.1958, Train Acc: 0.9510
Val Loss: 0.1930, Val Acc: 0.9566


Epoch 3/20 - Training: 100%|██████████| 378/378 [00:05<00:00, 64.18it/s]


Epoch 3/20
Train Loss: 0.1901, Train Acc: 0.9510
Val Loss: 0.1998, Val Acc: 0.9566


Epoch 4/20 - Training: 100%|██████████| 378/378 [00:05<00:00, 65.50it/s]


Epoch 4/20
Train Loss: 0.1848, Train Acc: 0.9510
Val Loss: 0.1997, Val Acc: 0.9566


Epoch 5/20 - Training: 100%|██████████| 378/378 [00:05<00:00, 66.75it/s]


Epoch 5/20
Train Loss: 0.1818, Train Acc: 0.9515
Val Loss: 0.1953, Val Acc: 0.9566


Epoch 6/20 - Training: 100%|██████████| 378/378 [00:05<00:00, 65.98it/s]


Epoch 6/20
Train Loss: 0.1775, Train Acc: 0.9513
Val Loss: 0.1997, Val Acc: 0.9566


Epoch 7/20 - Training: 100%|██████████| 378/378 [00:05<00:00, 66.95it/s]


Epoch 7/20
Train Loss: 0.1704, Train Acc: 0.9519
Val Loss: 0.2058, Val Acc: 0.9566


Epoch 8/20 - Training: 100%|██████████| 378/378 [00:05<00:00, 67.25it/s]


Epoch 8/20
Train Loss: 0.1640, Train Acc: 0.9532
Val Loss: 0.2075, Val Acc: 0.9542


Epoch 9/20 - Training: 100%|██████████| 378/378 [00:05<00:00, 66.56it/s]


Epoch 9/20
Train Loss: 0.1596, Train Acc: 0.9544
Val Loss: 0.2484, Val Acc: 0.9566


Epoch 10/20 - Training: 100%|██████████| 378/378 [00:05<00:00, 67.15it/s]


Epoch 10/20
Train Loss: 0.1515, Train Acc: 0.9551
Val Loss: 0.2514, Val Acc: 0.9503


Epoch 11/20 - Training: 100%|██████████| 378/378 [00:05<00:00, 65.56it/s]


Epoch 11/20
Train Loss: 0.1435, Train Acc: 0.9556
Val Loss: 0.2514, Val Acc: 0.9542


Epoch 12/20 - Training: 100%|██████████| 378/378 [00:05<00:00, 67.05it/s]


Epoch 12/20
Train Loss: 0.1357, Train Acc: 0.9577
Val Loss: 0.2710, Val Acc: 0.9527


Epoch 13/20 - Training: 100%|██████████| 378/378 [00:05<00:00, 66.68it/s]


Epoch 13/20
Train Loss: 0.1273, Train Acc: 0.9603
Val Loss: 0.2629, Val Acc: 0.9503


Epoch 14/20 - Training: 100%|██████████| 378/378 [00:05<00:00, 66.80it/s]


Epoch 14/20
Train Loss: 0.1217, Train Acc: 0.9619
Val Loss: 0.2975, Val Acc: 0.9550


Epoch 15/20 - Training: 100%|██████████| 378/378 [00:05<00:00, 66.08it/s]


Epoch 15/20
Train Loss: 0.1134, Train Acc: 0.9623
Val Loss: 0.3227, Val Acc: 0.9496


Epoch 16/20 - Training: 100%|██████████| 378/378 [00:05<00:00, 63.85it/s]


Epoch 16/20
Train Loss: 0.1090, Train Acc: 0.9640
Val Loss: 0.3322, Val Acc: 0.9457


Epoch 17/20 - Training: 100%|██████████| 378/378 [00:05<00:00, 67.81it/s]


Epoch 17/20
Train Loss: 0.0995, Train Acc: 0.9668
Val Loss: 0.3069, Val Acc: 0.9457


Epoch 18/20 - Training: 100%|██████████| 378/378 [00:05<00:00, 66.35it/s]


Epoch 18/20
Train Loss: 0.0912, Train Acc: 0.9690
Val Loss: 0.3833, Val Acc: 0.9395


Epoch 19/20 - Training: 100%|██████████| 378/378 [00:05<00:00, 65.12it/s]


Epoch 19/20
Train Loss: 0.0866, Train Acc: 0.9702
Val Loss: 0.3821, Val Acc: 0.9403


Epoch 20/20 - Training: 100%|██████████| 378/378 [00:05<00:00, 64.97it/s]


Epoch 20/20
Train Loss: 0.0778, Train Acc: 0.9731
Val Loss: 0.4547, Val Acc: 0.9364


<All keys matched successfully>

In [19]:
# Evaluate on test set
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss()
test_loss, test_acc, test_precision, test_recall, test_f1 = evaluate_model(trained_model, test_loader, criterion, device)

print("\nTest Results:")
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.4f}")
print(f"Test Precision: {test_precision:.4f}")
print(f"Test Recall: {test_recall:.4f}")
print(f"Test F1 Score: {test_f1:.4f}")


Test Results:
Test Loss: 0.3161
Test Accuracy: 0.9043
Test Precision: 0.9043
Test Recall: 1.0000
Test F1 Score: 0.9497


In [21]:
def print_test_classification_report(model, test_loader, criterion, device):
    """
    Evaluate the model on the test dataset and print a detailed classification report.
    
    Args:
        model: Trained PyTorch model
        test_loader: DataLoader for test dataset
        criterion: Loss function (e.g., nn.CrossEntropyLoss)
        device: Device to run evaluation on (e.g., 'cuda' or 'cpu')
    
    Returns:
        None (prints the classification report)
    """
    model.eval()
    total_loss = 0
    preds, labels = [], []
    
    # Add tqdm progress bar for test evaluation
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating Test Set"):
            message_emb = batch['message_emb'].to(device)
            context_emb = batch['context_emb'].to(device)
            true_labels = batch['label'].to(device)
            
            outputs = model(message_emb, context_emb)
            loss = criterion(outputs, true_labels)
            
            total_loss += loss.item()
            pred = torch.argmax(outputs, dim=1)
            preds.extend(pred.cpu().numpy())
            labels.extend(true_labels.cpu().numpy())
    
    # Calculate metrics
    avg_loss = total_loss / len(test_loader)
    accuracy = accuracy_score(labels, preds)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    
    # Generate and print classification report
    report = classification_report(labels, preds, target_names=['False', 'True'])
    
    print("\nTest Set Evaluation Results:")
    print(f"Average Loss: {avg_loss:.4f}")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Binary Precision: {precision:.4f}")
    print(f"Binary Recall: {recall:.4f}")
    print(f"Binary F1-Score: {f1:.4f}")
    print("\nDetailed Classification Report:")
    print(report)

In [22]:
print_test_classification_report(trained_model, test_loader, criterion, device)

Evaluating Test Set: 100%|██████████| 79/79 [00:00<00:00, 116.64it/s]


Test Set Evaluation Results:
Average Loss: 0.3161
Accuracy: 0.9043
Binary Precision: 0.9043
Binary Recall: 1.0000
Binary F1-Score: 0.9497

Detailed Classification Report:
              precision    recall  f1-score   support

       False       0.00      0.00      0.00       240
        True       0.90      1.00      0.95      2268

    accuracy                           0.90      2508
   macro avg       0.45      0.50      0.47      2508
weighted avg       0.82      0.90      0.86      2508




  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [51]:
class DiplomacyDataset(Dataset):
    def __init__(self, hdf5_file):
        self.data = pd.read_hdf(hdf5_file, key='diplomacy_data')
        self.message_embeddings = np.array(self.data['message_embedding'].tolist())
        self.context_embeddings = np.array(self.data['context_embeddings'].tolist())
        self.labels = self.data['label'].values
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return {
            'message_embedding': torch.tensor(self.message_embeddings[idx], dtype=torch.float32),
            'context_embeddings': torch.tensor(self.context_embeddings[idx], dtype=torch.float32),
            'label': torch.tensor(self.labels[idx], dtype=torch.float32)
        }

In [52]:
class BalancedBatchSampler(Sampler):
    def __init__(self, dataset, batch_size, false_ratio=0.75):  # New param: false_ratio
        self.dataset = dataset
        self.batch_size = batch_size
        self.false_ratio = false_ratio  # Fraction of batch that should be False
        self.pos_indices = np.where(dataset.labels == 1)[0]
        self.neg_indices = np.where(dataset.labels == 0)[0]
        self.num_pos = len(self.pos_indices)
        self.num_neg = len(self.neg_indices)
        self.neg_per_batch = int(batch_size * false_ratio)
        self.pos_per_batch = batch_size - self.neg_per_batch
        self.batches_per_epoch = max(self.num_pos // self.pos_per_batch, self.num_neg // self.neg_per_batch)
    
    def __iter__(self):
        for _ in range(self.batches_per_epoch):
            neg_batch = np.random.choice(self.neg_indices, min(self.neg_per_batch, self.num_neg), replace=True)
            pos_batch = np.random.choice(self.pos_indices, min(self.pos_per_batch, self.num_pos), replace=True)
            batch = np.concatenate([neg_batch, pos_batch])
            np.random.shuffle(batch)
            yield batch
    
    def __len__(self):
        return self.batches_per_epoch

In [53]:
class FocalLoss(nn.Module):
    def __init__(self, alpha, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha  # Scalar or tensor for class weights
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)  # alpha for True, 1-alpha for False
        focal_loss = alpha_t * (1 - pt) ** self.gamma * BCE_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        return focal_loss

In [54]:
class DiplomacyLieDetector(nn.Module):
    def __init__(self, embedding_dim=300, hidden_dim=256, context_window=2, 
                 num_heads=4, num_encoder_layers=2, dropout=0.1):
        super(DiplomacyLieDetector, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.context_window = context_window
        
        self.context_lstm = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            num_layers=1,
            batch_first=True,
            bidirectional=True
        )
        
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=hidden_dim * 2 + embedding_dim,
                nhead=num_heads,
                dim_feedforward=hidden_dim * 4,
                dropout=dropout,
                batch_first=True
            ),
            num_layers=num_encoder_layers
        )
        
        self.fc1 = nn.Linear(hidden_dim * 2 + embedding_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)
        self.fc2 = nn.Linear(hidden_dim, 1)  # Ensure single output for binary classification
        
    def forward(self, message_embedding, context_embeddings):
        batch_size = message_embedding.size(0)
        lstm_out, _ = self.context_lstm(context_embeddings)
        context_vector = lstm_out[:, -1, :]  # (batch_size, hidden_dim * 2)
        combined = torch.cat([message_embedding, context_vector], dim=1)  # (batch_size, embedding_dim + hidden_dim * 2)
        combined = combined.unsqueeze(1)  # (batch_size, 1, embedding_dim + hidden_dim * 2)
        
        transformer_out = self.transformer_encoder(combined)  # (batch_size, 1, embedding_dim + hidden_dim * 2)
        features = transformer_out.squeeze(1)  # (batch_size, embedding_dim + hidden_dim * 2)
        
        x = F.relu(self.fc1(features))  # (batch_size, hidden_dim)
        x = self.dropout(x)
        output = self.fc2(x)  # (batch_size, 1)
        return output

In [55]:
from tqdm import tqdm

def train_model(model, train_loader, val_loader,alpha, num_epochs=10, learning_rate=0.001, use_focal_loss=True):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    
    # Compute class weights for weighted BCE
    dataset = train_loader.dataset
    num_pos = np.sum(dataset.labels)
    num_neg = len(dataset.labels) - num_pos
    pos_weight = torch.tensor([num_neg / num_pos], device=device) if num_pos > 0 else torch.tensor([1.0], device=device)
    
    # Choose loss function
    if use_focal_loss:
        criterion = FocalLoss(alpha=1-alpha, gamma=2.0)
    else:
        criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0
        train_correct = 0
        train_samples = 0
        
        train_loader_tqdm = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training")
        for batch in train_loader_tqdm:
            message_emb = batch['message_embedding'].to(device)
            context_emb = batch['context_embeddings'].to(device)
            labels = batch['label'].to(device).unsqueeze(1)
            
            optimizer.zero_grad()
            outputs = model(message_emb, context_emb)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item() * labels.size(0)
            train_samples += labels.size(0)
            preds = (torch.sigmoid(outputs) > 0.5).float()
            train_correct += (preds == labels).sum().item()
            
            train_loader_tqdm.set_postfix(loss=loss.item())
        
        train_loss = train_loss / train_samples
        train_acc = train_correct / train_samples
        
        # Validation phase
        model.eval()
        val_loss = 0
        val_correct = 0
        val_samples = 0
        
        val_loader_tqdm = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation")
        with torch.no_grad():
            for batch in val_loader_tqdm:
                message_emb = batch['message_embedding'].to(device)
                context_emb = batch['context_embeddings'].to(device)
                labels = batch['label'].to(device).unsqueeze(1)
                
                outputs = model(message_emb, context_emb)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item() * labels.size(0)
                val_samples += labels.size(0)
                preds = (torch.sigmoid(outputs) > 0.5).float()
                val_correct += (preds == labels).sum().item()
                
                val_loader_tqdm.set_postfix(loss=loss.item())
        
        val_loss = val_loss / val_samples
        val_acc = val_correct / val_samples
        
        print(f"Epoch {epoch+1}/{num_epochs}, "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
    
    return model

In [60]:
def evaluate_model(model, test_loader, threshold=0.5):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    model.eval()
    
    all_probs = []  # Store probabilities instead of binary preds
    all_labels = []
    total_samples = 0
    
    with torch.no_grad():
        for batch in test_loader:
            message_emb = batch['message_embedding'].to(device)
            context_emb = batch['context_embeddings'].to(device)
            labels = batch['label'].to(device).unsqueeze(1)
            
            outputs = model(message_emb, context_emb)
            probs = torch.sigmoid(outputs)
            
            batch_size = labels.size(0)
            total_samples += batch_size
            
            all_probs.extend(torch.argmax(probs, dim=1))
            all_labels.extend(labels.cpu().numpy().flatten())
    
    all_probs = np.array(all_probs)
    all_labels = np.array(all_labels, dtype=int)
    
    # Try multiple thresholds to find best F1 for False
    thresholds = [0.5, 0.3, 0.1]
    best_f1_false = 0
    best_threshold = 0.5
    best_report = None
    
    for thresh in thresholds:
        all_preds = (all_probs > thresh).astype(int)
        report = classification_report(all_labels, all_preds, target_names=['False', 'True'], output_dict=True)
        f1_false = report['False']['f1-score']
        print(f"\nThreshold: {thresh}")
        print(classification_report(all_labels, all_preds, target_names=['False', 'True']))
        
        if f1_false > best_f1_false:
            best_f1_false = f1_false
            best_threshold = thresh
            best_report = classification_report(all_labels, all_preds, target_names=['False', 'True'])
    
    print(f"\nBest Threshold: {best_threshold}")
    print("Best Classification Report:\n", best_report)
    accuracy = np.mean((all_probs > best_threshold).astype(int) == all_labels)
    print(f"Test Accuracy at Best Threshold: {accuracy:.4f}")
    
    return all_probs, all_labels, best_report

In [57]:
# def evaluate_model(model, test_loader, threshold=0.5):
#     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#     model.to(device)
#     model.eval()
    
#     all_preds = []
#     all_labels = []
#     total_samples = 0
    
#     with torch.no_grad():
#         for i, batch in enumerate(test_loader):
#             message_emb = batch['message_embedding'].to(device)
#             context_emb = batch['context_embeddings'].to(device)
#             labels = batch['label'].to(device).unsqueeze(1)  # Shape: (batch_size, 1)
            
#             outputs = model(message_emb, context_emb)  # Shape: (batch_size, 1)
#             probs = torch.sigmoid(outputs)  # Shape: (batch_size, 1)
#             preds = (probs > threshold).float()  # Shape: (batch_size, 1)
            
#             batch_size = labels.size(0)
#             total_samples += batch_size
            
#             # Debug per batch
#             print(f"Batch {i}:")
#             print(f"  Labels shape: {labels.shape}, Preds shape: {preds.shape}")
#             print(f"  Labels flattened: {labels.cpu().numpy().flatten().shape}")
#             preds = torch.argmax(probs, dim=1)  # Shape: (batch_size,)
#             print(f"  Preds shape after argmax: {preds.shape}")
            
#             # Append batch data
#             all_labels.extend(labels.cpu().numpy().flatten())
#             all_preds.extend(preds.cpu().numpy())
    
#     # Convert to numpy arrays
#     all_preds = np.array(all_preds, dtype=int)
#     all_labels = np.array(all_labels, dtype=int)
    
#     # Final debug output
#     print(f"Total samples processed: {total_samples}")
#     print(f"Number of true labels: {len(all_labels)}")
#     print(f"Number of predicted labels: {len(all_preds)}")
    
#     if len(all_labels) != len(all_preds):
#         raise ValueError(f"Mismatch in sample counts: {len(all_labels)} labels vs {len(all_preds)} predictions")
    
#     # Print classification report
#     report = classification_report(all_labels, all_preds, target_names=['False', 'True'])
#     print("Classification Report:\n", report)
    
#     # Calculate and print accuracy
#     accuracy = np.mean(all_preds == all_labels)
#     print(f"Test Accuracy: {accuracy:.4f}")
    
#     return all_preds, all_labels, report

In [58]:
train_dataset = DiplomacyDataset('data/processed/train_processed.h5')
val_dataset = DiplomacyDataset('data/processed/val_processed.h5')
test_dataset = DiplomacyDataset('data/processed/test_processed.h5')
    
# Create balanced batch sampler
batch_size = 32
train_sampler = BalancedBatchSampler(train_dataset, batch_size, false_ratio=0.75)
train_loader = DataLoader(train_dataset, batch_sampler=train_sampler)

val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

num_pos = np.sum(train_dataset.labels)
num_neg = len(train_dataset.labels) - num_pos
alpha = num_neg / (num_pos + num_neg)  # Proportion of False
print(f"Alpha for Focal Loss: {alpha:.4f}")

# Initialize model
model = DiplomacyLieDetector(
    embedding_dim=300,
    hidden_dim=256,
    context_window=2,
    num_heads=4,
    num_encoder_layers=2
)

# criterion = FocalLoss(alpha=1-alpha)  # Use Focal Loss with alpha

# Train with Focal Loss
train_model(model, train_loader,val_loader, alpha, num_epochs=10, use_focal_loss=True)

Alpha for Focal Loss: 0.0490


Epoch 1/10 - Training: 100%|██████████| 1435/1435 [01:59<00:00, 12.02it/s, loss=0.029] 
Epoch 1/10 - Validation: 100%|██████████| 41/41 [00:00<00:00, 92.76it/s, loss=0.035] 


Epoch 1/10, Train Loss: 0.0290, Train Acc: 0.2503, Val Loss: 0.0393, Val Acc: 0.9566


Epoch 2/10 - Training: 100%|██████████| 1435/1435 [02:42<00:00,  8.84it/s, loss=0.0284]
Epoch 2/10 - Validation: 100%|██████████| 41/41 [00:00<00:00, 91.52it/s, loss=0.0375]


Epoch 2/10, Train Loss: 0.0288, Train Acc: 0.2500, Val Loss: 0.0436, Val Acc: 0.9566


Epoch 3/10 - Training: 100%|██████████| 1435/1435 [02:53<00:00,  8.25it/s, loss=0.0286]
Epoch 3/10 - Validation: 100%|██████████| 41/41 [00:00<00:00, 91.65it/s, loss=0.0362]


Epoch 3/10, Train Loss: 0.0287, Train Acc: 0.2500, Val Loss: 0.0415, Val Acc: 0.9566


Epoch 4/10 - Training: 100%|██████████| 1435/1435 [02:51<00:00,  8.37it/s, loss=0.0286]
Epoch 4/10 - Validation: 100%|██████████| 41/41 [00:00<00:00, 91.92it/s, loss=0.0363]


Epoch 4/10, Train Loss: 0.0286, Train Acc: 0.2500, Val Loss: 0.0417, Val Acc: 0.9566


Epoch 5/10 - Training: 100%|██████████| 1435/1435 [02:56<00:00,  8.12it/s, loss=0.0287]
Epoch 5/10 - Validation: 100%|██████████| 41/41 [00:00<00:00, 90.49it/s, loss=0.0363]


Epoch 5/10, Train Loss: 0.0286, Train Acc: 0.2500, Val Loss: 0.0416, Val Acc: 0.9566


Epoch 6/10 - Training: 100%|██████████| 1435/1435 [02:56<00:00,  8.11it/s, loss=0.0286]
Epoch 6/10 - Validation: 100%|██████████| 41/41 [00:00<00:00, 90.50it/s, loss=0.0363]


Epoch 6/10, Train Loss: 0.0286, Train Acc: 0.2500, Val Loss: 0.0417, Val Acc: 0.9566


Epoch 7/10 - Training: 100%|██████████| 1435/1435 [04:31<00:00,  5.29it/s, loss=0.0286]
Epoch 7/10 - Validation: 100%|██████████| 41/41 [00:00<00:00, 88.83it/s, loss=0.0363]


Epoch 7/10, Train Loss: 0.0286, Train Acc: 0.2500, Val Loss: 0.0417, Val Acc: 0.9566


Epoch 8/10 - Training: 100%|██████████| 1435/1435 [04:39<00:00,  5.14it/s, loss=0.0286]
Epoch 8/10 - Validation: 100%|██████████| 41/41 [00:00<00:00, 90.81it/s, loss=0.0363]


Epoch 8/10, Train Loss: 0.0286, Train Acc: 0.2500, Val Loss: 0.0417, Val Acc: 0.9566


Epoch 9/10 - Training: 100%|██████████| 1435/1435 [05:34<00:00,  4.29it/s, loss=0.0286]
Epoch 9/10 - Validation: 100%|██████████| 41/41 [00:01<00:00, 39.87it/s, loss=0.0363]


Epoch 9/10, Train Loss: 0.0286, Train Acc: 0.2500, Val Loss: 0.0417, Val Acc: 0.9566


Epoch 10/10 - Training: 100%|██████████| 1435/1435 [09:40<00:00,  2.47it/s, loss=0.0286]
Epoch 10/10 - Validation: 100%|██████████| 41/41 [00:00<00:00, 92.96it/s, loss=0.0363]

Epoch 10/10, Train Loss: 0.0286, Train Acc: 0.2500, Val Loss: 0.0417, Val Acc: 0.9566





DiplomacyLieDetector(
  (context_lstm): LSTM(300, 256, batch_first=True, bidirectional=True)
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-1): 2 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=812, out_features=812, bias=True)
        )
        (linear1): Linear(in_features=812, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=1024, out_features=812, bias=True)
        (norm1): LayerNorm((812,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((812,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (fc1): Linear(in_features=812, out_features=256, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (fc2): Linear(in_features=256, out_features=1, bias=True)
)

In [61]:
preds, labels, report = evaluate_model(trained_model, test_loader)


Threshold: 0.5
              precision    recall  f1-score   support

       False       0.00      0.00      0.00       240
        True       0.90      1.00      0.95      2268

    accuracy                           0.90      2508
   macro avg       0.45      0.50      0.47      2508
weighted avg       0.82      0.90      0.86      2508


Threshold: 0.3
              precision    recall  f1-score   support

       False       0.00      0.00      0.00       240
        True       0.90      1.00      0.95      2268

    accuracy                           0.90      2508
   macro avg       0.45      0.50      0.47      2508
weighted avg       0.82      0.90      0.86      2508


Threshold: 0.1
              precision    recall  f1-score   support

       False       0.00      0.00      0.00       240
        True       0.90      1.00      0.95      2268

    accuracy                           0.90      2508
   macro avg       0.45      0.50      0.47      2508
weighted avg       0.82   

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize