Instructions

1.   Install libraries in requirements.txt
2.   Download the data train.h5 and val.h5 and update setup cell with the right path.
3.   Hit Run All
4. Remove collab specific instructions (those that start with !)



In [None]:
import torch
import torch.nn as nn
from torch.optim import Adam
from transformers import AutoTokenizer, AutoModel
from transformers import ViTImageProcessor, ViTModel, GPT2Tokenizer, GPT2LMHeadModel
from transformers.models.deberta_v2.modeling_deberta_v2 import DebertaV2Encoder
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import precision_recall_curve
from torchvision import transforms
import h5py
from google.colab import drive
from tqdm import tqdm
import numpy as np
from networkx import DiGraph, relabel_nodes, all_pairs_shortest_path_length
from torchvision.transforms import ToPILImage
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from torch_geometric.nn import GATv2Conv,GCNConv
from PIL import Image
import networkx as nx

# Setup

In [None]:
BATCH_SIZE = 24
NUM_EPOCHS = 15
# Mount Google Drive


# Define paths to datasets
train_path = './data/train.h5'
val_path = './data/val.h5'
model_path = './models/multimodal_multilabel.pth'
# Create dataset objects
tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-base")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


## Graph and base dataset

In [None]:
# Class Map

label_map = {
    "Root":0,
    "Logos": 1,
    "Repetition": 2,
    "Obfuscation, Intentional vagueness, Confusion": 3,
    "Reasoning": 4,
    "Justification": 5,
    "Slogans": 6,
    "Bandwagon": 7,
    "Appeal to authority": 8,
    "Flag-waving": 9,
    "Appeal to fear/prejudice": 10,
    "Simplification": 11,
    "Causal Oversimplification": 12,
    "Black-and-white Fallacy/Dictatorship": 13,
    "Thought-terminating cliché": 14,
    "Distraction": 15,
    "Misrepresentation of Someone's Position (Straw Man)": 16,
    "Presenting Irrelevant Data (Red Herring)": 17,
    "Whataboutism": 18,
    "Ethos": 19,
    "Glittering generalities (Virtue)": 20,
    "Ad Hominem": 21,
    "Doubt": 22,
    "Name calling/Labeling": 23,
    "Smears": 24,
    "Reductio ad hitlerum": 25,
    "Pathos": 26,
    "Exaggeration/Minimisation": 27,
    "Loaded Language": 28,
    "Transfer": 29,
    "Appeal to (Strong) Emotions":30
}

def get_label_index(label):
    return label_map[label]

def get_label_for_index(index):
    for label, idx in label_map.items():
        if idx == index:
            return label
    assert False, f"Unknown index: {index}"
    return None



# Class hierarchy graph
# used by GAT to embed hierarchy into encoders

G = DiGraph()
# Add top-level categories
G.add_node("Root", index=0)
G.add_edge("Root", "Logos")
G.add_edge("Logos", "Repetition")
G.add_edge("Logos", "Obfuscation, Intentional vagueness, Confusion")
G.add_edge("Logos", "Reasoning")
G.add_edge("Logos", "Justification")
G.add_edge('Justification', "Slogans")
G.add_edge('Justification', "Bandwagon")
G.add_edge('Justification', "Appeal to authority")
G.add_edge('Justification', "Flag-waving")
G.add_edge('Justification', "Appeal to fear/prejudice")
G.add_edge('Reasoning', "Simplification")
G.add_edge('Simplification', "Causal Oversimplification")
G.add_edge('Simplification', "Black-and-white Fallacy/Dictatorship")
G.add_edge('Simplification', "Thought-terminating cliché")
G.add_edge('Reasoning', "Distraction")
G.add_edge('Distraction', "Misrepresentation of Someone's Position (Straw Man)")
G.add_edge('Distraction', "Presenting Irrelevant Data (Red Herring)")
G.add_edge('Distraction', "Whataboutism")
G.add_edge("Root", "Ethos")
G.add_edge('Ethos', "Appeal to authority")
G.add_edge('Ethos', "Glittering generalities (Virtue)")
G.add_edge('Ethos', "Bandwagon")
G.add_edge('Ethos', "Ad Hominem")
G.add_edge('Ethos', "Transfer")
G.add_edge('Ad Hominem', "Doubt")
G.add_edge('Ad Hominem', "Name calling/Labeling")
G.add_edge('Ad Hominem', "Smears")
G.add_edge('Ad Hominem', "Reductio ad hitlerum")
G.add_edge('Ad Hominem', "Whataboutism")
G.add_edge("Root", "Pathos")
G.add_edge('Pathos', "Exaggeration/Minimisation")
G.add_edge('Pathos', "Loaded Language")
G.add_edge('Pathos', "Appeal to (Strong) Emotions")
G.add_edge('Pathos', "Appeal to fear/prejudice")
G.add_edge('Pathos', "Flag-waving")
G.add_edge('Pathos', "Transfer")

# Separate edges by layer
layer1_edges = [
    (get_label_index("Root"), get_label_index("Logos")),
    (get_label_index("Root"), get_label_index("Ethos")),
    (get_label_index("Root"), get_label_index("Pathos")),
]

layer2_edges = [
    (get_label_index("Logos"), get_label_index("Repetition")),
    (get_label_index("Logos"), get_label_index("Obfuscation, Intentional vagueness, Confusion")),
    (get_label_index("Logos"), get_label_index("Reasoning")),
    (get_label_index("Logos"), get_label_index("Justification")),
    (get_label_index("Ethos"), get_label_index("Appeal to authority")),
    (get_label_index("Ethos"), get_label_index("Glittering generalities (Virtue)")),
    (get_label_index("Ethos"), get_label_index("Bandwagon")),
    (get_label_index("Pathos"), get_label_index("Exaggeration/Minimisation")),
    (get_label_index("Pathos"), get_label_index("Loaded Language")),
]

layer3_edges = [
    (get_label_index("Reasoning"), get_label_index("Simplification")),
    (get_label_index("Reasoning"), get_label_index("Distraction")),
    (get_label_index("Simplification"), get_label_index("Causal Oversimplification")),
    (get_label_index("Distraction"), get_label_index("Whataboutism")),
]

layers = [layer1_edges, layer2_edges, layer3_edges]

def depth_to_label(depth):
    if depth == 0:
        return  layer1_edges
    elif depth == 1:
        return layer2_edges
    elif depth == 2:
        return layer3_edges
    else:
        assert False, f"Unknown depth: {depth}"

# Dataset that returns tokenized string and meme
class HDF5Dataset(torch.utils.data.Dataset):
    def __init__(self, hdf5_path, label_map, hierarchy, transform=None):
        """
        Args:
            hdf5_path (str): Path to the HDF5 file.
            label_map (dict): Mapping from label text to index.
            hierarchy (DiGraph): Directed graph representing the hierarchy.
            transform (Callable, optional): Optional image transformation.
        """
        self.hdf5_path = hdf5_path
        self.label_map = label_map
        self.hierarchy = hierarchy
        self.transform = transform
        self.hf = None  # File handler to be opened on-demand

    def __len__(self):
        with h5py.File(self.hdf5_path, 'r') as hf:
            return hf['labels'].shape[0]

    def __getitem__(self, idx):
        global tokenizer
        if self.hf is None:
            self.hf = h5py.File(self.hdf5_path, 'r')
        text = self.hf["text"][idx].decode("utf-8")
        # Load data
        # get tokens from tokenizer
        text_token_ids, token_type_ids, attention_mask = tokenizer(text)
        image = torch.tensor(self.hf['images'][idx], dtype=torch.float32)

        # Decode and convert text labels to binary vector
        label_bytes = self.hf['labels'][idx]
        label_texts = label_bytes.decode("utf-8").strip().split("<?>")  # Decode bytes to string
        label_vector = torch.zeros(len(self.label_map), dtype=torch.float32)

        # Set leaf labels
        leaf_indices = [self.label_map[label] for label in label_texts if label in self.label_map]

        # Propagate labels up the hierarchy
        all_indices = set(leaf_indices)
        for leaf in label_texts:
            if leaf in self.label_map:
                node = leaf
                while node in self.hierarchy:
                    all_indices.add(self.label_map[node])
                    parent_nodes = list(self.hierarchy.predecessors(node))
                    if not parent_nodes or parent_nodes[0] == "Root":
                        break
                    node = parent_nodes[0]
            elif leaf == "":
              return text_token_ids,token_type_ids, text_attention_masks, image, label_vector
            else:
              assert leaf == "Root", f"Unknown label: {label_texts}"

        # Set binary vector for all labels
        for idx in all_indices:
            label_vector[idx] = 1.0

        # Apply optional image transformation
        if self.transform:
            image = self.transform(image)
        return text_token_ids, token_type_ids, text_attention_masks, image, label_vector

    def close(self):
        if self.hf is not None:
            self.hf.close()
            self.hf = None




