In [1]:
import json
from collections import defaultdict


type_info = defaultdict(set)
list_lengths = defaultdict(set)

with open("/kaggle/input/acl-diplomacy/train.jsonl", 'r') as f:
    for line in f:
        entry = json.loads(line)
        for key, value in entry.items():
            # Record type
            type_info[key].add(type(value).__name__)

            
            if isinstance(value, list):
                list_lengths[key].add(len(value))


print("Key-wise type and list length info:\n")
for key in sorted(type_info.keys()):
    print(f"{key}:")
    print(f"  Types seen: {sorted(type_info[key])}")
    if key in list_lengths:
        print(f"  List lengths seen: {sorted(list_lengths[key])}")
    print()


Key-wise type and list length info:

absolute_message_index:
  Types seen: ['list']
  List lengths seen: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 28, 29, 30, 31, 32, 33, 34, 35, 36, 38, 44, 46, 47, 48, 49, 50, 51, 55, 56, 58, 62, 63, 64, 65, 66, 67, 68, 69, 70, 75, 78, 81, 86, 87, 90, 95, 96, 98, 99, 104, 113, 119, 120, 123, 130, 133, 134, 135, 136, 139, 145, 148, 150, 151, 155, 157, 161, 166, 189, 197, 205, 208, 215, 283, 321, 366, 435, 457, 471, 480, 511, 656, 675]

game_id:
  Types seen: ['int']

game_score:
  Types seen: ['list']
  List lengths seen: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 28, 29, 30, 31, 32, 33, 34, 35, 36, 38, 44, 46, 47, 48, 49, 50, 51, 55, 56, 58, 62, 63, 64, 65, 66, 67, 68, 69, 70, 75, 78, 81, 86, 87, 90, 95, 96, 98, 99, 104, 113, 119, 120, 123, 130, 133, 134, 135, 136, 139, 145, 148, 150, 151, 155, 157, 161, 166, 189, 197, 205, 208, 215, 283, 321, 366,

In [3]:
import json



with open("/kaggle/input/acl-diplomacy/train.jsonl", 'r') as f:
    for i in range(3):
        line = f.readline()
        if not line:
            break
        entry = json.loads(line)
        print(entry.keys())
        score_delta = entry.get("game_score_delta")
        sender_labels = entry.get("sender_labels")

        print(f"Entry {i+1}:")
        print("score_delta:", len(score_delta) if score_delta is not None else "Missing")
        print("sender_labels:", len(sender_labels) if sender_labels is not None else "Missing")
        print()


dict_keys(['messages', 'sender_labels', 'receiver_labels', 'speakers', 'receivers', 'absolute_message_index', 'relative_message_index', 'seasons', 'years', 'game_score', 'game_score_delta', 'players', 'game_id'])
Entry 1:
score_delta: 321
sender_labels: 321

dict_keys(['messages', 'sender_labels', 'receiver_labels', 'speakers', 'receivers', 'absolute_message_index', 'relative_message_index', 'seasons', 'years', 'game_score', 'game_score_delta', 'players', 'game_id'])
Entry 2:
score_delta: 155
sender_labels: 155

dict_keys(['messages', 'sender_labels', 'receiver_labels', 'speakers', 'receivers', 'absolute_message_index', 'relative_message_index', 'seasons', 'years', 'game_score', 'game_score_delta', 'players', 'game_id'])
Entry 3:
score_delta: 87
sender_labels: 87



In [4]:
import json
from collections import Counter


train_file = "/kaggle/input/acl-diplomacy/train.jsonl"
val_file = "/kaggle/input/acl-diplomacy/validation.jsonl"
test_file = "/kaggle/input/acl-diplomacy/test.jsonl"


def count_labels(file_path):
    counts = Counter()
    with open(file_path, 'r') as f:
        for line in f:
            entry = json.loads(line)
            label = entry.get("sender_labels", [])
            for i in label:
                label_str = str(i).strip().lower()
                if label_str == "true":
                    counts["true"] += 1
                elif label_str == "false":
                    counts["false"] += 1
                else:
                    counts["other"] += 1
    return counts

train_counts = count_labels(train_file)
val_counts = count_labels(val_file)
test_counts = count_labels(test_file)

