In [1]:
# ------------------- Import Necessary Libraries -------------------
# Import libraries for data handling, machine learning, and neural network operations
import json
import os
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from transformers import BertTokenizer, BertModel
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import f1_score, precision_recall_curve, precision_score, recall_score
from torch.utils.data import Dataset
from torch_geometric.nn import GCNConv, knn_graph
from torch_geometric.data import Data
from torch_geometric.utils import to_undirected

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# ------------------- Device Configuration -------------------
# Set up the computation device: use MPS (Metal Performance Shaders) if available, otherwise default to CPU

DEVICE = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {DEVICE}")

Using device: mps


In [3]:
# ------------------- Load Dataset -------------------
def load_data(file_path):
    """Load JSON data from a specified file path.
    
    Args:
        file_path (str): Path to the JSON file to be loaded.
    
    Returns:
        list: A list of dictionaries containing the loaded data.
    """
    with open(file_path, "r", encoding="utf-8") as f:
        return json.load(f)

# Load the training, development, and test datasets from their respective JSON files
train_data = load_data("train_revised.json")
dev_data = load_data("dev_revised.json")
test_data = load_data("test_revised.json")

# Display the size of each dataset for verification
print(f"\nLoaded datasets: Train({len(train_data)}), Dev({len(dev_data)}), Test({len(test_data)})\n")

# ------------------- Extract and Filter Relations -------------------
def get_unique_relations(data):
    """Extract all unique relation types present in the dataset.
    
    Args:
        data (list): List of data entries, each containing a 'labels' field with relation dictionaries.
    
    Returns:
        set: A set of unique relation identifiers (e.g., 'P123', 'NA').
    """
    return set(rel.get("r", "UNKNOWN") for entry in data for rel in entry.get("labels", []))

# Extract unique relations from each dataset
train_relations = get_unique_relations(train_data)
dev_relations = get_unique_relations(dev_data)
test_relations = get_unique_relations(test_data)

# Identify relations that are present in one dataset but missing in others
missing_in_train = (dev_relations | test_relations) - train_relations
missing_in_dev = (train_relations | test_relations) - dev_relations
missing_in_test = (train_relations | dev_relations) - test_relations

# Print debugging information about missing relations
print("\nDebug: Checking for missing relations")
print(f"No missing relations in Train" if not missing_in_train else f"Missing in Train: {missing_in_train}")
print(f"No missing relations in Dev" if not missing_in_dev else f"Missing in Dev: {missing_in_dev}")
print(f"No missing relations in Test" if not missing_in_test else f"Missing in Test: {missing_in_test}")

# Combine relations to remove (those missing in any dataset) to ensure consistency across splits
relations_to_remove = missing_in_dev | missing_in_test
print(f"\nRelations removed due to being missing: {relations_to_remove}\n")

def filter_data_fixed(data, remove_rels):
    """Filter out data entries containing specified relations to ensure dataset consistency.
    
    Args:
        data (list): List of data entries to filter.
        remove_rels (set): Set of relation identifiers to remove.
    
    Returns:
        list: Filtered list of data entries with unwanted relations removed.
    """
    new_data = []
    removed_entries_count = 0
    for entry in data:
        # Keep only labels whose relation is not in the remove list
        entry["labels"] = [rel for rel in entry.get("labels", []) if rel.get("r") not in remove_rels]
        if entry["labels"]: # If there are still labels, keep the entry
            new_data.append(entry)
        else: # Otherwise, count it as removed
            removed_entries_count += 1
    print(f"Removed {removed_entries_count} entries due to missing relations: {remove_rels}")
    return new_data

# Apply the filtering function to all datasets
train_data = filter_data_fixed(train_data, relations_to_remove)
dev_data = filter_data_fixed(dev_data, relations_to_remove)
test_data = filter_data_fixed(test_data, relations_to_remove)

# Display the updated dataset sizes after filtering
print(f"\nAfter filtering: Train({len(train_data)}), Dev({len(dev_data)}), Test({len(test_data)})\n")


Loaded datasets: Train(3053), Dev(500), Test(500)


Debug: Checking for missing relations
No missing relations in Train
Missing in Dev: {'P1198'}
Missing in Test: {'P190'}