## Hierarchy dataset with soft prompts

In [None]:

# Dataset that returns tokenized text that is appended with soft prompt and meme
class HDPureDataset(torch.utils.data.Dataset):
    def __init__(self, hdf5_path, label_map, hierarchy, transform=None):
        """
        Args:
            hdf5_path (str): Path to the HDF5 file.
            label_map (dict): Mapping from label text to index.
            hierarchy (DiGraph): Directed graph representing the hierarchy.
            transform (Callable, optional): Optional image transformation.
        """
        self.hdf5_path = hdf5_path
        self.label_map = label_map
        self.hierarchy = hierarchy
        self.transform = transform
        self.hf = None  # File handler to be opened on-demand
        self.hierarchy_levels = 3
        print(self.hierarchy_levels)

    def construct_prompt(self):
        """
        Construct the prompt based on the hierarchy depth.
        """
        prompt_tokens = ["[V{}] [MASK]".format(i) for i in range(1, self.hierarchy_levels + 1)]
        return " ".join(prompt_tokens)

    def __len__(self):
        with h5py.File(self.hdf5_path, 'r') as hf:
            return hf['labels'].shape[0]

    def __getitem__(self, idx):
        if self.hf is None:
            self.hf = h5py.File(self.hdf5_path, 'r')
        text = self.hf["text"][idx].decode("utf-8")
        # Tokenize the combined input
        tokenized = tokenizer(text, padding="max_length", max_length=128,truncation=True, return_tensors="pt")

        text_token_ids = tokenized["input_ids"].squeeze(0)
        token_type_ids = tokenized["token_type_ids"].squeeze(0)
        text_attention_masks = tokenized["attention_mask"].squeeze(0)

        image = torch.tensor(self.hf['images'][idx], dtype=torch.float32)

        # Decode and convert text labels to binary vector
        label_bytes = self.hf['labels'][idx]
        label_texts = label_bytes.decode("utf-8").strip().split("<?>")  # Decode bytes to string
        label_vector = torch.zeros(len(self.label_map), dtype=torch.float32)

        # Set leaf labels
        leaf_indices = [self.label_map[label] for label in label_texts if label in self.label_map]

        # Propagate labels up the hierarchy
        all_indices = set(leaf_indices)
        for leaf in label_texts:
            if leaf in self.label_map:
                node = leaf
                while node in self.hierarchy:
                    all_indices.add(self.label_map[node])
                    parent_nodes = list(self.hierarchy.predecessors(node))
                    if not parent_nodes or parent_nodes[0] == "Root":
                        break
                    node = parent_nodes[0]
            elif leaf == "":
              return text_token_ids, token_type_ids, text_attention_masks, image, label_vector
            else:
              assert leaf == "Root", f"Unknown label: {label_texts}"

        # Set binary vector for all labels
        for idx in all_indices:
            label_vector[idx] = 1.0

        # Apply optional image transformation
        if self.transform:
            image = self.transform(image)

        return text_token_ids, token_type_ids, text_attention_masks, image, label_vector

    def close(self):
        if self.hf is not None:
            self.hf.close()
            self.hf = None

## Load data

In [None]:
train_dataset = HDPureDataset(train_path, label_map, G)
val_dataset = HDPureDataset(val_path, label_map, G)
best_f1_score_g = None
best_thresholds = None
# Create DataLoaders
all_val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
all_train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
3
3


# Transformer

In [None]:
class TransformerFusion(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=256, num_heads=8, num_layers=6, ff_dim=512, dropout=0.3, mlp_hidden_dim=512):
        super(TransformerFusion, self).__init__()
        # Transformer layers for text and image separately
        self.text_transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=input_dim, nhead=num_heads, dim_feedforward=ff_dim, dropout=dropout),
            num_layers=num_layers
        )
        self.image_transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=input_dim, nhead=num_heads, dim_feedforward=ff_dim, dropout=dropout),
            num_layers=num_layers
        )

        # Cross-attention mechanism
        self.cross_attn = nn.MultiheadAttention(embed_dim=input_dim, num_heads=num_heads, dropout=dropout)

        # Learnable attention pooling
        self.text_pool_query = nn.Parameter(torch.empty(1, input_dim))  # Query for text
        self.image_pool_query = nn.Parameter(torch.empty(1, input_dim))  # Query for image
        self.cross_pool_query = nn.Parameter(torch.empty(1, input_dim))  # Query for cross-attention

        # Initialize pooling queries using Xavier Initialization
        nn.init.xavier_uniform_(self.text_pool_query)
        nn.init.xavier_uniform_(self.image_pool_query)
        nn.init.xavier_uniform_(self.cross_pool_query)

        # MLP for feature fusion
        self.mlp = nn.Sequential(
            nn.Linear(3 * input_dim, mlp_hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_dim, hidden_dim)
        )

        self.text_norm = nn.LayerNorm(input_dim)
        self.image_norm = nn.LayerNorm(input_dim)

    def attention_pooling(self, features, query):
        """
        Perform attention-based pooling.
        Args:
            features: Tensor of shape [seq_len, batch_size, input_dim]
            query: Tensor of shape [1, input_dim]
        Returns:
            Pooled output: Tensor of shape [batch_size, input_dim]
        """
        attn_weights = torch.softmax(torch.matmul(features.transpose(0, 1), query.T), dim=1)  # [batch_size, seq_len, 1]
        pooled = torch.sum(attn_weights * features.transpose(0, 1), dim=1)  # [batch_size, input_dim]
        return pooled

    def forward(self, text_features, image_features):
        text_features = self.text_norm(text_features)
        image_features = self.image_norm(image_features)

        # Self-attention
        text_features = self.text_transformer(text_features)
        image_features = self.image_transformer(image_features)

        # Cross-attention
        cross_attn_output, _ = self.cross_attn(
            query=text_features,
            key=image_features,
            value=image_features
        )

        # Attention-based pooling
        text_pooled = self.attention_pooling(text_features, self.text_pool_query)
        image_pooled = self.attention_pooling(image_features, self.image_pool_query)
        cross_attn_pooled = self.attention_pooling(cross_attn_output, self.cross_pool_query)

        # Concatenate pooled features and fuse with MLP
        combined_features = torch.cat([text_pooled, text_pooled, text_pooled], dim=-1)
        fused_features = self.mlp(combined_features)

        return fused_features



# Model

## Base model


This is a model without enhancements. It consists of embedder and an optional  transformer fusion.