overall_counts = train_counts + val_counts + test_counts


print("Train label count:")
print(f"true: {train_counts.get('true', 0)}")
print(f"false: {train_counts.get('false', 0)}")

print()


print("Validation label count:")
print(f"true: {val_counts.get('true', 0)}")
print(f"false: {val_counts.get('false', 0)}")

print()

print("Test label count:")
print(f"true: {test_counts.get('true', 0)}")
print(f"false: {test_counts.get('false', 0)}")
print()

print("Overall label count:")
print(f"true: {overall_counts.get('true', 0)}")
print(f"false: {overall_counts.get('false', 0)}")

print()


Train label count:
true: 12541
false: 591

Validation label count:
true: 1360
false: 56

Test label count:
true: 2501
false: 240

Overall label count:
true: 16402
false: 887



In [5]:
import json
import re
from collections import defaultdict
from tqdm import tqdm
import json
import re
from collections import defaultdict
from tqdm import tqdm

def clean_message(text):
    text = text.lower()
    text = re.sub(r'https?://\S+|www\.\S+', '', text)  
    text = re.sub(r'[^a-z0-9\s]', '', text)          
    text = re.sub(r'\s+', ' ', text).strip()          
    return text

def add_history_to_jsonl(input_path, output_path, k=5):
    with open(input_path, 'r') as f_in, open(output_path, 'w') as f_out:
        for line in tqdm(f_in, desc="Processing games"):
            if not line.strip():
                continue

            game = json.loads(line)
            game_id = game.get('game_id', 'UNKNOWN')

            messages = game['messages']
            speakers = game['speakers']
            receivers = game['receivers']

    
            history_lookup = defaultdict(list)

            game['history'] = []

            for i in range(len(messages)):
                sender = speakers[i]
                receiver = receivers[i]
                pair_key = (game_id, sender, receiver)

                
                cleaned = clean_message(messages[i])
                game['messages'][i] = cleaned

              
                last_k = history_lookup[pair_key][-k:]
                game['history'].append(last_k)

               
                history_lookup[pair_key].append(cleaned)

            
            f_out.write(json.dumps(game) + "\n")


add_history_to_jsonl("/kaggle/input/acl-diplomacy/train.jsonl", "train_with_history_10.jsonl", k=10)


Processing games: 189it [00:00, 665.82it/s]


In [6]:
add_history_to_jsonl("/kaggle/input/acl-diplomacy/validation.jsonl", "val_with_history_10.jsonl", k=10)
add_history_to_jsonl("/kaggle/input/acl-diplomacy/test.jsonl", "test_with_history_10.jsonl", k=10)

Processing games: 21it [00:00, 771.12it/s]
Processing games: 42it [00:00, 813.39it/s]


In [7]:
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from tqdm import tqdm
from sklearn.metrics import f1_score, accuracy_score
import numpy as np


2025-04-15 17:35:31.654753: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744738531.847163      31 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744738531.901981      31 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [8]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=50):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

      
        pe = torch.zeros(max_len, d_model) 
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # [max_len, 1]
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        if d_model % 2 == 1:
           
            pe[:, 1::2] = torch.cos(position * div_term[:-1])
        else:
            pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0) 
        self.register_buffer('pe', pe)

    def forward(self, x):
       
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)


class GameStateEncoder(nn.Module):
    def __init__(self, season_vocab_size, year_vocab_size, out_dim=32):
        super(GameStateEncoder, self).__init__()
        
        input_dim = 2 + season_vocab_size + year_vocab_size
        self.fc = nn.Sequential(
            nn.Linear(input_dim, out_dim),
            nn.ReLU()
        )
    
    def forward(self, game_features):
        
        return self.fc(game_features)

