## Extracting Narratives and Sub Narratives

In [1]:
!pip install pdfplumber

Collecting pdfplumber
  Downloading pdfplumber-0.11.5-py3-none-any.whl.metadata (42 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/42.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.5/42.5 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pdfminer.six==20231228 (from pdfplumber)
  Downloading pdfminer.six-20231228-py3-none-any.whl.metadata (4.2 kB)
Collecting pypdfium2>=4.18.0 (from pdfplumber)
  Downloading pypdfium2-4.30.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (48 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m48.2/48.2 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
Downloading pdfplumber-0.11.5-py3-none-any.whl (59 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.5/59.5 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pdfminer.six-20231228-py3-none-any.whl (5.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
import pdfplumber
import re
from collections import defaultdict

def extract_narratives(pdf_path):
    # Separate dictionaries for narratives and sub-narratives
    ukraine_narratives = {}
    climate_narratives = {}

    ukraine_subnarratives = defaultdict(list)
    climate_subnarratives = defaultdict(list)

    with pdfplumber.open(pdf_path) as pdf:
        for page_num, page in enumerate(pdf.pages, start=1):
            text = page.extract_text()
            if not text:
                continue

            # Identify taxonomy based on the page number
            if page_num == 1:
                narrative_dict = ukraine_narratives
                subnarrative_dict = ukraine_subnarratives
            elif page_num == 2:
                narrative_dict = climate_narratives
                subnarrative_dict = climate_subnarratives
            else:
                continue  # Ignore other pages if any

            lines = text.split("\n")
            current_narrative = None

            for line in lines:
                line = line.strip()
                if not line or line.startswith("Figure") or line.isdigit():
                    continue  # Skip unwanted lines

                # If line does not start with '-', it's a narrative
                if not line.startswith("-"):
                    current_narrative = line
                    if current_narrative != "Other":  # Exclude "Other"
                        narrative_dict[current_narrative] = len(narrative_dict)  # Assign sequential number
                # If line starts with '-', it's a sub-narrative
                elif current_narrative and current_narrative != "Other":
                    subnarrative_dict[current_narrative].append(line.lstrip("-").strip())

    return ukraine_narratives, climate_narratives, dict(ukraine_subnarratives), dict(climate_subnarratives)

# Example usage
pdf_path = "NARRATIVE-TAXONOMIES.pdf"  # Replace with your actual PDF path
ukraine_narratives, climate_narratives, ukraine_subnarratives, climate_subnarratives = extract_narratives(pdf_path)

print("Ukraine War Narratives:", ukraine_narratives)
print("Climate Change Narratives:", climate_narratives)
print("Ukraine War Sub-Narratives:", ukraine_subnarratives)
print("Climate Change Sub-Narratives:", climate_subnarratives)

Ukraine War Narratives: {'Blaming the war on others rather than the invader': 0, 'Discrediting Ukraine': 1, 'Russia is the Victim': 2, 'Praise of Russia': 3, 'Overpraising the West': 4, 'Speculating war outcomes': 5, 'Discrediting the West, Diplomacy': 6, 'Negative Consequences for the West': 7, 'Distrust towards Media': 8, 'Amplifying war-related fears': 9, 'Hidden plots by secret schemes of powerful groups': 10}
Climate Change Narratives: {'Criticism of climate policies': 0, 'Criticism of institutions and authorities': 1, 'Climate change is beneficial': 2, 'Downplaying climate change': 3, 'Questioning the measurements and science': 4, 'Criticism of climate movement': 5, 'Controversy about green technologies': 6, 'Hidden plots by secret schemes of powerful groups': 7, 'Amplifying Climate Fears': 8, 'Green policies are geopolitical instruments': 9}
Ukraine War Sub-Narratives: {'Blaming the war on others rather than the invader': ['Ukraine is the aggressor', 'The West are the aggressors

## Data Loading

In [3]:
!unzip /content/train.zip -d /content/Train

Archive:  /content/train.zip
   creating: /content/Train/target_4_December_release/
   creating: /content/Train/target_4_December_release/BG/
   creating: /content/Train/target_4_December_release/BG/raw-documents/
  inflating: /content/Train/target_4_December_release/BG/raw-documents/A6_CC_BG_10015.txt  
  inflating: /content/Train/target_4_December_release/BG/raw-documents/A6_CC_BG_10345.txt  
  inflating: /content/Train/target_4_December_release/BG/raw-documents/A6_CC_BG_10380.txt  
  inflating: /content/Train/target_4_December_release/BG/raw-documents/A6_CC_BG_10468.txt  
  inflating: /content/Train/target_4_December_release/BG/raw-documents/A6_CC_BG_10525.txt  
  inflating: /content/Train/target_4_December_release/BG/raw-documents/A6_CC_BG_10556.txt  
  inflating: /content/Train/target_4_December_release/BG/raw-documents/A6_CC_BG_10565.txt  
  inflating: /content/Train/target_4_December_release/BG/raw-documents/A6_CC_BG_10575.txt  
  inflating: /content/Train/target_4_December_rele

In [4]:
import os
import torch
from transformers import XLMRobertaTokenizer
from torch.utils.data import Dataset

In [5]:
# Paths
DATA_DIR = "/content/Train"  # Change this to the dataset location
LANGUAGES = ["BG", "EN", "HI", "PT", "RU"]

ukraine_narratives, climate_narratives, ukraine_subnarratives, climate_subnarratives = extract_narratives(pdf_path)

# Combine Ukraine and Climate Change narratives into a single dictionary
combined_narratives = {**ukraine_narratives, **climate_narratives}

# Reassign values to ensure they are unique and continuously increasing
narratives = {k: i for i, k in enumerate(combined_narratives.keys())}

# Combine Ukraine and Climate Change sub-narratives into a single dictionary
subnarratives = {**ukraine_subnarratives, **climate_subnarratives}

# Calculate the total number of unique sub-narratives
all_sub_narratives = sorted(set(sub for subs in subnarratives.values() for sub in subs))
sub_narrative_indices = {sub: i for i, sub in enumerate(all_sub_narratives)}  # Assign unique indices

num_sub_narratives = len(all_sub_narratives)
print(f"Total sub-narratives: {num_sub_narratives}")

Total sub-narratives: 74


In [6]:
narratives

{'Blaming the war on others rather than the invader': 0,
 'Discrediting Ukraine': 1,
 'Russia is the Victim': 2,
 'Praise of Russia': 3,
 'Overpraising the West': 4,
 'Speculating war outcomes': 5,
 'Discrediting the West, Diplomacy': 6,
 'Negative Consequences for the West': 7,
 'Distrust towards Media': 8,
 'Amplifying war-related fears': 9,
 'Hidden plots by secret schemes of powerful groups': 10,
 'Criticism of climate policies': 11,
 'Criticism of institutions and authorities': 12,
 'Climate change is beneficial': 13,
 'Downplaying climate change': 14,
 'Questioning the measurements and science': 15,
 'Criticism of climate movement': 16,
 'Controversy about green technologies': 17,
 'Amplifying Climate Fears': 18,
 'Green policies are geopolitical instruments': 19}

In [7]:
subnarratives

{'Blaming the war on others rather than the invader': ['Ukraine is the aggressor',
  'The West are the aggressors'],
 'Discrediting Ukraine': ['Rewriting Ukraine’s history',
  'Discrediting Ukrainian nation and society',
  'Discrediting Ukrainian military',
  'Discrediting Ukrainian government and officials and policies',
  'Ukraine is a puppet of the West',
  'Ukraine is a hub for criminal activities',
  'Ukraine is associated with nazism',
  'Situation in Ukraine is hopeless'],
 'Russia is the Victim': ['The West is russophobic',
  'Russia actions in Ukraine are only self-defence',
  'UA is anti-RU extremists'],
 'Praise of Russia': ['Praise of Russian military might',
  'Praise of Russian President Vladimir Putin',
  'Russia is a guarantor of peace and prosperity',
  'Russia has international support from a number of countries and people',
  'Russian invasion has strong national support'],
 'Overpraising the West': ['NATO will destroy Russia',
  'The West belongs in the right side o

In [8]:
# Load Tokenizer
TOKENIZER = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")

# Dataset Class
class NarrativeDataset(Dataset):
    def __init__(self, texts, narrative_labels, sub_narrative_labels, tokenizer, max_len=512):
        self.texts = texts
        self.narrative_labels = narrative_labels
        self.sub_narrative_labels = sub_narrative_labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        narrative_label = self.narrative_labels[idx]
        sub_narrative_label = self.sub_narrative_labels[idx]

        # Tokenize text
        inputs = self.tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=self.max_len,
            return_tensors="pt"
        )
        input_ids = inputs["input_ids"].squeeze(0)
        attention_mask = inputs["attention_mask"].squeeze(0)

        # Convert labels to tensors
        narrative_label = torch.tensor(narrative_label, dtype=torch.float)
        sub_narrative_label = torch.tensor(sub_narrative_label, dtype=torch.float)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "narrative_labels": narrative_label,
            "sub_narrative_labels": sub_narrative_label
        }

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.10M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/615 [00:00<?, ?B/s]

In [9]:
# Load and Clean Data
def load_and_clean_data():
    texts, narrative_labels, sub_narrative_labels = [], [], []
    lang_counts = {lang: 0 for lang in LANGUAGES}
    lang_annotations = {lang: [] for lang in LANGUAGES}  # To store annotations per language

    for lang in LANGUAGES:
        lang_path = os.path.join(DATA_DIR, "target_4_December_release", lang)
        annotation_file = os.path.join(lang_path, "subtask-2-annotations.txt")

        # Debug: Print the annotation file path
        #print(f"Checking annotation file for {lang}: {annotation_file}")

        if not os.path.exists(annotation_file):
            print(f"Warning: No annotations found for {lang}")
            continue

        with open(annotation_file, "r", encoding="utf-8") as file:
            for line in file:
                parts = line.strip().split("\t")
                if len(parts) < 3:
                    print(f"Skipping malformed line: {line}")
                    continue

                article_id, narratives_str, subnarratives_str = parts[0], parts[1], parts[2]

                # Skip samples where both narratives and sub-narratives are "Other"
                if narratives_str == "Other" and subnarratives_str == "Other":
                    continue

                # Read corresponding text file
                text_file = os.path.join(lang_path, "raw-documents", article_id)

                # Debug: Print the text file path
                #print(f"Checking text file for {article_id}: {text_file}")

                if not os.path.exists(text_file):
                    print(f"Warning: Missing text file {article_id} in {lang}")
                    continue

                with open(text_file, "r", encoding="utf-8") as f:
                    text = f.read()

                # Split narratives and sub-narratives
                narrative_list = [n.split(": ")[1] if ": " in n else n for n in narratives_str.split(";")]
                sub_narrative_list = [s.split(": ")[1] if ": " in s else s for s in subnarratives_str.split(";")]

                # Encode Narratives (Multi-Label Classification)
                narrative_label = [1 if narrative in narrative_list else 0 for narrative in narratives.keys()]

                # Encode Sub-Narratives (Multi-Label Classification)
                sub_narrative_label = [0] * num_sub_narratives  # Initialize zero vector
                for sub in sub_narrative_list:
                    if sub in sub_narrative_indices:
                        sub_narrative_label[sub_narrative_indices[sub]] = 1  # Set index to 1

                # Store annotation
                annotation = {
                    "text": text,
                    "article_id": article_id,
                    "narratives": narrative_list,
                    "sub_narratives": sub_narrative_list,
                    "narrative_label": narrative_label,
                    "sub_narrative_label": sub_narrative_label
                }

                texts.append(text)
                narrative_labels.append(narrative_label)
                sub_narrative_labels.append(sub_narrative_label)
                lang_annotations[lang].append(annotation)
                lang_counts[lang] += 1

    return texts, narrative_labels, sub_narrative_labels, lang_counts, lang_annotations

In [10]:
# Execution
if __name__ == "__main__":
    print("Loading data...")
    texts, narrative_labels, sub_narrative_labels, lang_counts, lang_annotations = load_and_clean_data()

    # Print statistics
    print("\n### Annotations per Language ###")
    for lang, count in lang_counts.items():
        print(f"{lang}: {count} annotations")

    print("\n### Sample Preprocessed Output per Language ###")
    for lang, lang_anns in lang_annotations.items():
        print(f"\nLanguage: {lang}")
        lang_dataset = NarrativeDataset(
            [ann["text"] for ann in lang_anns],
            [ann["narrative_label"] for ann in lang_anns],
            [ann["sub_narrative_label"] for ann in lang_anns],
            TOKENIZER
        )

        # Display 2 examples per language
        for i in range(min(2, len(lang_dataset))):
            sample = lang_dataset[i]
            print(f"\nExample {i+1}:")
            print("Article ID:", lang_anns[i]["article_id"])
            print("Tokenized Input IDs:", sample["input_ids"][:20])  # Print first 20 tokens
            print("Decoded Text:", TOKENIZER.decode(sample["input_ids"][:100]))  # Decode first 100 tokens
            print("Narrative Labels:", sample["narrative_labels"].numpy())  # Print narrative labels
            #print("Sub-Narrative Labels:", sample["sub_narrative_labels"].numpy())  # Print sub-narrative labels

Loading data...

### Annotations per Language ###
BG: 371 annotations
EN: 230 annotations
HI: 268 annotations
PT: 373 annotations
RU: 133 annotations

### Sample Preprocessed Output per Language ###

Language: BG

Example 1:
Article ID: BG_670.txt
Tokenized Input IDs: tensor([     0,   1089,  22617,   1669,     29,  47829,   2097,  32275,     69,
           137,    197,  35359,  53335,   2827,  40053,    155,    135, 128601,
            29,  12747])
Decoded Text: <s> Опитът на колективния Запад да „обезкърви Русия“ с ръцете на властите в Киев „се провали с гръм и трясък“ и скоро от Украйна ... Опитът на колективния Запад да „обезкърви Русия“ с ръцете на властите в Киев „се провали с гръм и трясък“ и скоро от Украйна няма да остане почти нищо, ако не започне процесът на разрешаване на този въоръжен конфликт
Narrative Labels: [1. 0. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]

Example 2:
Article ID: A7_URW_BG_4793.txt
Tokenized Input IDs: tensor([     0, 160480, 108723,  45653

## Training the model

In [18]:
import torch.nn as nn
from transformers import XLMRobertaModel

class NarrativeClassificationModel(nn.Module):
    def __init__(self, num_narratives, num_sub_narratives):
        super(NarrativeClassificationModel, self).__init__()
        self.xlm_roberta = XLMRobertaModel.from_pretrained("xlm-roberta-base")
        self.dropout = nn.Dropout(0.1)

        # Separate classifiers for narratives and sub-narratives
        self.narrative_classifier = nn.Linear(self.xlm_roberta.config.hidden_size, num_narratives)
        self.sub_narrative_classifier = nn.Linear(self.xlm_roberta.config.hidden_size, num_sub_narratives)

    def forward(self, input_ids, attention_mask):
        outputs = self.xlm_roberta(
            input_ids=input_ids,
            attention_mask=attention_mask
        )

        pooled_output = outputs.last_hidden_state[:, 0, :]  # Using [CLS] token representation
        pooled_output = self.dropout(pooled_output)

        narrative_logits = self.narrative_classifier(pooled_output)
        sub_narrative_logits = self.sub_narrative_classifier(pooled_output)

        return narrative_logits, sub_narrative_logits

In [19]:
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import f1_score, accuracy_score

# Initialize model
num_narratives = len(narratives)
model = NarrativeClassificationModel(num_narratives, num_sub_narratives)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Initialize dataset
dataset = NarrativeDataset(texts, narrative_labels, sub_narrative_labels, TOKENIZER)

# Split dataset into train and validation
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

# Loss functions and optimizer
criterion = nn.BCEWithLogitsLoss()  # Binary cross-entropy loss with logits
optimizer = optim.AdamW(model.parameters(), lr=2e-5)

# Learning rate scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

In [22]:
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=5):
    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0

        # Training phase
        for batch in train_loader:
            optimizer.zero_grad()

            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            narrative_labels = batch['narrative_labels'].to(device)
            sub_narrative_labels = batch['sub_narrative_labels'].to(device)

            # Forward pass
            narrative_logits, sub_narrative_logits = model(input_ids, attention_mask)

            # Calculate losses
            narrative_loss = criterion(narrative_logits, narrative_labels)
            sub_narrative_loss = criterion(sub_narrative_logits, sub_narrative_labels)
            total_loss = narrative_loss + sub_narrative_loss

            # Backward pass and optimize
            total_loss.backward()
            optimizer.step()

            train_loss += total_loss.item()

        # Validation phase
        model.eval()
        val_loss = 0.0
        all_narrative_preds = []
        all_narrative_labels = []
        all_sub_narrative_preds = []
        all_sub_narrative_labels = []

        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                narrative_labels = batch['narrative_labels'].to(device)
                sub_narrative_labels = batch['sub_narrative_labels'].to(device)

                # Forward pass
                narrative_logits, sub_narrative_logits = model(input_ids, attention_mask)

                # Calculate losses
                narrative_loss = criterion(narrative_logits, narrative_labels)
                sub_narrative_loss = criterion(sub_narrative_logits, sub_narrative_labels)
                total_loss = narrative_loss + sub_narrative_loss
                val_loss += total_loss.item()

                # Store predictions and labels for metrics
                narrative_preds = torch.sigmoid(narrative_logits) > 0.5
                sub_narrative_preds = torch.sigmoid(sub_narrative_logits) > 0.5

                all_narrative_preds.extend(narrative_preds.cpu().numpy())
                all_narrative_labels.extend(narrative_labels.cpu().numpy())
                all_sub_narrative_preds.extend(sub_narrative_preds.cpu().numpy())
                all_sub_narrative_labels.extend(sub_narrative_labels.cpu().numpy())

        # Calculate metrics
        train_loss /= len(train_loader)
        val_loss /= len(val_loader)

        narrative_f1 = f1_score(all_narrative_labels, all_narrative_preds, average='micro')
        sub_narrative_f1 = f1_score(all_sub_narrative_labels, all_sub_narrative_preds, average='micro')

        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
        print(f"Narrative F1: {narrative_f1:.4f}")

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), "best_model.pt")
            print("Saved best model!")

        scheduler.step()

    return model

# Start training
trained_model = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler)
trained_model

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 1/5
Train Loss: 0.2791 | Val Loss: 0.2518
Narrative F1: 0.2971
Saved best model!


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 2/5
Train Loss: 0.2484 | Val Loss: 0.2375
Narrative F1: 0.4892
Saved best model!


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 3/5
Train Loss: 0.2295 | Val Loss: 0.2278
Narrative F1: 0.4600
Saved best model!


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 4/5
Train Loss: 0.2243 | Val Loss: 0.2251
Narrative F1: 0.4558
Saved best model!


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 5/5
Train Loss: 0.2214 | Val Loss: 0.2237
Narrative F1: 0.4817
Saved best model!


NarrativeClassificationModel(
  (xlm_roberta): XLMRobertaModel(
    (embeddings): XLMRobertaEmbeddings(
      (word_embeddings): Embedding(250002, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): XLMRobertaEncoder(
      (layer): ModuleList(
        (0-11): 12 x XLMRobertaLayer(
          (attention): XLMRobertaAttention(
            (self): XLMRobertaSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): XLMRobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768

## Prediction

In [26]:
def predict_validation_data(model, val_loader, narratives_dict, sub_narratives_list, tokenizer):
    """
    Predict narratives and sub-narratives for validation data

    Args:
        model: Trained model
        val_loader: Validation DataLoader
        narratives_dict: Dictionary of narrative names to indices
        sub_narratives_list: List of all sub-narrative names
        tokenizer: Tokenizer used for the model

    Returns:
        tuple: (texts, true_narratives, pred_narratives, true_subnarratives, pred_subnarratives)
    """
    model.eval()
    texts = []
    true_narratives = []
    pred_narratives = []
    true_subnarratives = []
    pred_subnarratives = []

    # Reverse mappings for label indices to names
    idx_to_narrative = {v: k for k, v in narratives_dict.items()}
    idx_to_subnarrative = {i: name for i, name in enumerate(sub_narratives_list)}

    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)

            # Get true labels
            batch_true_narrative = batch['narrative_labels'].cpu().numpy()
            batch_true_subnarrative = batch['sub_narrative_labels'].cpu().numpy()

            # Get original texts (from dataset)
            batch_texts = [val_dataset.dataset.texts[val_dataset.indices[i]]
                          for i in range(len(input_ids))]

            # Get predictions
            narrative_logits, sub_narrative_logits = model(input_ids, attention_mask)

            # Convert to probabilities and binary predictions
            narrative_probs = torch.sigmoid(narrative_logits).cpu().numpy()
            sub_narrative_probs = torch.sigmoid(sub_narrative_logits).cpu().numpy()

            narrative_preds = (narrative_probs > 0.5).astype(int)
            sub_narrative_preds = (sub_narrative_probs > 0.5).astype(int)

            # Process each sample in the batch
            for i in range(len(input_ids)):
                # Store original text
                texts.append(batch_texts[i])

                # Get true labels
                true_narrative_idx = [idx for idx, val in enumerate(batch_true_narrative[i]) if val == 1]
                true_narrative_names = [idx_to_narrative[idx] for idx in true_narrative_idx]
                true_narratives.append(true_narrative_names)

                true_subnarrative_idx = [idx for idx, val in enumerate(batch_true_subnarrative[i]) if val == 1]
                true_subnarrative_names = [idx_to_subnarrative[idx] for idx in true_subnarrative_idx]
                true_subnarratives.append(true_subnarrative_names)

                # Get predicted labels
                pred_narrative_idx = [idx for idx, val in enumerate(narrative_preds[i]) if val == 1]
                pred_narrative_names = [idx_to_narrative[idx] for idx in pred_narrative_idx]
                pred_narratives.append(pred_narrative_names)

                pred_subnarrative_idx = [idx for idx, val in enumerate(sub_narrative_preds[i]) if val == 1]
                pred_subnarrative_names = [idx_to_subnarrative[idx] for idx in pred_subnarrative_idx]
                pred_subnarratives.append(pred_subnarrative_names)

    return texts, true_narratives, pred_narratives, true_subnarratives, pred_subnarratives

# Load the best model
model.load_state_dict(torch.load("best_model.pt"))
model = model.to(device)

# Get predictions for validation data
(val_texts, val_true_narratives, val_pred_narratives,
 val_true_subnarratives, val_pred_subnarratives) = predict_validation_data(
    model, val_loader, narratives, all_sub_narratives, TOKENIZER
)

# Analyze results
def print_validation_results(texts, true_narr, pred_narr, true_sub, pred_sub, num_samples=5):
    """Print sample validation results"""
    print(f"\nValidation Results (showing {num_samples} samples):")
    print("="*80)

    correct_count = 0
    for i in range(min(num_samples, len(texts))):
        narrative_correct = set(true_narr[i]) == set(pred_narr[i])
        subnarrative_correct = set(true_sub[i]) == set(pred_sub[i])

        if narrative_correct and subnarrative_correct:
            correct_count += 1

        print(f"\nSample {i+1}:")
        print(f"Text: {texts[i][:100]}...")
        print(f"\nTrue Narratives: {', '.join(true_narr[i]) if true_narr[i] else 'None'}")
        print(f"Pred Narratives: {', '.join(pred_narr[i]) if pred_narr[i] else 'None'}")
        #print(f"Match: {'✓' if narrative_correct else '✗'}")

        #print(f"\nTrue Sub-narratives: {', '.join(true_sub[i]) if true_sub[i] else 'None'}")
        #print(f"Pred Sub-narratives: {', '.join(pred_sub[i]) if pred_sub[i] else 'None'}")
        #print(f"Match: {'✓' if subnarrative_correct else '✗'}")
        print("-"*60)

    # Calculate accuracy
    total_samples = min(num_samples, len(texts))
    print(f"\nAccuracy in shown samples: {correct_count}/{total_samples} ({correct_count/total_samples:.1%})")

# Print sample results
print_validation_results(val_texts, val_true_narratives, val_pred_narratives,
                       val_true_subnarratives, val_pred_subnarratives, num_samples=10)

# Calculate overall accuracy
def calculate_overall_accuracy(true_narr, pred_narr, true_sub, pred_sub):
    """Calculate overall accuracy metrics"""
    narrative_correct = 0
    subnarrative_correct = 0
    both_correct = 0

    for i in range(len(true_narr)):
        if set(true_narr[i]) == set(pred_narr[i]):
            narrative_correct += 1
        if set(true_sub[i]) == set(pred_sub[i]):
            subnarrative_correct += 1
        if (set(true_narr[i]) == set(pred_narr[i])) and (set(true_sub[i]) == set(pred_sub[i])):
            both_correct += 1

    total = len(true_narr)
    print("\nOverall Validation Accuracy:")
    print("="*60)
    print(f"Narrative Accuracy: {narrative_correct}/{total} ({narrative_correct/total:.2%})")
    #print(f"Sub-narrative Accuracy: {subnarrative_correct}/{total} ({subnarrative_correct/total:.2%})")
    #print(f"Both Correct: {both_correct}/{total} ({both_correct/total:.2%})")

# Calculate overall accuracy
calculate_overall_accuracy(val_true_narratives, val_pred_narratives,
                         val_true_subnarratives, val_pred_subnarratives)


Validation Results (showing 10 samples):

Sample 1:
Text: Climate Expert: ‘Three Strikes’ Against Climate Alarmism 

 Facts and data show “three strikes and c...

True Narratives: Hidden plots by secret schemes of powerful groups, Downplaying climate change, Questioning the measurements and science, Criticism of climate movement
Pred Narratives: Criticism of institutions and authorities
------------------------------------------------------------

Sample 2:
Text: Русия заби още един пирон в ковчега на разлагащата се хегемония на САЩ

▪️Договорът за всеобхватно с...

True Narratives: Praise of Russia, Negative Consequences for the West
Pred Narratives: Discrediting the West, Diplomacy
------------------------------------------------------------

Sample 3:
Text: अमेरिका के राष्ट्रपति जो बायडेन के किस्म-किस्म के वीडियो सामने आते रहते हैं, जिसमें वो खोए-खोए से नज...

True Narratives: Discrediting the West, Diplomacy
Pred Narratives: None
---------------------------------------------------