In [None]:
class MulticlassMemeClassifier(nn.Module):
    def __init__(self, hidden_dim, hierarchy):
        super(MulticlassMemeClassifier, self).__init__()
        # Text encoder: Pretrained BERT
        self.text_encoder = AutoModel.from_pretrained("microsoft/deberta-v3-base")
        self.text_fc = nn.Sequential(nn.Dropout(0.1), nn.Linear(768, 768),
                                      nn.Dropout(0.1))

        # Image encoder: Pretrained ViT
        self.image_encoder = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        self.image_fc = nn.Sequential(nn.Dropout(0.1),
                                      nn.Linear(768, 768),
                                      nn.Dropout(0.1))

        # Fusion mechanism
        self.fusion = TransformerFusion(input_dim=768, hidden_dim=hidden_dim)

        # Create classifier dictionary for hierarchical nodes
        self.classifier_dict = nn.ModuleDict({
            node: nn.Linear(hidden_dim, 1) for node in hierarchy.nodes if node != "Root"
        })

        self.hierarchy = hierarchy

    def forward(self, text_input_ids, token_type_ids, text_attention_mask, images):
        text_hidden_states = self.text_encoder(
            input_ids=text_input_ids,
            attention_mask=text_attention_mask,
            token_type_ids=token_type_ids
        ).last_hidden_state

        # Text encoding
        text_features = self.text_fc(text_hidden_states)

        # Image encoding
        image_hidden_states = self.image_encoder(pixel_values=images).last_hidden_state
        image_features = self.image_fc(image_hidden_states)

        # Transpose to [seq_len, batch_size, input_dim]
        text_features = text_features.permute(1, 0, 2)
        image_features = image_features.permute(1, 0, 2)

        # Fusion
        fused_features = self.fusion(text_features, image_features)

        # Outputs for each node in the hierarchy
        outputs = {node: self.classifier_dict[node](fused_features) for node in self.hierarchy.nodes if node != "Root"}
        return outputs


## Heirarchy model

This model is integrated with GAT and soft prompt tuning

In [None]:
from transformers import AutoTokenizer
from torch_geometric.nn import GATConv
import torch.nn.functional as F

"""
# Verbalizer predictions
        outputs = {}
        for i in range(self.L):
            hidden_state = pred_hidden_states[i]
            label_indices = [self.label2idx[label] for label in self.labels_per_layer[i]]
            logits = torch.matmul(hidden_state, self.label_embeddings[label_indices].t())
            for idx, label in enumerate(self.labels_per_layer[i]):
                outputs[label] = logits[:, idx].unsqueeze(-1)

        # Image encoding
        image_features = self.image_encoder(pixel_values=images).last_hidden_state
        image_features = self.image_fc(image_features)
        print(image_features.shape, text_features.shape)
        image_features = image_features.permute(1, 0, 2)
        text_features = text_features.permute(1, 0, 2)
        # Verbalizer predictions
"""
class HierarchyAwareClassifier(nn.Module):
    def __init__(self, hidden_dim, hierarchy, edge_levels, pretrained_model="microsoft/deberta-v3-base", num_gat_layers=3):
        super(HierarchyAwareClassifier, self).__init__()
        self.hierarchy = hierarchy
        self.hidden_dim = hidden_dim
        self.L = len(edge_levels)  # Hierarchy depth

        # Tokenizer and text encoder
        self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
        self.text_encoder = AutoModel.from_pretrained(pretrained_model, output_hidden_states=True)
        self.deberta_hidden_dim = self.text_encoder.config.hidden_size

        # Add [Vi] tokens to tokenizer
        self.template_tokens = [f"[V{i}]" for i in range(1, self.L + 1)]
        self.tokenizer.add_tokens(self.template_tokens)

        # Add [PRED] token to tokenizer
        self.tokenizer.add_tokens(["[PRED]"])

        # Resize text encoder embeddings for new tokens
        self.text_encoder.resize_token_embeddings(len(self.tokenizer))

        # Initialize template embeddings [V1], [V2], ..., [VL] (randomly initialized, learned during training)
        self.template_embeddings = nn.Parameter(torch.randn(self.L, self.deberta_hidden_dim))

        # Initialize [PRED] token embedding using [MASK] token embedding
        mask_token_id = self.tokenizer.mask_token_id
        self.pred_embedding = nn.Parameter(
            self.text_encoder.embeddings.word_embeddings.weight[mask_token_id].clone().detach()
        )

        # Label embeddings for each label in the hierarchy
        self.labels = [node for node in hierarchy.nodes if node != "Root"]
        self.num_labels = len(self.labels)
        self.label2idx = {label: idx for idx, label in enumerate(self.labels)}
        self.idx2label = {idx: label for label, idx in self.label2idx.items()}

        # Assign levels to labels
        self.levels = self.assign_levels()

        # Labels per layer
        self.labels_per_layer = [[] for _ in range(self.L)]
        for label in self.labels:
            level = self.levels[label]
            if level < self.L:
                self.labels_per_layer[level].append(label)

        # Virtual label embeddings initialized randomly
        self.label_embeddings = nn.Parameter(torch.randn(self.num_labels, self.deberta_hidden_dim))

        # GAT for hierarchy injection
        self.gat_layers = nn.ModuleList([
            GATConv(self.deberta_hidden_dim, self.deberta_hidden_dim, heads=1)
            for _ in range(num_gat_layers)
        ])

        # Build edge index
        self.edge_index = self.build_edge_index()
        self.fusion_model = TransformerFusion(self.deberta_hidden_dim)
        # Image encoder: Pretrained ViT
        self.image_encoder = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        self.image_fc = nn.Sequential(nn.Dropout(0.1))

        # Fusion mechanism
        self.fusion_fc = nn.Linear(self.deberta_hidden_dim + self.image_encoder.config.hidden_size, self.deberta_hidden_dim)

        # Classifiers for each node
        self.classifier_dict = nn.ModuleDict({
            node: nn.Linear(hidden_dim, 1) for node in self.labels
        })
        self.verbalizer_projection = torch.nn.Linear(14, self.deberta_hidden_dim)

    def assign_levels(self):
        # Assign levels to each label in the hierarchy using BFS
        levels = {}
        queue = [('Root', 0)]
        visited = set()
        while queue:
            node, level = queue.pop(0)
            if node in visited:
                continue
            visited.add(node)
            levels[node] = level
            for child in self.hierarchy.neighbors(node):
                if child not in visited:
                    queue.append((child, level + 1))
        return levels

    def build_edge_index(self):
        # Build edge index for GAT
        edge_index = [[], []]
        node_indices = {label: idx for idx, label in enumerate(self.labels)}
        for i in range(self.L):
            node_indices[f'V{i+1}'] = self.num_labels + i

        # Add hierarchy edges
        for parent, child in self.hierarchy.edges():
            if parent == "Root":
                continue
            if parent in node_indices and child in node_indices:
                edge_index[0].append(node_indices[parent])
                edge_index[1].append(node_indices[child])

        # Add virtual node edges
        for i in range(self.L):
            vi_idx = node_indices[f'V{i+1}']
            for label in self.labels_per_layer[i]:
                label_idx = node_indices[label]
                edge_index[0].extend([vi_idx, label_idx])
                edge_index[1].extend([label_idx, vi_idx])

        return torch.tensor(edge_index, dtype=torch.long)

    def get_gat_embeddings(self):

      num_template_embeddings = self.template_embeddings.size(0)  # L

      num_nodes = self.edge_index.max().item() + 1
      input_embeddings = self.template_embeddings
      if num_template_embeddings < num_nodes:
          num_missing_nodes = num_nodes - num_template_embeddings

          extra_embeddings = self.pred_embedding.unsqueeze(0).repeat(num_missing_nodes, 1)  # Shape: (num_missing_nodes, hidden_dim)

          input_embeddings = torch.cat([input_embeddings, extra_embeddings], dim=0)



      # Check edge_index validity
      num_nodes = input_embeddings.size(1)
      if self.edge_index.max() >= num_nodes or self.edge_index.min() < 0:
          raise ValueError("Invalid node indices in edge_index!")
      # Apply GAT layers
      for layer in self.gat_layers:
          input_embeddings = layer(input_embeddings, self.edge_index)
          input_embeddings = F.relu(input_embeddings)  # Apply non-linearity

      return input_embeddings  # Remove batch dimension


    def weave(self, text_input_ids, token_type_ids, text_attention_mask):
      """
      Arranges tokens in the format:
      [CLS] text [SEP] [V1] [PRED] ... [VL] [PRED], then adds GAT-enhanced embeddings.
      """
      cls_token_id = self.tokenizer.cls_token_id
      sep_token_id = self.tokenizer.sep_token_id
      batch_size, seq_len = text_input_ids.size()
      device = text_input_ids.device

      # Add [CLS] and [SEP] around the input tokens
      cls_tokens = torch.full((batch_size, 1), cls_token_id, device=device)
      sep_tokens = torch.full((batch_size, 1), sep_token_id, device=device)
      combined_input_ids = torch.cat([cls_tokens, text_input_ids, sep_tokens], dim=1)
      seq_len = combined_input_ids.size(1)

      # Extend the attention mask
      extended_attention_mask = torch.cat([torch.ones(batch_size, 1, device=device),
                                          text_attention_mask,
                                          torch.ones(batch_size, 1, device=device)], dim=1)

      # Prepare initial input embeddings for the text
      input_embeddings = self.text_encoder.embeddings(input_ids=combined_input_ids, token_type_ids=None)

      # Get GAT-enhanced embeddings and extend input embeddings
      gat_embeddings = self.get_gat_embeddings()  # Shape: (L*2, hidden_dim)
      gat_embeddings = gat_embeddings.unsqueeze(0).repeat(batch_size, 1, 1)  # Shape: (batch_size, L*2, hidden_dim)
      extended_embeddings = torch.cat([input_embeddings, gat_embeddings], dim=1)  # Shape: (batch_size, seq_len + L*2, hidden_dim)

      # Extend the attention mask for the GAT embeddings
      extended_attention_mask = torch.cat([extended_attention_mask,
                                          torch.ones(batch_size, gat_embeddings.size(1), device=device)], dim=1)

      # Add final [SEP]
      combined_input_ids = torch.cat([combined_input_ids, sep_tokens], dim=1)

      # Return the extended embeddings and attention mask
      return extended_embeddings, extended_attention_mask


    def forward(self, text_input_ids, token_type_ids, text_attention_mask, images):
        # Move everything to device
        device = text_input_ids.device
        self.template_embeddings = self.template_embeddings.to(device)
        self.pred_embedding = self.pred_embedding.to(device)
        self.label_embeddings = self.label_embeddings.to(device)
        self.edge_index = self.edge_index.to(device)

        # Prepare embeddings for the combined sequence
        batch_size, seq_len = text_input_ids.size()

        input_embeddings, extended_attention_mask = self.weave(text_input_ids, token_type_ids, text_attention_mask)
        # Forward pass through the encoder
        encoder_outputs = self.text_encoder.encoder(
            hidden_states=input_embeddings,
            attention_mask=self.text_encoder.get_extended_attention_mask(extended_attention_mask, extended_attention_mask.shape, device)
        )
        sequence_output = encoder_outputs.last_hidden_state

        # Extract hidden states of [PRED] tokens
        pred_hidden_states = [sequence_output[:, seq_len + self.L + i, :] for i in range(self.L)]


        outputs = {}
        verbalizer_features = []  # List to collect verbalizer outputs for fusion
        for i in range(self.L):
            hidden_state = pred_hidden_states[i]
            label_indices = [self.label2idx[label] for label in self.labels_per_layer[i]]
            logits = torch.matmul(hidden_state, self.label_embeddings[label_indices].t())
            verbalizer_features.append(logits)  # Collect logits as verbalizer features
            for idx, label in enumerate(self.labels_per_layer[i]):
                outputs[label] = logits[:, idx].unsqueeze(-1)


        # Stack verbalizer features for fusion
        max_num_labels = max([logit.shape[1] for logit in verbalizer_features])  # Find the maximum number of labels
        padded_verbalizer_features = [
            F.pad(logit, (0, max_num_labels - logit.shape[1]))  # Pad logits to have the same number of labels
            for logit in verbalizer_features
        ]

        # Stack and project verbalizer features
        text_features = torch.stack(padded_verbalizer_features, dim=1)  # Shape: (batch_size, L, max_num_labels)
        text_features = self.verbalizer_projection(text_features)

        # Image encoding
        image_features = self.image_encoder(pixel_values=images).last_hidden_state
        image_features = self.image_fc(image_features)
       # print(image_features.shape, text_features.shape)
        image_features = image_features.permute(1, 0, 2)
        text_features = text_features.permute(1, 0, 2)

        # Fusion
        fused_features = self.fusion_model(text_features, image_features)
        fused_features = torch.relu(fused_features)
        fused_features = torch.dropout(fused_features, p=0.1, train=self.training)

        # Fusion
        #text_features = sequence_output[:, 0, :]
        ##combined_features = torch.cat([text_features, image_features], dim=-1)
        #fused_features = self.fusion_model(text_features, image_features)
        #fused_features = torch.relu(fused_features)
        ##fused_features = torch.dropout(fused_features, p=0.1, train=self.training)

        # Final predictions
        final_outputs = {node: self.classifier_dict[node](fused_features) for node in self.labels}

        return final_outputs