class HistoryEncoderTransformer(nn.Module):
    def __init__(self, hidden_dim=768, num_layers=2, num_heads=8, dropout=0.1):
        super(HistoryEncoderTransformer, self).__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased")
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads, dropout=dropout, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.pos_encoder = PositionalEncoding(hidden_dim, dropout=dropout, max_len=20)  

    def forward(self, history_texts_batch):
        
        device = next(self.parameters()).device
        all_bert_embeddings = []
        lengths = []

        for history in history_texts_batch:
            if len(history) == 0:
            
                history = [""]
           
            inputs = self.tokenizer(history, return_tensors="pt", padding=True, truncation=True, return_attention_mask=True)
            inputs = {k: v.to(device) for k, v in inputs.items()}

            with torch.no_grad():
                outputs = self.bert(**inputs)
        
            cls_embeddings = outputs.last_hidden_state[:, 0, :]
            all_bert_embeddings.append(cls_embeddings)
            lengths.append(cls_embeddings.size(0))
        
      
        padded = nn.utils.rnn.pad_sequence(all_bert_embeddings, batch_first=True) 
        max_len = padded.size(1)
        attn_mask = torch.zeros(padded.size(0), max_len, dtype=torch.bool, device=device)
        for i, l in enumerate(lengths):
            if l < max_len:
                attn_mask[i, l:] = True

    
        padded = self.pos_encoder(padded)
      
        transformer_out = self.transformer_encoder(padded, src_key_padding_mask=attn_mask)
    
        pooled = []
        for i in range(transformer_out.size(0)):
            valid_tokens = transformer_out[i, :lengths[i], :]
            pooled.append(valid_tokens.mean(dim=0))
        history_encoding = torch.stack(pooled, dim=0) 
        return history_encoding

class FusionAttention(nn.Module):
    def __init__(self, input_dim, num_heads=2):
        super(FusionAttention, self).__init__()
    
        self.attn = nn.MultiheadAttention(embed_dim=input_dim, num_heads=num_heads, batch_first=True)
        
        self.pooling_query = nn.Parameter(torch.randn(1, 1, input_dim))
    
    def forward(self, features):
       
        attn_out, _ = self.attn(features, features, features)
    
        query = self.pooling_query.expand(attn_out.size(0), -1, -1) 
     
        attn_weights = torch.bmm(query, attn_out.transpose(1, 2)) 
        attn_weights = torch.softmax(attn_weights, dim=-1) 
        
        fused_vector = torch.bmm(attn_weights, attn_out).squeeze(1) 
        return fused_vector


class DeceptionClassifier(nn.Module):
    def __init__(self, fused_dim):
        super(DeceptionClassifier, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(fused_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 1)
        )
    
    def forward(self, fused_vector):
        logits = self.classifier(fused_vector)
        return logits


class MultiModalDeceptionModel(nn.Module):
    def __init__(self, season_vocab_size=3, year_vocab_size=5,
                 game_state_out_dim=32, fusion_dim=768):
        super(MultiModalDeceptionModel, self).__init__()
   
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        

        self.game_encoder = GameStateEncoder(season_vocab_size, year_vocab_size, out_dim=game_state_out_dim)

        self.sender_embedding = nn.Embedding(num_embeddings=7, embedding_dim=16)
        self.receiver_embedding = nn.Embedding(num_embeddings=7, embedding_dim=16)
        self.sender_receiver_proj = nn.Linear(64, self.tokenizer.model_max_length if hasattr(self.tokenizer, 'model_max_length') else fusion_dim)
        self.proj_to_text_dim = nn.Linear(self.sender_receiver_proj.out_features, fusion_dim)
        self.history_encoder = HistoryEncoderTransformer(hidden_dim=fusion_dim)
       
        self.text_feature_dim = fusion_dim  # 768
        self.fusion_attention = FusionAttention(input_dim=self.text_feature_dim, num_heads=2)
        
    
        self.classifier = DeceptionClassifier(fused_dim=self.text_feature_dim)
    
    def forward(self, current_message, game_state_features, history_texts, sender_ids, receiver_ids):
        device = next(self.parameters()).device

        inputs = self.tokenizer(current_message, return_tensors="pt", truncation=True, padding=True)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        text_outputs = self.bert(**inputs)
        text_vector = text_outputs.last_hidden_state[:, 0, :] 
    
        
        game_vector_raw = self.game_encoder(game_state_features)  
    
     
        sender_emb = self.sender_embedding(sender_ids)   
        receiver_emb = self.receiver_embedding(receiver_ids) 
        combined_game = torch.cat([game_vector_raw, sender_emb, receiver_emb], dim=1) 
       
        game_vector_proj = self.sender_receiver_proj(combined_game)
        game_vector = self.proj_to_text_dim(game_vector_proj)
    
       
        history_vector = self.history_encoder(history_texts)  
    
        
        fusion_input = torch.stack([text_vector, game_vector, history_vector], dim=1)
    

        fused_vector = self.fusion_attention(fusion_input)  
    
        logits = self.classifier(fused_vector) 
        return logits