Relations removed due to being missing: {'P1198', 'P190'}

Removed 3 entries due to missing relations: {'P1198', 'P190'}
Removed 1 entries due to missing relations: {'P1198', 'P190'}
Removed 1 entries due to missing relations: {'P1198', 'P190'}

After filtering: Train(3050), Dev(499), Test(499)



In [4]:
# ------------------- Multi-Label Encoding -------------------
# Create a sorted list of all unique relations in the training data for consistent encoding
all_relations = sorted(set(rel.get('r', 'UNKNOWN') for entry in train_data for rel in entry.get('labels', [])))
mlb = MultiLabelBinarizer(classes=all_relations)

def extract_labels_multi(data):
    """Convert relation labels into a multi-label binary format using MultiLabelBinarizer.
    
    Args:
        data (list): List of data entries with 'labels' field.
    
    Returns:
        numpy.ndarray: Binary matrix where each row represents an entry and each column a relation.
    """
    relation_labels = []
    for entry in tqdm(data, desc="Processing Labels"):
        labels = [rel.get('r', 'UNKNOWN') for rel in entry.get('labels', [])]
        relation_labels.append(labels)
    return mlb.fit_transform(relation_labels)

# Set up directory to save or load embeddings and MultiLabelBinarizer object
embedding_dir = "bert_embeddings"
os.makedirs(embedding_dir, exist_ok=True)  # Create directory if it doesn't exist

# Define file paths for saving/loading embeddings and the MultiLabelBinarizer
train_embedding_file = os.path.join(embedding_dir, "X_train.npy")
dev_embedding_file = os.path.join(embedding_dir, "X_dev.npy")
test_embedding_file = os.path.join(embedding_dir, "X_test.npy")
mlb_file = os.path.join(embedding_dir, "mlb.pkl")

# Check if precomputed embeddings and MultiLabelBinarizer exist
embeddings_exist = (os.path.exists(mlb_file) and os.path.exists(train_embedding_file) and 
                    os.path.exists(dev_embedding_file) and os.path.exists(test_embedding_file))

if embeddings_exist:
    # Load existing MultiLabelBinarizer and embeddings if all files are present
    print("Loading existing MultiLabelBinarizer and BERT embeddings...")
    with open(mlb_file, 'rb') as f:
        mlb = pickle.load(f)
    X_train = np.load(train_embedding_file)
    X_dev = np.load(dev_embedding_file)
    X_test = np.load(test_embedding_file)
    y_train = extract_labels_multi(train_data)  # Compute labels with loaded MultiLabelBinarizer
    y_dev = extract_labels_multi(dev_data)
    y_test = extract_labels_multi(test_data)
    print(f"Multi-label encoding done with loaded MLB! BERT embeddings loaded! Shape: {X_train.shape}")
else:
    # Compute multi-label encoding and extract BERT embeddings if files are missing
    print("Computing multi-label encoding and extracting BERT embeddings...")
    y_train = extract_labels_multi(train_data)
    y_dev = extract_labels_multi(dev_data)
    y_test = extract_labels_multi(test_data)
    print("Multi-label encoding done!")

    # ------------------- BERT Feature Extraction -------------------
    # Initialize the BERT tokenizer and model for feature extraction
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    model = BertModel.from_pretrained("bert-base-uncased").to(DEVICE)

    def extract_features(data):
        """Extract BERT embeddings for each data entry using the BERT model.
        
        Args:
            data (list): List of data entries with 'sents' field containing sentences.
        
        Returns:
            numpy.ndarray: Stacked array of BERT embeddings (CLS token) for each entry.
        """
        embeddings = []
        for entry in tqdm(data, desc="Extracting BERT Embeddings"):
            # Combine all sentences in the entry into a single string
            text = " ".join([" ".join(sent) for sent in entry.get("sents", [])])
            # Tokenize the text and prepare it for BERT input
            inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512).to(DEVICE)
            with torch.no_grad():
                outputs = model(**inputs)
            # Extract the CLS token embedding from the last hidden state
            embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()
            embeddings.append(embedding)
        return np.vstack(embeddings)

    # Extract BERT embeddings for all datasets since they weren't loaded
    print("Extracting and saving BERT embeddings...")
    X_train = extract_features(train_data)
    X_dev = extract_features(dev_data)
    X_test = extract_features(test_data)

    # Save the computed embeddings and MultiLabelBinarizer to files for future use
    np.save(train_embedding_file, X_train)
    np.save(dev_embedding_file, X_dev)
    np.save(test_embedding_file, X_test)
    with open(mlb_file, 'wb') as f:
        pickle.dump(mlb, f)
    print(f"BERT embeddings extracted! Shape: {X_train.shape}")