## HPT model


Experimental model that does not work :(

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from transformers.models.bert.modeling_bert import BertModel, BertPreTrainedModel, BertOnlyMLMHead
from transformers.modeling_outputs import MaskedLMOutput
from transformers import AutoTokenizer
from torch_geometric.nn import GATConv
from torch_geometric.utils import from_networkx
import networkx as nx

# Assume that G, label_map, depth2label, get_label_index, and get_label_for_index are already defined
# If not, define them here or import them

# Helper functions if not already defined
# get_label_index = ...
# get_label_for_index = ...
# G = ...
# depth2label = ...

def multilabel_categorical_crossentropy(y_true, y_pred):
    loss_mask = y_true != -100
    y_true = y_true.masked_select(loss_mask).view(-1, y_pred.size(-1))
    y_pred = y_pred.masked_select(loss_mask).view(-1, y_true.size(-1))
    y_pred = (1 - 2 * y_true) * y_pred
    y_pred_neg = y_pred - y_true * 1e12
    y_pred_pos = y_pred - (1 - y_true) * 1e12
    zeros = torch.zeros_like(y_pred[:, :1])
    y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)
    y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
    neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
    pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
    return (neg_loss + pos_loss).mean()

class GraphEmbedding(nn.Module):
    def __init__(self, config, embedding, new_embedding, label_map, label_to_index, depth_to_labels, layer=1):
        super(GraphEmbedding, self).__init__()
        self.label_map = label_map
        self.label_to_index = label_to_index
        self.depth_to_labels = depth_to_labels

        padding_idx = config.pad_token_id
        self.num_class = config.num_labels

        self.graph = nn.Sequential(*[
            GATConv(in_channels=new_embedding.size(-1), out_channels=new_embedding.size(-1)) for _ in range(layer)
        ])

        self.padding_idx = padding_idx
        self.original_embedding = embedding
        new_embedding = torch.cat(
            [torch.zeros(1, new_embedding.size(-1), device=new_embedding.device, dtype=new_embedding.dtype),
             new_embedding], dim=0)
        self.new_embedding = nn.Embedding.from_pretrained(new_embedding, freeze=False, padding_idx=0)
        self.size = self.original_embedding.num_embeddings + self.new_embedding.num_embeddings - 1
        self.depth = self.new_embedding.num_embeddings - 2 - self.num_class

    @property
    def weight(self):
        def foo():
            edge_features = self.new_embedding.weight[1:, :]
            edge_features = edge_features[:-1, :]
            edge_features = self.graph(edge_features)
            edge_features = torch.cat(
                [edge_features, self.new_embedding.weight[-1:, :]], dim=0)
            return torch.cat([self.original_embedding.weight, edge_features], dim=0)

        return foo

    @property
    def raw_weight(self):
        def foo():
            return torch.cat([self.original_embedding.weight, self.new_embedding.weight[1:, :]], dim=0)

        return foo

    def forward(self, x):
        x = F.embedding(x, self.weight(), self.padding_idx)
        return x