class DeceptionDataset(Dataset):
    def __init__(self, jsonl_file, season_to_idx=None, year_buckets=None, country_to_idx=None):
        self.samples = []
        self.season_to_idx = season_to_idx or {"Spring": 0, "Fall": 1, "Winter": 2}
        self.year_buckets = year_buckets or [1901, 1906, 1911, 1916, 1921]
        self.country_to_idx = country_to_idx or {"russia":0, "turkey":1, "england":2, "france":3, "germany":4, "italy":5, "austria":6}
        
        with open(jsonl_file, 'r') as f:
            for line in f:
                if not line.strip():
                    continue
                game = json.loads(line)
                num_messages = len(game['messages'])
                for i in range(num_messages):
                    sample = {}
                
                    sample['current_message'] = game['messages'][i]
                    
        
                    try:
                        game_score = float(game['game_score'][i])
                    except:
                        game_score = 0.0
                    try:
                        score_delta = float(game['game_score_delta'][i])
                    except:
                        score_delta = 0.0
                    
                
                    season = game['seasons'][i]
                    season_vec = [0] * len(self.season_to_idx)
                    if season in self.season_to_idx:
                        season_vec[self.season_to_idx[season]] = 1
                    
                
                    year = int(game['years'][i])
                    year_bucket = self.bucket_year(year)
                    year_vec = [0] * (len(self.year_buckets))
                    year_vec[year_bucket] = 1
                    
                  
                    game_state = [game_score, score_delta] + season_vec + year_vec
                    sample['game_state_features'] = torch.tensor(game_state, dtype=torch.float)
                    
                   
                    sample['history'] = game['history'][i]
                    
                    
                    sender_str = game['speakers'][i].lower()
                    receiver_str = game['receivers'][i].lower()
                    sample['sender'] = torch.tensor(self.country_to_idx.get(sender_str, 0), dtype=torch.long)
                    sample['receiver'] = torch.tensor(self.country_to_idx.get(receiver_str, 0), dtype=torch.long)
                    
                    label_raw = game['sender_labels'][i]
                    label = 0 if label_raw is True or label_raw == "true" else 1
                    sample['label'] = torch.tensor(label, dtype=torch.float)
                    
                    self.samples.append(sample)
    
    def bucket_year(self, year):
        for idx, bound in enumerate(self.year_buckets):
            if year < bound:
                return idx
        return len(self.year_buckets) - 1
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        return self.samples[idx]

def collate_fn(batch):
    current_messages = [sample['current_message'] for sample in batch]
    game_states = torch.stack([sample['game_state_features'] for sample in batch])  # [batch, dim]
    histories = [sample['history'] for sample in batch]  
    sender_ids = torch.stack([sample['sender'] for sample in batch])
    receiver_ids = torch.stack([sample['receiver'] for sample in batch])
    labels = torch.stack([sample['label'] for sample in batch])
    return current_messages, game_states, histories, sender_ids, receiver_ids, labels

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [8]:


model = MultiModalDeceptionModel(season_vocab_size=3, year_vocab_size=5,
                                 game_state_out_dim=32, fusion_dim=768)
model.to(device)


train_dataset = DeceptionDataset("train_with_history_10.jsonl")
val_dataset   = DeceptionDataset("val_with_history_10.jsonl")

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)

num_true = 12541
num_false = 591
pos_weight_value = num_true / num_false
pos_weight = torch.tensor([pos_weight_value]).to(device)
print(f"Class counts -> True: {num_true}, False: {num_false}")
print(f"Computed pos_weight for BCEWithLogitsLoss: {pos_weight_value:.4f}")