Loading existing MultiLabelBinarizer and BERT embeddings...


Processing Labels: 100%|██████████| 3050/3050 [00:00<00:00, 157686.43it/s]
Processing Labels: 100%|██████████| 499/499 [00:00<00:00, 159281.41it/s]
Processing Labels: 100%|██████████| 499/499 [00:00<00:00, 141082.42it/s]

Multi-label encoding done with loaded MLB! BERT embeddings loaded! Shape: (3050, 768)





In [5]:
# Set up directory to save the GCN model and optimizer
model_dir  = "GCN_model_optmizers"
os.makedirs(model_dir, exist_ok=True)

model_save_path = os.path.join(model_dir, "GCN_model.pth")
optimizer_save_path = os.path.join(model_dir, "optimizer.pth")

# Verify that required variables from preprocessing are available
if 'X_train' not in globals() or 'X_dev' not in globals() or 'X_test' not in globals() or \
   'y_train' not in globals() or 'y_dev' not in globals() or 'y_test' not in globals() or 'mlb' not in globals():
    raise NameError("Required variables (X_train, X_dev, X_test, y_train, y_dev, y_test, mlb) not found in memory. Ensure preprocessing code has run.")

# ------------------- Custom Multi-Label Focal Loss -------------------
class MultiLabelFocalLoss(nn.Module):
    def __init__(self, gamma=2.0, weight=None, reduction='mean'):
        """Initialize the focal loss function for multi-label classification.
        
        Args:
            gamma (float): Focusing parameter to reduce the impact of easy examples.
            weight (torch.Tensor, optional): Weights for each class to handle imbalance.
            reduction (str): Method to reduce the loss ('mean', 'sum', or None).
        """
        super(MultiLabelFocalLoss, self).__init__()
        self.gamma = gamma
        self.weight = weight
        self.reduction = reduction

    def forward(self, inputs, targets):
        """Compute the focal loss between predictions and targets.
        
        Args:
            inputs (torch.Tensor): Model predictions (logits).
            targets (torch.Tensor): True binary labels.
        
        Returns:
            torch.Tensor: Computed focal loss.
        """
        ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        p_t = torch.sigmoid(inputs) * targets + (1 - torch.sigmoid(inputs)) * (1 - targets)
        loss = ce_loss * ((1 - p_t) ** self.gamma)

        if self.weight is not None:
            weight = self.weight[None, :]
            loss = loss * (targets * weight + (1 - targets))

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

# ------------------- Enhanced GCN Model with Dropout -------------------
class GCN(nn.Module):
    def __init__(self, input_dim, hidden_dim1, hidden_dim2, output_dim, dropout=0.4):
        """Initialize the Graph Convolutional Network (GCN) model.
        
        Args:
            input_dim (int): Dimension of input features (BERT embeddings).
            hidden_dim1 (int): Dimension of the first hidden layer.
            hidden_dim2 (int): Dimension of the second hidden layer.
            output_dim (int): Number of output classes (relations).
            dropout (float): Dropout rate to prevent overfitting.
        """
        super(GCN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim1)
        self.conv2 = GCNConv(hidden_dim1, hidden_dim2)
        self.conv3 = GCNConv(hidden_dim2, output_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, data):
        """Forward pass through the GCN model.
        
        Args:
            data (Data): Graph data object containing features and edge indices.
        
        Returns:
            torch.Tensor: Output logits for each node.
        """
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.conv3(x, edge_index)
        return x