class OutputEmbedding(nn.Module):
    def __init__(self, bias):
        super(OutputEmbedding, self).__init__()
        self.weight = None
        self.bias = bias

    def forward(self, x):
        return F.linear(x, self.weight(), self.bias)


class Prompt(BertPreTrainedModel):
    _keys_to_ignore_on_load_unexpected = [r"pooler"]
    _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]

    def __init__(self, config, graph_type='GAT', layer=1, label_map=None, label_to_index=None, depth_to_labels=None, **kwargs):
        super().__init__(config)

        self.bert = AutoModel.from_pretrained(config, add_pooling_layer=False)
        self.tokenizer = AutoTokenizer.from_pretrained(self.name_or_path)
        self.cls = BertOnlyMLMHead(config)
        self.num_labels = config.num_labels
        self.multiclass_bias = nn.Parameter(torch.zeros(self.num_labels, dtype=torch.float32))
        bound = 1 / math.sqrt(768)
        nn.init.uniform_(self.multiclass_bias, -bound, bound)

        self.graph_type = graph_type
        self.vocab_size = self.tokenizer.vocab_size
        self.layer = layer
        self.label_map = label_map
        self.label_to_index = label_to_index
        self.depth_to_labels = depth_to_labels

        self.init_weights()

    def get_output_embeddings(self):
        return self.cls.predictions.decoder

    def set_output_embeddings(self, new_embeddings):
        self.cls.predictions.decoder = new_embeddings

    def init_embedding(self):
        label_dict = self.label_map
        tokenizer = AutoTokenizer.from_pretrained(self.name_or_path)
        label_emb = []
        input_embeds = self.get_input_embeddings()

        for label in self.label_map.keys():
            encoded = tokenizer.encode(label, add_special_tokens=False)
            label_emb.append(
                input_embeds.weight.index_select(0, torch.tensor(encoded, device=self.device)).mean(dim=0))

        prefix = input_embeds(torch.tensor([tokenizer.mask_token_id],
                                           device=self.device, dtype=torch.long))

        prompt_embedding = nn.Embedding(3 + 1,
                                        input_embeds.weight.size(1), padding_idx=0)

        self._init_weights(prompt_embedding)

        label_emb = torch.cat(
            [torch.stack(label_emb), prompt_embedding.weight[1:, :], prefix], dim=0)

        embedding = GraphEmbedding(self.config, input_embeds, label_emb, self.label_map, self.label_to_index,
                                     self.depth_to_labels, layer=self.layer)
        self.set_input_embeddings(embedding)

        output_embeddings = OutputEmbedding(self.get_output_embeddings().bias)
        self.set_output_embeddings(output_embeddings)

        output_embeddings.weight = embedding.raw_weight
        self.vocab_size = output_embeddings.bias.size(0)
        output_embeddings.bias.data = nn.functional.pad(
            output_embeddings.bias.data,
            (
                0,
                embedding.size - output_embeddings.bias.shape[0],
            ),
            "constant",
            0,
        )

    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            labels=None,
            return_dict=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        multiclass_pos = input_ids == (self.get_input_embeddings().size - 1)
        single_labels = input_ids.masked_fill(multiclass_pos | (input_ids == self.config.pad_token_id), -100)

        if self.training:
            enable_mask = input_ids < self.tokenizer.vocab_size
            random_mask = torch.rand(input_ids.shape, device=input_ids.device) * attention_mask * enable_mask
            input_ids = input_ids.masked_fill(random_mask > 0.865, self.tokenizer.mask_token_id)
            random_ids = torch.randint_like(input_ids, 104, self.vocab_size)
            mlm_mask = random_mask > 0.985
            input_ids = input_ids * mlm_mask.logical_not() + random_ids * mlm_mask
            mlm_mask = random_mask < 0.85
            single_labels = single_labels.masked_fill(mlm_mask, -100)

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]
        prediction_scores = self.cls(sequence_output)

        masked_lm_loss = None

        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            masked_lm_loss = loss_fct(prediction_scores.view(-1, prediction_scores.size(-1)),
                                      single_labels.view(-1))
            multiclass_logits = prediction_scores.masked_select(
                multiclass_pos.unsqueeze(-1).expand(-1, -1, prediction_scores.size(-1))).view(-1,
                                                                                              prediction_scores.size(
                                                                                                  -1))
            multiclass_logits = multiclass_logits[:,
                                self.vocab_size:self.vocab_size + self.num_labels] + self.multiclass_bias
            multiclass_loss = multilabel_categorical_crossentropy(labels.view(-1, self.num_labels), multiclass_logits)
            masked_lm_loss += multiclass_loss

        if not return_dict:
            output = (prediction_scores,) + outputs[2:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

        return MaskedLMOutput(
            loss=masked_lm_loss,
            logits=prediction_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )




# Loss

In [None]:
class HierarchyAwareLoss(nn.Module):
    def __init__(self, label_index):
        """
        Args:
            label_index (dict): Mapping of label names to their indices in the label space.
        """
        super(HierarchyAwareLoss, self).__init__()
        self.label_index = label_index

    @staticmethod
    def multilabel_categorical_crossentropy(y_true, y_pred):
        """
        Compute the multilabel categorical crossentropy loss.

        Args:
            y_true (torch.Tensor): Ground truth binary labels of shape (batch_size, num_labels).
            y_pred (torch.Tensor): Predicted logits of shape (batch_size, num_labels).

        Returns:
            torch.Tensor: Loss value.
        """
        loss_mask = y_true != -100  # Exclude masked values
        y_true = y_true.masked_select(loss_mask).view(-1, y_pred.size(-1))
        y_pred = y_pred.masked_select(loss_mask).view(-1, y_true.size(-1))
        y_pred = (1 - 2 * y_true) * y_pred
        y_pred_neg = y_pred - y_true * 1e12
        y_pred_pos = y_pred - (1 - y_true) * 1e12
        zeros = torch.zeros_like(y_pred[:, :1])
        y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)
        y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
        neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
        pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
        return (neg_loss + pos_loss).mean()

    def forward(self, outputs, labels):
        """
        Compute the multi-label categorical crossentropy loss.

        Args:
            outputs (dict): Dictionary of logits for each label node.
            labels (torch.Tensor): Binary ground truth labels of shape (batch_size, num_labels).

        Returns:
            torch.Tensor: Loss value.
        """
        # Combine all logits into a single tensor
        logits = torch.stack([outputs[label].squeeze(-1) for label in self.label_index.keys() if label != "Root"], dim=1)  # Shape: (batch_size, num_labels)
        # Combine labels into a single tensor
        labels_combined = torch.stack([labels[:, idx] for idx in self.label_index.values() if idx != 0], dim=1)  # Shape: (batch_size, num_labels)
        # Compute the multilabel categorical crossentropy loss
        loss = self.multilabel_categorical_crossentropy(labels_combined, logits)
        return loss