lr = 1e-5
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = optim.Adam(model.parameters(), lr=lr)

best_f1 = 0.0

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1, verbose=True)

best_f1 = 0.0

model.train()
for epoch in range(10):
    train_losses = []
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1} Training"):
        current_messages, game_states, histories, sender_ids, receiver_ids, labels = batch
        optimizer.zero_grad()
        
        game_states = game_states.to(device)
        sender_ids = sender_ids.to(device)
        receiver_ids = receiver_ids.to(device)
        labels = labels.to(device)
        
        logits = model(current_messages, game_states, histories, sender_ids, receiver_ids)
        loss = criterion(logits.view(-1), labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        train_losses.append(loss.item())
    
    avg_train_loss = np.mean(train_losses)
    
    # Validation
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Epoch {epoch+1} Validation"):
            current_messages, game_states, histories, sender_ids, receiver_ids, labels = batch
            game_states = game_states.to(device)
            sender_ids = sender_ids.to(device)
            receiver_ids = receiver_ids.to(device)
            labels = labels.to(device)
            logits = model(current_messages, game_states, histories, sender_ids, receiver_ids)
            prob = torch.sigmoid(logits.view(-1))
            preds = (prob > 0.5).long()
            all_preds.extend(preds.tolist())
            all_labels.extend(labels.long().tolist())
    
    val_acc = accuracy_score(all_labels, all_preds)
    val_f1 = f1_score(all_labels, all_preds, average="macro")
    print(f"\nEpoch {epoch+1}: Train Loss = {avg_train_loss:.4f} | Val Acc = {val_acc:.4f} | Val F1 = {val_f1:.4f}")
    
    scheduler.step(avg_train_loss)
    
    if val_f1 > best_f1:
        best_f1 = val_f1
        torch.save(model.state_dict(), "best_model_attention2.pt")
        print(f" Saved new best model (F1 = {val_f1:.4f}) at epoch {epoch+1}")
    
    model.train()


Class counts -> True: 12541, False: 591
Computed pos_weight for BCEWithLogitsLoss: 21.2200


Epoch 1 Training: 100%|██████████| 821/821 [12:18<00:00,  1.11it/s]
Epoch 1 Validation: 100%|██████████| 89/89 [00:52<00:00,  1.68it/s]



Epoch 1: Train Loss = 2.0546 | Val Acc = 0.9569 | Val F1 = 0.4890
 Saved new best model (F1 = 0.4890) at epoch 1


Epoch 2 Training: 100%|██████████| 821/821 [12:00<00:00,  1.14it/s]
Epoch 2 Validation: 100%|██████████| 89/89 [00:51<00:00,  1.73it/s]



Epoch 2: Train Loss = 2.3864 | Val Acc = 0.9576 | Val F1 = 0.5053
 Saved new best model (F1 = 0.5053) at epoch 2


Epoch 3 Training: 100%|██████████| 821/821 [11:57<00:00,  1.14it/s]
Epoch 3 Validation: 100%|██████████| 89/89 [00:51<00:00,  1.73it/s]



Epoch 3: Train Loss = 2.3605 | Val Acc = 0.9597 | Val F1 = 0.5225
 Saved new best model (F1 = 0.5225) at epoch 3


Epoch 4 Training: 100%|██████████| 821/821 [11:54<00:00,  1.15it/s]
Epoch 4 Validation: 100%|██████████| 89/89 [00:51<00:00,  1.74it/s]



Epoch 4: Train Loss = 2.2733 | Val Acc = 0.9590 | Val F1 = 0.6079
 Saved new best model (F1 = 0.6079) at epoch 4


Epoch 5 Training: 100%|██████████| 821/821 [12:12<00:00,  1.12it/s]
Epoch 5 Validation: 100%|██████████| 89/89 [00:53<00:00,  1.66it/s]



Epoch 5: Train Loss = 2.0693 | Val Acc = 0.9520 | Val F1 = 0.5626


Epoch 6 Training: 100%|██████████| 821/821 [12:14<00:00,  1.12it/s]
Epoch 6 Validation: 100%|██████████| 89/89 [00:51<00:00,  1.73it/s]



Epoch 6: Train Loss = 1.7672 | Val Acc = 0.9484 | Val F1 = 0.5672


Epoch 7 Training: 100%|██████████| 821/821 [11:58<00:00,  1.14it/s]
Epoch 7 Validation: 100%|██████████| 89/89 [00:51<00:00,  1.74it/s]



Epoch 7: Train Loss = 1.6262 | Val Acc = 0.9294 | Val F1 = 0.5430


Epoch 8 Training: 100%|██████████| 821/821 [11:57<00:00,  1.14it/s]
Epoch 8 Validation: 100%|██████████| 89/89 [00:51<00:00,  1.73it/s]



Epoch 8: Train Loss = 1.5442 | Val Acc = 0.9477 | Val F1 = 0.5563


Epoch 9 Training: 100%|██████████| 821/821 [11:59<00:00,  1.14it/s]
Epoch 9 Validation: 100%|██████████| 89/89 [00:53<00:00,  1.66it/s]



Epoch 9: Train Loss = 1.3868 | Val Acc = 0.9541 | Val F1 = 0.5661


Epoch 10 Training: 100%|██████████| 821/821 [12:19<00:00,  1.11it/s]
Epoch 10 Validation: 100%|██████████| 89/89 [00:53<00:00,  1.67it/s]


Epoch 10: Train Loss = 1.2362 | Val Acc = 0.9470 | Val F1 = 0.5553





In [None]:
#sfsdfsdfdsfhghhghfggfdfgggghhhhghhfghfghfghfghfghfghfhgjghjhhgjfgjhfgfgdgdfgdfgdfgdfghghfjkhkjhkhkjhhghghg

In [9]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model = MultiModalDeceptionModel(season_vocab_size=3, year_vocab_size=5,
                                 game_state_out_dim=32, fusion_dim=768)
model.to(device)
model.load_state_dict(torch.load("/kaggle/input/attention-tranformer-all-embed/best_model_attention2 (1).pt"))
model.eval()
test_dataset   = DeceptionDataset("/kaggle/working/test_with_history_10.jsonl")

test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)