# ------------------- Graph Construction with k-NN -------------------
def create_sparse_graph_data(X, y, k=3):
    """Construct a graph data object using k-nearest neighbors (k-NN).
    
    Args:
        X (numpy.ndarray): Feature matrix (BERT embeddings).
        y (numpy.ndarray): Label matrix (binary multi-label).
        k (int): Number of nearest neighbors to connect in the graph.
    
    Returns:
        Data: PyTorch Geometric Data object with nodes, edges, and labels.
    """
    x = torch.tensor(X, dtype=torch.float)
    edge_index = knn_graph(x, k=k, loop=False)
    edge_index = to_undirected(edge_index)
    y = torch.tensor(y, dtype=torch.float)
    return Data(x=x, edge_index=edge_index, y=y)

# Create graph data objects for training, development, and test sets
train_graph = create_sparse_graph_data(X_train, y_train, k=3).to(DEVICE)
dev_graph = create_sparse_graph_data(X_dev, y_dev, k=3).to(DEVICE)
test_graph = create_sparse_graph_data(X_test, y_test, k=3).to(DEVICE)

# Define GCN model parameters
input_dim = X_train.shape[1]  # 768 (BERT embedding size)
hidden_dim1 = 512
hidden_dim2 = 256
output_dim = y_train.shape[1] # Number of unique relations
dropout_rate = 0.4
gnn_model = GCN(input_dim, hidden_dim1, hidden_dim2, output_dim, dropout=dropout_rate).to(DEVICE)

# Compute class weights to handle label imbalance
class_counts = np.sum(y_train, axis=0)
pos_weights = torch.tensor((len(y_train) - class_counts) / (class_counts + 1e-6), dtype=torch.float).to(DEVICE)

# Initialize loss function and optimizer
criterion = MultiLabelFocalLoss(gamma=2.0, weight=pos_weights, reduction='mean')
optimizer = torch.optim.Adam(gnn_model.parameters(), lr=0.001)

# Set up training monitoring variables
train_f1_scores = []
dev_f1_scores = []
best_dev_f1 = 0
patience = 10
counter = 0

# ------------------- Training Loop with Overfitting Check -------------------
if os.path.exists(model_save_path):
    print(f"Loading existing model from {model_save_path}")
    gnn_model.load_state_dict(torch.load(model_save_path))
    gnn_model.to(DEVICE)
else:
    print("Training GNN Model...")
    for epoch in tqdm(range(55), desc="Training Epochs"):
        gnn_model.train()
        optimizer.zero_grad()
        out = gnn_model(train_graph)
        loss = criterion(out, train_graph.y)
        loss.backward()
        optimizer.step()

        # Evaluate model performance every 5 epochs
        if epoch % 5 == 0:
            gnn_model.eval()
            with torch.no_grad():
                train_logits = gnn_model(train_graph)
                train_pred_prob = torch.sigmoid(train_logits).cpu().numpy()
                train_preds = (train_pred_prob >= 0.5).astype(int)
                train_f1 = f1_score(y_train, train_preds, average="micro")

                dev_logits = gnn_model(dev_graph)
                dev_pred_prob = torch.sigmoid(dev_logits).cpu().numpy()
                dev_preds = (dev_pred_prob >= 0.5).astype(int)
                dev_f1 = f1_score(y_dev, dev_preds, average="micro")

            train_f1_scores.append(train_f1)
            dev_f1_scores.append(dev_f1)
            print(f"Epoch {epoch}: Train F1 = {train_f1:.4f}, Dev F1 = {dev_f1:.4f}")

            # Implement early stopping based on development set F1 score
            if dev_f1 > best_dev_f1:
                best_dev_f1 = dev_f1
                torch.save(gnn_model.state_dict(), model_save_path)
                counter = 0
            else:
                counter += 1
            if counter >= patience:
                print("Early stopping triggered")
                break

        torch.cuda.empty_cache() # Clear GPU memory if applicable

    print(f"Model saved to {model_save_path}")

# Load best model if it exists
if os.path.exists(model_save_path):
    gnn_model.load_state_dict(torch.load(model_save_path))
    print("Best model loaded successfully!")
else:
    print("Model file not found. Please ensure the model is trained and saved.")

# ------------------- Threshold Tuning -------------------
print("Optimizing thresholds on dev set...")
gnn_model.eval()
with torch.no_grad():
    dev_logits = gnn_model(dev_graph)
    dev_pred_prob = torch.sigmoid(dev_logits).cpu().numpy()