In [None]:
def l1_regularization(model, lambda_l1=1e-6):
    l1_loss = 0.0
    for param in model.parameters():
        if param.requires_grad:  # Only include trainable parameters
            l1_loss += torch.sum(torch.abs(param))
    return lambda_l1 * l1_loss

# Utility

In [None]:
def get_model(num_classes, heirarchy):
    model = HierarchyAwareClassifier(256, G, layers)
    return model
    try:
        model.load_state_dict(torch.load(model_path))
    except FileNotFoundError:
        print("Error: Model file not found!")
    return model


def get_loss_function():
    return nn.BCEWithLogitsLoss()  # Multi-label classification

def find_best_thresholds(predictions, targets):
    """
    Find the best threshold for each class that maximizes the F1 score.

    Args:
        predictions (torch.Tensor): Predicted probabilities, shape (num_samples, num_classes).
        targets (torch.Tensor): Binary ground truth labels, shape (num_samples, num_classes).

    Returns:
        List[float]: Best threshold for each class.
    """
    num_classes = predictions.shape[1]
    best_thresholds = []

    for i in range(num_classes):
        # Compute precision, recall, and thresholds
        precision, recall, thresholds = precision_recall_curve(targets[:, i].cpu(), predictions[:, i].cpu())

        # Compute F1 scores
        f1_scores = 2 * (precision * recall) / (precision + recall + 1e-8)  # Avoid division by zero

        # Find the threshold that gives the best F1 score
        best_threshold = thresholds[np.argmax(f1_scores)]
        best_thresholds.append(best_threshold)

    return best_thresholds

def multilabel_accuracy(y_true, y_pred):
    # y_true and y_pred are binary arrays (e.g., one-hot encoded)
    return (y_true == y_pred).mean()

def save_model(model):
    torch.save(model.state_dict(), model_path)

def _h_precision_score(labels, preds, hierarchy):
    """
    Calculate hierarchical precision.
    Args:
        labels (torch.Tensor): Ground truth labels, shape (batch_size, num_labels).
        preds (dict): Dictionary of predicted labels for each category.
        hierarchy (networkx.DiGraph): Hierarchy graph of labels.
    Returns:
        float: Hierarchical precision score.
    """
    correct = 0
    total_pred = 0

    for i in range(labels.shape[0]):  # Iterate over each sample
        pred_row = {label: preds[label][i].item() for label in preds.keys()}
        label_row = labels[i]

        pred_set = _get_active_nodes(pred_row, hierarchy, is_pred=True)
        label_set = _get_active_nodes(label_row, hierarchy, is_pred=False)

        correct += len(pred_set.intersection(label_set))
        total_pred += len(pred_set)

    return correct / total_pred if total_pred > 0 else 0



def _h_recall_score(labels, preds, hierarchy):
    """
    Calculate hierarchical recall.
    Args:
        labels (torch.Tensor): Ground truth labels, shape (batch_size, num_labels).
        preds (dict): Dictionary of predicted labels for each category.
        hierarchy (networkx.DiGraph): Hierarchy graph of labels.
    Returns:
        float: Hierarchical recall score.
    """
    correct = 0
    total_label = 0

    for i in range(labels.shape[0]):  # Iterate over each sample
        pred_row = {label: preds[label][i].item() for label in preds.keys()}
        label_row = labels[i]

        pred_set = _get_active_nodes(pred_row, hierarchy, is_pred=True)
        label_set = _get_active_nodes(label_row, hierarchy, is_pred=False)

        correct += len(pred_set.intersection(label_set))
        total_label += len(label_set)

    return correct / total_label if total_label > 0 else 0




def _h_fbeta_score(labels, preds, hierarchy, beta=1):
    """
    Calculate hierarchical F-beta score.
    Args:
        labels (torch.Tensor): Ground truth labels, shape (batch_size, num_labels).
        preds (dict): Dictionary of predicted labels for each category.
        hierarchy (networkx.DiGraph): Hierarchy graph of labels.
        beta (float): Beta value for F-beta score. Default is 1 for F1 score.
    Returns:
        float: Hierarchical F-beta score.
    """
    precision = _h_precision_score(labels, preds, hierarchy)
    recall = _h_recall_score(labels, preds, hierarchy)

    if precision + recall == 0:
        return 0

    return (1 + beta**2) * (precision * recall) / ((beta**2 * precision) + recall)



def _get_active_nodes(example, hierarchy, is_pred=False):
    """
    Get active nodes (with value 1) from an example and their ancestors.
    Args:
        example (torch.Tensor or dict): Binary vector or prediction dictionary for one example.
        hierarchy (networkx.DiGraph): Hierarchy graph of labels.
        is_pred (bool): Whether the example is from predictions or labels.
    Returns:
        set: Active nodes (names) and their ancestors for predictions; labels directly for ground truth.
    """
    active_nodes = set()

    if is_pred:
        # Example is a dictionary of predictions
        for label, value in example.items():
            if value == 1:  # Active prediction
                active_nodes.update(_get_ancestors(label, hierarchy))
    else:
        # Example is a tensor of ground truth labels
        for idx in torch.nonzero(example, as_tuple=False).squeeze(1).tolist():
            label = get_label_for_index(idx)  # Convert index to label name
            active_nodes.add(label)  # Use labels directly as all ancestors are already included

    return active_nodes



def _get_ancestors(node, hierarchy):
    """
    Get all ancestors of a node, including the node itself.
    Args:
        node (int): Node index to retrieve ancestors for.
        hierarchy (networkx.DiGraph): Hierarchy graph.
    Returns:
        set: Set of ancestor nodes (indices).
    """
    ancestors = set(nx.ancestors(hierarchy, node))
    ancestors.add(node)
    return ancestors




def compute_hierarchical_metrics(labels, preds, hierarchy):
    """
    Compute hierarchical precision, recall, and F1 score.
    Args:
        labels (torch.Tensor): Ground truth labels, shape (batch_size, num_labels).
        preds (torch.Tensor): Predicted labels, shape (batch_size, num_labels).
        hierarchy (networkx.DiGraph): Hierarchy graph of labels.
    Returns:
        dict: Hierarchical precision, recall, and F1 score.
    """
    hierarchical_precision = _h_precision_score(labels, preds, hierarchy)
    hierarchical_recall = _h_recall_score(labels, preds, hierarchy)
    hierarchical_f1 = _h_fbeta_score(labels, preds, hierarchy)

    return {
        "hierarchical_precision": hierarchical_precision,
        "hierarchical_recall": hierarchical_recall,
        "hierarchical_f1": hierarchical_f1,
    }



# Training loop

## Validate