all_preds = []
all_labels = []

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Test Evaluation"):
        current_messages, game_states, histories, sender_ids, receiver_ids, labels = batch
        
   
        game_states = game_states.to(device)
        sender_ids = sender_ids.to(device)
        receiver_ids = receiver_ids.to(device)
        labels = labels.to(device)
        
        logits = model(current_messages, game_states, histories, sender_ids, receiver_ids)
        prob = torch.sigmoid(logits.view(-1))
        preds = (prob > 0.5).long()
        
        all_preds.extend(preds.tolist())
        all_labels.extend(labels.long().tolist())


from sklearn.metrics import confusion_matrix
cm = confusion_matrix(all_labels, all_preds)
print("Final Confusion Matrix:")

print(cm)


cuda


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

  model.load_state_dict(torch.load("/kaggle/input/attention-tranformer-all-embed/best_model_attention2 (1).pt"))
  output = torch._nested_tensor_from_mask(
Test Evaluation: 100%|██████████| 172/172 [01:31<00:00,  1.89it/s]

Final Confusion Matrix:
[[2328  173]
 [ 189   51]]





In [11]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

accuracy = accuracy_score(all_labels, all_preds)

precision_macro = precision_score(all_labels, all_preds, average="macro")
recall_macro = recall_score(all_labels, all_preds, average="macro")

f1_macro = f1_score(all_labels, all_preds, average="macro")
f1_per_class = f1_score(all_labels, all_preds, average=None, labels=[0, 1])


print(f"Accuracy       : {accuracy:.4f}")
print(f"Precision (avg): {precision_macro:.4f}")
print(f"Recall    (avg): {recall_macro:.4f}")
print(f"F1 Score  (avg): {f1_macro:.4f}")
print(f"F1 Score (class 0): {f1_per_class[0]:.4f}")
print(f"F1 Score (class 1): {f1_per_class[1]:.4f}")


Accuracy       : 0.8679
Precision (avg): 0.5763
Recall    (avg): 0.5717
F1 Score  (avg): 0.5738
F1 Score (class 0): 0.9279
F1 Score (class 1): 0.2198