# Tune thresholds for each class to maximize F1 score on the development set
optimal_thresholds = np.zeros(output_dim)
for i in range(output_dim):
    precisions, recalls, thresholds = precision_recall_curve(y_dev[:, i], dev_pred_prob[:, i])
    f1_scores = (2 * precisions * recalls) / (precisions + recalls + 1e-6)
    optimal_thresholds[i] = thresholds[np.argmax(f1_scores)] if len(thresholds) > 0 else 0.5

# Save optimal thresholds for later use in real-time inference
torch.save(optimizer.state_dict(), optimizer_save_path)
print(f"Optimizer saved to {optimizer_save_path}")

# Define relations to ignore during evaluation
IGNORED_RELATIONS = {"NA"}

def compute_ign_f1(y_true, y_pred, relation_list):
    """Compute F1 scores while ignoring specified relations.
    
    Args:
        y_true (numpy.ndarray): True binary labels.
        y_pred (numpy.ndarray): Predicted binary labels.
        relation_list (list): List of all relation types.
    
    Returns:
        tuple: Micro and weighted IGN F1 scores.
    """
    relevant_indices = [i for i, rel in enumerate(relation_list) if rel not in IGNORED_RELATIONS]
    
    y_true_filtered = y_true[:, relevant_indices]
    y_pred_filtered = y_pred[:, relevant_indices]
    
    ign_f1_micro = f1_score(y_true_filtered, y_pred_filtered, average="micro", zero_division=0)
    ign_f1_weighted = f1_score(y_true_filtered, y_pred_filtered, average="weighted", zero_division=0)
    
    return ign_f1_micro, ign_f1_weighted

# ------------------- Evaluation -------------------

print("Evaluating on test set...")
with torch.no_grad():
    dev_logits = gnn_model(dev_graph)
    test_logits = gnn_model(test_graph)

    dev_pred_prob = torch.sigmoid(dev_logits).cpu().numpy()
    test_pred_prob = torch.sigmoid(test_logits).cpu().numpy()

    # Apply optimized thresholds to get binary predictions
    y_dev_pred = (dev_pred_prob >= optimal_thresholds).astype(int)
    y_test_pred = (test_pred_prob >= optimal_thresholds).astype(int)

# Define datasets and their true/predicted labels for evaluation
metrics = {
    "Dev": {"y_true": y_dev, "y_pred": y_dev_pred},
    "Test": {"y_true": y_test, "y_pred": y_test_pred}
}

# Compute and display evaluation metrics for each dataset
for dataset, data in metrics.items():
    y_true, y_pred = data["y_true"], data["y_pred"]
    
    f1_micro = f1_score(y_true, y_pred, average="micro", zero_division=0)
    f1_weighted = f1_score(y_true, y_pred, average="weighted", zero_division=0)
    precision = precision_score(y_true, y_pred, average="micro", zero_division=0)
    recall = recall_score(y_true, y_pred, average="micro", zero_division=0)
    
    ign_f1_micro, ign_f1_weighted = compute_ign_f1(y_true, y_pred, all_relations)

    print(f"\n{dataset} Metrics:")
    print(f"Micro F1: {f1_micro:.4f}")
    print(f"Weighted F1: {f1_weighted:.4f}")
    print(f"Micro IGN F1: {ign_f1_micro:.4f}")
    print(f"Weighted IGN F1: {ign_f1_weighted:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")

Loading existing model from GCN_model_optmizers/GCN_model.pth
Best model loaded successfully!
Optimizing thresholds on dev set...
Optimizer saved to GCN_model_optmizers/optimizer.pth
Evaluating on test set...

Dev Metrics:
Micro F1: 0.6012
Weighted F1: 0.6495
Micro IGN F1: 0.6012
Weighted IGN F1: 0.6495
Precision: 0.5156
Recall: 0.7209

Test Metrics:
Micro F1: 0.5546
Weighted F1: 0.6046
Micro IGN F1: 0.5546
Weighted IGN F1: 0.6046
Precision: 0.4687
Recall: 0.6789