In [None]:
def validate(model, val_loader, device, loss_fn):
    if isinstance(device, str):
        device = torch.device(device)
    global label_map
    model.eval()
    all_labels = []
    all_probs = {node: [] for node in model.hierarchy.nodes if node != "Root"}  # Initialize probabilities storage
    total_loss = 0
    count = 0
    with torch.no_grad():
        for batch in val_loader:
            text_input_ids, token_type_ids, attention_masks, images, labels = [b.to(device) for b in batch]
            # check if labels has a 1
            non_zero_mask = torch.any(labels != 0, dim=1)

            # Filter out the zero-only rows from all batch components
            text_input_ids = text_input_ids[non_zero_mask]
            token_type_ids = token_type_ids[non_zero_mask]
            attention_masks = attention_masks[non_zero_mask]
            images = images[non_zero_mask]
            labels = labels[non_zero_mask]
            # Forward pass
            outputs = model(text_input_ids, token_type_ids, attention_masks, images)
            probs = {node: torch.sigmoid(outputs[node].squeeze(-1)) for node in outputs.keys()}

            # Compute loss
            loss = loss_fn(outputs, labels)
            total_loss += loss.item()

            # Store probabilities and ground truth labels
            all_labels.append(labels.cpu())
            for node in probs.keys():
                all_probs[node].append(probs[node].cpu())

    # Aggregate labels and probabilities
    all_labels = torch.cat(all_labels, dim=0)  # Ground truth labels
    all_probs = {node: torch.cat(all_probs[node], dim=0) for node in all_probs.keys()}

    # Convert probabilities to binary predictions
    thresholds = find_best_thresholds(torch.stack(list(all_probs.values())).T, all_labels)
    threshold_dict = {node: threshold for node, threshold in zip(all_probs.keys(), thresholds)}
    all_preds = {node: (all_probs[node] > threshold_dict[node]).int() for node in all_probs.keys()}


    # Compute evaluation metrics
    metrics = {}
    precision_sum = 0
    recall_sum = 0
    f1_sum = 0
    total_samples = 0
    for node in all_probs.keys():
        idx = label_map[node]
        precision = precision_score(all_labels[:, idx], all_preds[node], zero_division=0)
        recall = recall_score(all_labels[:, idx], all_preds[node])
        f1 = f1_score(all_labels[:, idx], all_preds[node])
        metrics[node] = {"precision": precision, "recall": recall, "f1": f1}
        num_samples = int(all_labels[:, idx].sum())
        precision_sum += precision * num_samples
        recall_sum += recall * num_samples
        f1 = f1_score(all_labels[:, idx], all_preds[node])
        f1_sum += f1 * num_samples
        total_samples += num_samples

    # Calculate overall weighted metrics
    if total_samples > 0:
        overall_metrics = {
            "precision": precision_sum / total_samples,
            "recall": recall_sum / total_samples,
            "f1": f1_sum / total_samples,
        }
    else:
        overall_metrics = {"precision": 0, "recall": 0, "f1": 0}

    # Compute hierarchical metrics
    hierarchical_metrics = compute_hierarchical_metrics(all_labels, all_preds, model.hierarchy)

    # Log metrics
    print(f"Validation Loss: {total_loss / len(val_loader):.4f}")
    for node, metric in metrics.items():
        print(f"{node}: Precision: {metric['precision']:.4f}, Recall: {metric['recall']:.4f}, F1: {metric['f1']:.4f}")
    print("<------------------------------------>")
    print(f"Overall Metrics: Precision: {overall_metrics['precision']:.4f}, Recall: {overall_metrics['recall']:.4f}, F1: {overall_metrics['f1']:.4f}")
    print(f"Hierarchical Metrics: {hierarchical_metrics}")
    print("<------------------------------------>")
    return total_loss / len(val_loader), hierarchical_metrics, overall_metrics


## Train

### Base train

In [None]:
def train(model, loss_fn, epochs=NUM_EPOCHS):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    # Define optimizer, scheduler, and loss function
    optimizer = Adam(model.parameters(), lr=5e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.9)
    for param in model.fusion_model.parameters():
        param.requires_grad = True

    loss_list = []
    metrics_list = []
    metrics_list2 = []

    for epoch in range(epochs):
        model.train()
        train_loss = 0

        train_loader = tqdm(all_train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False)

        for batch in train_loader:
            text_input_ids, token_type_ids, attention_masks, images, labels = [b.to(device) for b in batch]
            non_zero_mask = torch.any(labels != 0, dim=1)
            # Filter out the zero-only rows from all batch components
            text_input_ids = text_input_ids[non_zero_mask]
            token_type_ids = token_type_ids[non_zero_mask]
            attention_masks = attention_masks[non_zero_mask]
            images = images[non_zero_mask]
            labels = labels[non_zero_mask]
            # Zero gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(text_input_ids, token_type_ids, attention_masks, images)
            loss = loss_fn(outputs, labels)
            #loss += l1_regularization(model)
            # Backward pass
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_loader.set_postfix(loss=loss.item())

        avg_train_loss = train_loss / len(all_train_loader)
        print(f"Epoch {epoch+1}/{epochs}, Training Loss: {avg_train_loss:.4f}")

        # Validation
        val_loss, metrics1, metrics2 = validate(model, all_val_loader, device, loss_fn)
        print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Validation Loss: {val_loss:.4f}")

        loss_list.append(avg_train_loss)
        metrics_list.append(metrics1)
        metrics_list2.append(metrics2)
        # Adjust learning rate
        scheduler.step(val_loss)

    return loss_list, metrics_list, metrics_list2


### Heirarchy aware train

## Run

In [None]:
model = get_model(22,G)
#model = Prompt.from_pretrained("bert-base-uncased", graph_type='GAT', layer=3, label_map=label_map, label_to_index=get_label_for_index, index_to_label=None, depth_to_labels=depth_to_label)
model.to("cuda")




HierarchyAwareClassifier(
  (text_encoder): DebertaV2Model(
    (embeddings): DebertaV2Embeddings(
      (word_embeddings): Embedding(128005, 768, padding_idx=0)
      (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)
      (dropout): StableDropout()
    )
    (encoder): DebertaV2Encoder(
      (layer): ModuleList(
        (0-11): 12 x DebertaV2Layer(
          (attention): DebertaV2Attention(
            (self): DisentangledSelfAttention(
              (query_proj): Linear(in_features=768, out_features=768, bias=True)
              (key_proj): Linear(in_features=768, out_features=768, bias=True)
              (value_proj): Linear(in_features=768, out_features=768, bias=True)
              (pos_dropout): StableDropout()
              (dropout): StableDropout()
            )
            (output): DebertaV2SelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True

In [None]:
validate(model, all_val_loader, torch.device("cuda"), HierarchyAwareLoss(label_map))



Validation Loss: 5.0300
Logos: Precision: 0.5571, Recall: 1.0000, F1: 0.7156
Repetition: Precision: 0.0448, Recall: 0.9565, F1: 0.0856
Obfuscation, Intentional vagueness, Confusion: Precision: 0.0125, Recall: 0.2000, F1: 0.0235
Reasoning: Precision: 0.0000, Recall: 0.0000, F1: 0.0000
Justification: Precision: 0.4444, Recall: 0.0221, F1: 0.0421
Slogans: Precision: 0.1556, Recall: 0.1346, F1: 0.1443
Bandwagon: Precision: 0.0175, Recall: 0.8750, F1: 0.0342
Appeal to authority: Precision: 0.2667, Recall: 0.0606, F1: 0.0988
Flag-waving: Precision: 0.1659, Recall: 0.6102, F1: 0.2609
Appeal to fear/prejudice: Precision: 0.0000, Recall: 0.0000, F1: 0.0000
Simplification: Precision: 0.2609, Recall: 0.0594, F1: 0.0968
Causal Oversimplification: Precision: 0.0427, Recall: 0.9545, F1: 0.0817
Black-and-white Fallacy/Dictatorship: Precision: 0.2222, Recall: 0.0364, F1: 0.0625
Thought-terminating cliché: Precision: 0.0754, Recall: 0.9737, F1: 0.1399
Distraction: Precision: 0.1111, Recall: 0.0294, F1:

(5.029999074481783,
 {'hierarchical_precision': 0.24405014874628134,
  'hierarchical_recall': 0.8244795405599425,
  'hierarchical_f1': 0.37661911788817837},
 {'precision': 0.3306728372535554,
  'recall': 0.5574300071787509,
  'f1': 0.361451866177274})

In [None]:
h_metrics_list = []
metrics2_list = []
loss_list = []
for param in model.text_encoder.parameters():
    param.requires_grad = False
for param in model.image_encoder.parameters():
    param.requires_grad = False
for i in range(1):
  loss, metrics_list, metrics_list2 = train(model,HierarchyAwareLoss(label_map),  5)
  h_metrics_list.extend([
        {
            "precision": metrics["hierarchical_precision"],
            "recall": metrics["hierarchical_recall"],
            "f1": metrics["hierarchical_f1"]
        } for metrics in metrics_list
  ])

  # Append loss and metrics2 for this iteration
  loss_list.extend(loss)
  metrics2_list.extend(metrics_list2)
save_model(model)



Epoch 1/5, Training Loss: 4.4823




Validation Loss: 4.4372
Logos: Precision: 0.5517, Recall: 0.9784, F1: 0.7056
Repetition: Precision: 0.0460, Recall: 0.9565, F1: 0.0878
Obfuscation, Intentional vagueness, Confusion: Precision: 0.0556, Recall: 0.2000, F1: 0.0870
Reasoning: Precision: 0.0000, Recall: 0.0000, F1: 0.0000
Justification: Precision: 0.3602, Recall: 0.9890, F1: 0.5280
Slogans: Precision: 0.1130, Recall: 0.8846, F1: 0.2004
Bandwagon: Precision: 0.0138, Recall: 0.2500, F1: 0.0261
Appeal to authority: Precision: 0.1282, Recall: 0.9091, F1: 0.2247
Flag-waving: Precision: 0.1330, Recall: 0.4915, F1: 0.2094
Appeal to fear/prejudice: Precision: 0.0217, Recall: 0.0294, F1: 0.0250
Simplification: Precision: 0.2340, Recall: 0.3267, F1: 0.2727
Causal Oversimplification: Precision: 0.0365, Recall: 0.3636, F1: 0.0664
Black-and-white Fallacy/Dictatorship: Precision: 0.2000, Recall: 0.0364, F1: 0.0615
Thought-terminating cliché: Precision: 0.0709, Recall: 0.5000, F1: 0.1242
Distraction: Precision: 0.0528, Recall: 0.4412, F1:



Epoch 2/5, Training Loss: 4.4213




Validation Loss: 4.4316
Logos: Precision: 0.7778, Recall: 0.0252, F1: 0.0488
Repetition: Precision: 0.0434, Recall: 0.9130, F1: 0.0828
Obfuscation, Intentional vagueness, Confusion: Precision: 0.0714, Recall: 0.2000, F1: 0.1053
Reasoning: Precision: 0.3333, Recall: 0.0155, F1: 0.0296
Justification: Precision: 0.3592, Recall: 0.4862, F1: 0.4131
Slogans: Precision: 0.1062, Recall: 0.9808, F1: 0.1917
Bandwagon: Precision: 0.0000, Recall: 0.0000, F1: 0.0000
Appeal to authority: Precision: 0.2000, Recall: 0.0455, F1: 0.0741
Flag-waving: Precision: 0.1259, Recall: 0.8983, F1: 0.2208
Appeal to fear/prejudice: Precision: 0.0707, Recall: 0.7941, F1: 0.1298
Simplification: Precision: 0.2261, Recall: 0.5842, F1: 0.3260
Causal Oversimplification: Precision: 0.0407, Recall: 0.8182, F1: 0.0776
Black-and-white Fallacy/Dictatorship: Precision: 0.2174, Recall: 0.0909, F1: 0.1282
Thought-terminating cliché: Precision: 0.0751, Recall: 0.9737, F1: 0.1394
Distraction: Precision: 0.0698, Recall: 1.0000, F1:



Epoch 3/5, Training Loss: 4.4112




Validation Loss: 4.4226
Logos: Precision: 0.7778, Recall: 0.0252, F1: 0.0488
Repetition: Precision: 0.0477, Recall: 0.8261, F1: 0.0903
Obfuscation, Intentional vagueness, Confusion: Precision: 0.0092, Recall: 0.2000, F1: 0.0175
Reasoning: Precision: 0.0000, Recall: 0.0000, F1: 0.0000
Justification: Precision: 0.3291, Recall: 0.1436, F1: 0.2000
Slogans: Precision: 0.1017, Recall: 0.1154, F1: 0.1081
Bandwagon: Precision: 0.0000, Recall: 0.0000, F1: 0.0000
Appeal to authority: Precision: 0.1408, Recall: 0.4394, F1: 0.2132
Flag-waving: Precision: 0.1134, Recall: 0.9492, F1: 0.2025
Appeal to fear/prejudice: Precision: 0.0681, Recall: 1.0000, F1: 0.1276
Simplification: Precision: 0.3000, Recall: 0.0297, F1: 0.0541
Causal Oversimplification: Precision: 0.0473, Recall: 0.6364, F1: 0.0881
Black-and-white Fallacy/Dictatorship: Precision: 0.1111, Recall: 0.0364, F1: 0.0548
Thought-terminating cliché: Precision: 0.1212, Recall: 0.1053, F1: 0.1127
Distraction: Precision: 0.1111, Recall: 0.0294, F1:



Epoch 4/5, Training Loss: 4.4024




Validation Loss: 4.4260
Logos: Precision: 0.7333, Recall: 0.0396, F1: 0.0751
Repetition: Precision: 0.0444, Recall: 0.2609, F1: 0.0759
Obfuscation, Intentional vagueness, Confusion: Precision: 0.0222, Recall: 0.2000, F1: 0.0400
Reasoning: Precision: 0.2000, Recall: 0.0078, F1: 0.0149
Justification: Precision: 0.3434, Recall: 0.1878, F1: 0.2429
Slogans: Precision: 0.1037, Recall: 0.9808, F1: 0.1875
Bandwagon: Precision: 0.0000, Recall: 0.0000, F1: 0.0000
Appeal to authority: Precision: 0.1365, Recall: 0.5606, F1: 0.2196
Flag-waving: Precision: 0.3333, Recall: 0.0508, F1: 0.0882
Appeal to fear/prejudice: Precision: 0.0671, Recall: 0.3235, F1: 0.1111
Simplification: Precision: 0.3030, Recall: 0.0990, F1: 0.1493
Causal Oversimplification: Precision: 0.0385, Recall: 0.0909, F1: 0.0541
Black-and-white Fallacy/Dictatorship: Precision: 0.1053, Recall: 0.0364, F1: 0.0541
Thought-terminating cliché: Precision: 0.0760, Recall: 0.6579, F1: 0.1362
Distraction: Precision: 0.0638, Recall: 0.1765, F1:



Epoch 5/5, Training Loss: 4.4025




Validation Loss: 4.4231
Logos: Precision: 0.5558, Recall: 0.9676, F1: 0.7060
Repetition: Precision: 0.0648, Recall: 0.8261, F1: 0.1203
Obfuscation, Intentional vagueness, Confusion: Precision: 0.1000, Recall: 0.2000, F1: 0.1333
Reasoning: Precision: 0.5000, Recall: 0.0233, F1: 0.0444
Justification: Precision: 0.3621, Recall: 0.8564, F1: 0.5090
Slogans: Precision: 0.1683, Recall: 0.3269, F1: 0.2222
Bandwagon: Precision: 0.0085, Recall: 0.1250, F1: 0.0160
Appeal to authority: Precision: 0.1444, Recall: 0.4091, F1: 0.2134
Flag-waving: Precision: 0.1294, Recall: 0.3729, F1: 0.1921
Appeal to fear/prejudice: Precision: 0.0000, Recall: 0.0000, F1: 0.0000
Simplification: Precision: 0.2073, Recall: 0.1683, F1: 0.1858
Causal Oversimplification: Precision: 0.1000, Recall: 0.0455, F1: 0.0625
Black-and-white Fallacy/Dictatorship: Precision: 0.1319, Recall: 0.5636, F1: 0.2138
Thought-terminating cliché: Precision: 0.0755, Recall: 0.1053, F1: 0.0879
Distraction: Precision: 0.1111, Recall: 0.0294, F1:

In [None]:
save_model(model)