In [1]:
import os
import re
import pandas as pd
from pathlib import Path

def parse_classification_report(file_path):
    """Parse classification_report.txt file and extract key metrics"""
    with open(file_path, 'r') as f:
        content = f.read()
    
    # Extract accuracy
    accuracy_match = re.search(r'accuracy\s+(\d+\.\d+)', content)
    accuracy = float(accuracy_match.group(1)) if accuracy_match else None
    
    # Extract macro avg precision, recall, and f1-score
    macro_avg_match = re.search(r'macro avg\s+(\d+\.\d+)\s+(\d+\.\d+)\s+(\d+\.\d+)', content)
    macro_precision = float(macro_avg_match.group(1)) if macro_avg_match else None
    macro_recall = float(macro_avg_match.group(2)) if macro_avg_match else None
    macro_f1 = float(macro_avg_match.group(3)) if macro_avg_match else None
    
    return accuracy, macro_precision, macro_recall, macro_f1


def compare_folds(base_path=".", model_name="YourModel"):
    """Compare results across all folds"""
    results = []
    
    # Process each fold
    for fold_num in range(1, 11):  # fold_1 to fold_10
        fold_dir = Path(base_path) / f"fold_{fold_num}"
        report_file = fold_dir / "nodes_report.txt"
        
        if report_file.exists():
            try:
                accuracy, macro_precision, macro_recall, macro_f1 = parse_classification_report(report_file)

                
                results.append({
                    'model_name': model_name,
                    'fold_number': fold_num,
                    'accuracy': accuracy,
                    'precision_macro_avg': macro_precision,
                    'recall_macro_avg': macro_recall,
                    'f1_score_macro_avg': macro_f1
                })
                
            except Exception as e:
                print(f"Error processing fold_{fold_num}: {e}")
        else:
            print(f"Warning: {report_file} not found")
    
    # Process test_evaluation folder
    test_dir = Path(base_path) / "test_evaluation"
    test_report_file = test_dir / "classification_report.txt"
    
    if test_report_file.exists():
        try:
            accuracy, macro_precision, macro_recall, macro_f1 = parse_classification_report(test_report_file)
            
            results.append({
                'model_name': model_name,
                'fold_number': 'Test',
                'accuracy': accuracy,
                'precision_macro_avg': macro_precision,
                'recall_macro_avg': macro_recall,
                'f1_score_macro_avg': macro_f1
            })
            
            print(f"Processed test_evaluation: Accuracy={accuracy:.4f}, Precision={macro_precision:.4f}, Recall={macro_recall:.4f}, Macro F1={macro_f1:.4f}")
            
        except Exception as e:
            print(f"Error processing test_evaluation: {e}")
    else:
        print(f"Warning: {test_report_file} not found")

    
    return results

def generate_comparison_report(results, model_name):
    """Generate comparison report and statistics"""
    if not results:
        print("No results to analyze")
        return
    
    df = pd.DataFrame(results)
    
    # Separate test results from fold results
    fold_results = df[df['fold_number'] != 'Test']
    test_results = df[df['fold_number'] == 'Test']
    
    # Calculate averages and add average row (only from fold results)
    if len(fold_results) > 0:
        avg_accuracy = fold_results['accuracy'].mean()
        avg_precision = fold_results['precision_macro_avg'].mean()
        avg_recall = fold_results['recall_macro_avg'].mean()
        avg_f1_macro = fold_results['f1_score_macro_avg'].mean()
        
        avg_row = {
            'model_name': model_name,
            'fold_number': 'Avg',
            'accuracy': avg_accuracy,
            'precision_macro_avg': avg_precision,
            'recall_macro_avg': avg_recall,
            'f1_score_macro_avg': avg_f1_macro
        }
        
        # Reconstruct dataframe: folds + avg + test
        df = pd.concat([
            fold_results, 
            pd.DataFrame([avg_row]), 
            test_results
        ], ignore_index=True)


    # Save results to CSV
    output_file = "fold_comparison_results.csv"
    df.to_csv(output_file, index=False, float_format='%.4f')
    print(f"\nResults saved to: {output_file}")
    
    return df

def main():
    # Configuration
    BASE_PATH = "./10 folds results"  # Current directory, change if needed
    MODEL_NAME = "Joint RGCN"  # Change to your actual model name
    
    print("Starting fold comparison analysis...")
    
    # Parse all fold results
    results = compare_folds(BASE_PATH, MODEL_NAME)
    
    if results:
        # Generate comparison report
        df = generate_comparison_report(results, MODEL_NAME)

    else:
        print("No valid results found. Please check your file paths and formats.")

if __name__ == "__main__":
    main()


Starting fold comparison analysis...

Results saved to: fold_comparison_results.csv


In [1]:
import os
import torch
from pathlib import Path

# Path to the directory containing the graph data files
GRAPH_DATA_DIR = Path("graph_data_processed_for_joint_prediction")

# Edge type mapping
EDGE_TYPE_MAP = {0: "Support", 1: "Attack", 2: "No Relation"}
NODE_TYPE_MAP = {0: "Premise", 1: "Conclusion"}

def load_graph_data(graph_data_dir):
    """Load all preprocessed graph data files."""
    all_files = sorted([f for f in graph_data_dir.glob("*.pt") if f.is_file()])
    if not all_files:
        raise FileNotFoundError(f"No files found in {graph_data_dir}")
    return [torch.load(f, weights_only=False) for f in all_files]

def calculate_counts(data_list):
    """Count edge types and node types in the dataset."""
    edge_counts = torch.zeros(len(EDGE_TYPE_MAP), dtype=torch.long)
    node_counts = torch.zeros(len(NODE_TYPE_MAP), dtype=torch.long)

    for data in data_list:
        # Count edges
        edge_counts += torch.bincount(data.edge_type, minlength=len(EDGE_TYPE_MAP))

        # Count nodes
        node_counts += torch.bincount(data.y, minlength=len(NODE_TYPE_MAP))

    return edge_counts, node_counts

if __name__ == "__main__":
    print("Loading graph data...")
    all_data = load_graph_data(GRAPH_DATA_DIR)
    print(f"Loaded {len(all_data)} graph files.")

    # Calculate counts
    edge_counts, node_counts = calculate_counts(all_data)

    # Print results
    print("\n=== Dataset Statistics ===")
    print("Edge Counts:")
    for edge_type, count in enumerate(edge_counts):
        print(f"  {EDGE_TYPE_MAP[edge_type]}: {count}")

    print("\nNode Counts:")
    for node_type, count in enumerate(node_counts):
        print(f"  {NODE_TYPE_MAP[node_type]}: {count}")


Loading graph data...
Loaded 40 graph files.

=== Dataset Statistics ===
Edge Counts:
  Support: 2291
  Attack: 145
  No Relation: 3016

Node Counts:
  Premise: 2393
  Conclusion: 161


In [3]:
import torch
import numpy as np
from sklearn.model_selection import KFold
from pathlib import Path
from torch_geometric.data import Data  # Ensure this is imported


# Constants
GRAPH_DATA_DIR = Path("graph_data_processed_for_joint_prediction")
NUM_FOLDS = 10
SEED = 42

EDGE_TYPE_MAP = {0: "Support", 1: "Attack", 2: "No Relation"}
NODE_TYPE_MAP = {0: "Premise", 1: "Conclusion"}

def load_graph_data(graph_data_dir):
    """Load all graph data files with safe deserialization."""
    files = sorted([f for f in graph_data_dir.glob("*.pt") if f.is_file()])
    if not files:
        raise FileNotFoundError(f"No graph data files found in {graph_data_dir}")

    # Allowlist the Data class for safe unpickling
    torch.serialization.add_safe_globals([Data])

    return [torch.load(f, weights_only=False) for f in files]


def calculate_graph_counts(data_list):
    """Calculate edge and node type counts for each graph."""
    edge_counts = []
    node_counts = []
    for data in data_list:
        # Edge counts
        edges = torch.bincount(data.edge_type, minlength=len(EDGE_TYPE_MAP))
        edge_counts.append(edges.numpy())
        # Node counts
        nodes = torch.bincount(data.y, minlength=len(NODE_TYPE_MAP))
        node_counts.append(nodes.numpy())
    return np.array(edge_counts), np.array(node_counts)

def create_balanced_splits(data_list, num_folds, seed):
    """Custom splitting to balance edge and node counts across folds."""
    np.random.seed(seed)
    edge_counts, node_counts = calculate_graph_counts(data_list)

    # Aggregate counts
    total_edge_counts = edge_counts.sum(axis=0)
    total_node_counts = node_counts.sum(axis=0)

    print("Total Edge Counts:", dict(zip(EDGE_TYPE_MAP.values(), total_edge_counts)))
    print("Total Node Counts:", dict(zip(NODE_TYPE_MAP.values(), total_node_counts)))

    # Initialize folds
    folds = [[] for _ in range(num_folds)]
    fold_edge_counts = np.zeros((num_folds, len(EDGE_TYPE_MAP)))
    fold_node_counts = np.zeros((num_folds, len(NODE_TYPE_MAP)))

    # Shuffle data indices
    indices = np.arange(len(data_list))
    np.random.shuffle(indices)

    # Distribute graphs into folds
    for idx in indices:
        graph_edges = edge_counts[idx]
        graph_nodes = node_counts[idx]
        # Find the fold that increases balance the least
        scores = []
        for fold in range(num_folds):
            updated_edge_counts = fold_edge_counts[fold] + graph_edges
            updated_node_counts = fold_node_counts[fold] + graph_nodes
            # Calculate imbalance
            edge_imbalance = np.std(updated_edge_counts / total_edge_counts)
            node_imbalance = np.std(updated_node_counts / total_node_counts)
            scores.append(edge_imbalance + node_imbalance)
        # Assign graph to the best fold
        best_fold = np.argmin(scores)
        folds[best_fold].append(idx)
        fold_edge_counts[best_fold] += graph_edges
        fold_node_counts[best_fold] += graph_nodes

    return folds

def run_joint_rgcn_cross_validation():
    """Main function to run cross-validation for the joint RGCN model."""
    # Load data
    data_list = load_graph_data(GRAPH_DATA_DIR)
    print(f"Loaded {len(data_list)} graphs.")

    # Create balanced splits
    folds = create_balanced_splits(data_list, NUM_FOLDS, SEED)

    # Print fold statistics
    for fold_idx, fold in enumerate(folds):
        print(f"\nFold {fold_idx + 1} Statistics:")
        fold_data = [data_list[idx] for idx in fold]
        edge_counts, node_counts = calculate_graph_counts(fold_data)
        print("Edge Counts:", edge_counts.sum(axis=0))
        print("Node Counts:", node_counts.sum(axis=0))

    # Placeholder: Replace this with your RGCN training loop
    for fold_idx, (train_fold, val_fold) in enumerate(folds):
        print(f"\n=== Fold {fold_idx + 1}/{NUM_FOLDS} ===")
        train_data = [data_list[idx] for idx in train_fold]
        val_data = [data_list[idx] for idx in val_fold]

        # Train your model here (train_data, val_data)
        print(f"Training on {len(train_data)} graphs, validating on {len(val_data)} graphs.")

run_joint_rgcn_cross_validation()


Loaded 40 graphs.
Total Edge Counts: {'Support': 2291, 'Attack': 145, 'No Relation': 3016}
Total Node Counts: {'Premise': 2393, 'Conclusion': 161}

Fold 1 Statistics:
Edge Counts: [265  29 410]
Node Counts: [291  17]

Fold 2 Statistics:
Edge Counts: [431  24 551]
Node Counts: [452  30]

Fold 3 Statistics:
Edge Counts: [169   9 214]
Node Counts: [172  12]

Fold 4 Statistics:
Edge Counts: [111   2 121]
Node Counts: [120  11]

Fold 5 Statistics:
Edge Counts: [372  22 482]
Node Counts: [361  25]

Fold 6 Statistics:
Edge Counts: [388  34 558]
Node Counts: [443  29]

Fold 7 Statistics:
Edge Counts: [335  19 430]
Node Counts: [345  24]

Fold 8 Statistics:
Edge Counts: [48  0 48]
Node Counts: [39  2]

Fold 9 Statistics:
Edge Counts: [ 99   4 119]
Node Counts: [101   7]

Fold 10 Statistics:
Edge Counts: [73  2 83]
Node Counts: [69  4]


ValueError: too many values to unpack (expected 2)

In [None]:
import torch
import numpy as np
from pathlib import Path
from sklearn.metrics import classification_report, f1_score
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.data import Data
from torch_geometric.nn import RGCNConv
import torch.nn.functional as F
from tqdm import tqdm

# Constants
GRAPH_DATA_DIR = Path("graph_data_processed_for_joint_prediction")
NUM_FOLDS = 10
SEED = 42
EPOCHS = 10
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 5e-3
DROPOUT_RATE = 0.5

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

EDGE_TYPE_MAP = {0: "Support", 1: "Attack", 2: "No Relation"}
NODE_TYPE_MAP = {0: "Premise", 1: "Conclusion"}

# Helper Functions (from previous code)
def load_graph_data(graph_data_dir):
    """Load graph data with secure deserialization."""
    files = sorted([f for f in graph_data_dir.glob("*.pt") if f.is_file()])
    if not files:
        raise FileNotFoundError(f"No graph data files found in {graph_data_dir}")
    from torch_geometric.data import Data
    torch.serialization.add_safe_globals([Data])
    return [torch.load(f, weights_only=False) for f in files]

def calculate_graph_counts(data_list):
    """Calculate edge and node type counts for each graph."""
    edge_counts, node_counts = [], []
    for data in data_list:
        edge_counts.append(torch.bincount(data.edge_type, minlength=len(EDGE_TYPE_MAP)).numpy())
        node_counts.append(torch.bincount(data.y, minlength=len(NODE_TYPE_MAP)).numpy())
    return np.array(edge_counts), np.array(node_counts)

def create_balanced_splits(data_list, num_folds, seed):
    """Create custom splits to balance edge and node types."""
    np.random.seed(seed)
    edge_counts, node_counts = calculate_graph_counts(data_list)

    total_edge_counts = edge_counts.sum(axis=0)
    total_node_counts = node_counts.sum(axis=0)

    folds = [[] for _ in range(num_folds)]
    fold_edge_counts = np.zeros((num_folds, len(EDGE_TYPE_MAP)))
    fold_node_counts = np.zeros((num_folds, len(NODE_TYPE_MAP)))

    indices = np.arange(len(data_list))
    np.random.shuffle(indices)

    for idx in indices:
        graph_edges, graph_nodes = edge_counts[idx], node_counts[idx]
        scores = []
        for fold in range(num_folds):
            updated_edges = fold_edge_counts[fold] + graph_edges
            updated_nodes = fold_node_counts[fold] + graph_nodes
            edge_imbalance = np.std(updated_edges / total_edge_counts)
            node_imbalance = np.std(updated_nodes / total_node_counts)
            scores.append(edge_imbalance + node_imbalance)
        best_fold = np.argmin(scores)
        folds[best_fold].append(idx)
        fold_edge_counts[best_fold] += graph_edges
        fold_node_counts[best_fold] += graph_nodes

    return folds
class MultiTaskLoss(torch.nn.Module):
    def __init__(self, edge_weight=None, node_weight=None, alpha=0.6):
        super().__init__()
        self.edge_criterion = torch.nn.CrossEntropyLoss(weight=edge_weight)
        self.node_criterion = torch.nn.CrossEntropyLoss(weight=node_weight)
        self.alpha = alpha  # Weight for edge vs. node loss

    def forward(self, edge_pred, edge_true, node_pred, node_true):
        edge_loss = self.edge_criterion(edge_pred, edge_true)
        node_loss = self.node_criterion(node_pred, node_true)
        return self.alpha * edge_loss + (1 - self.alpha) * node_loss


# RGCN Model
class EnhancedLegalRGCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, num_relations, dropout=DROPOUT_RATE):
        super().__init__()
        self.conv1 = RGCNConv(in_channels, hidden_channels, num_relations)
        self.conv2 = RGCNConv(hidden_channels, hidden_channels, num_relations)
        self.conv3 = RGCNConv(hidden_channels, hidden_channels, num_relations)
        self.edge_classifier = torch.nn.Sequential(
            torch.nn.Linear(2 * hidden_channels, hidden_channels),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(hidden_channels, len(EDGE_TYPE_MAP))
        )
        self.node_classifier = torch.nn.Sequential(
            torch.nn.Linear(hidden_channels, hidden_channels // 2),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(hidden_channels // 2, len(NODE_TYPE_MAP))
        )

    def forward(self, x, edge_index, edge_type):
        x1 = F.relu(self.conv1(x, edge_index, edge_type))
        x2 = F.relu(self.conv2(x1, edge_index, edge_type))
        x3 = self.conv3(x2, edge_index, edge_type)

        row, col = edge_index
        edge_features = torch.cat([x3[row], x3[col]], dim=-1)
        edge_out = self.edge_classifier(edge_features)
        node_out = self.node_classifier(x3)

        return edge_out, node_out

# Training and Evaluation
def train_epoch(model, data_loader, optimizer, criterion):
    model.train()
    total_loss = 0
    for data in data_loader:
        data = data.to(device)
        optimizer.zero_grad()
        edge_pred, node_pred = model(data.x, data.edge_index, data.edge_type)
        loss = criterion(edge_pred, data.edge_type, node_pred, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(data_loader)


def validate(model, data_loader, criterion):
    model.eval()
    edge_true, edge_pred = [], []
    node_true, node_pred = [], []
    total_loss = 0

    with torch.no_grad():
        for data in data_loader:
            data = data.to(device)
            edge_out, node_out = model(data.x, data.edge_index, data.edge_type)
            loss = criterion(edge_out, data.edge_type, node_out, data.y)
            total_loss += loss.item()

            edge_true.extend(data.edge_type.cpu().numpy())
            edge_pred.extend(edge_out.argmax(dim=1).cpu().numpy())
            node_true.extend(data.y.cpu().numpy())
            node_pred.extend(node_out.argmax(dim=1).cpu().numpy())

    edge_f1 = f1_score(edge_true, edge_pred, average="macro")
    node_f1 = f1_score(node_true, node_pred, average="macro")
    return total_loss / len(data_loader), edge_f1, node_f1

def run_joint_rgcn_cross_validation():
    data_list = load_graph_data(GRAPH_DATA_DIR)
    folds = create_balanced_splits(data_list, NUM_FOLDS, SEED)
    all_f1_scores = []

    for fold_idx, val_indices in enumerate(folds):
        print(f"=== Fold {fold_idx + 1}/{NUM_FOLDS} ===")
        train_indices = [i for i in range(len(data_list)) if i not in val_indices]
        train_data = [data_list[idx] for idx in train_indices]
        val_data = [data_list[idx] for idx in val_indices]

        # Initialize Model
        model = EnhancedLegalRGCN(in_channels=770, hidden_channels=32, num_relations=3).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
        scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=3)

        # Compute class weights for balanced training
        edge_weight = calculate_class_weights(train_data, "edge_type", len(EDGE_TYPE_MAP))
        node_weight = calculate_class_weights(train_data, "y", len(NODE_TYPE_MAP))
        criterion = MultiTaskLoss(edge_weight=edge_weight, node_weight=node_weight)

        best_val_loss = float("inf")
        early_stop_counter = 0
        max_patience = 5

        for epoch in range(EPOCHS):
            train_loss = train_epoch(model, train_data, optimizer, criterion)
            val_loss, edge_f1, node_f1 = validate(model, val_data, criterion)
            scheduler.step(val_loss)

            print(f"Epoch {epoch + 1}/{EPOCHS} | Train Loss: {train_loss:.4f} | "
                  f"Val Loss: {val_loss:.4f} | Edge F1: {edge_f1:.4f} | Node F1: {node_f1:.4f}")

            # Early stopping
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                early_stop_counter = 0
            else:
                early_stop_counter += 1
                if early_stop_counter >= max_patience:
                    print("Early stopping triggered.")
                    break

        # Save F1 scores for fold
        all_f1_scores.append((edge_f1, node_f1))

    avg_edge_f1 = np.mean([score[0] for score in all_f1_scores])
    avg_node_f1 = np.mean([score[1] for score in all_f1_scores])
    print(f"\nFinal Average Edge F1: {avg_edge_f1:.4f}")
    print(f"Final Average Node F1: {avg_node_f1:.4f}")

def calculate_class_weights(data_list, attr, num_classes):
    counts = torch.zeros(num_classes, device=device)  # Ensure counts is on the correct device
    for data in data_list:
        attr_tensor = getattr(data, attr).to(device)  # Move the attribute tensor to the same device
        counts += torch.bincount(attr_tensor, minlength=num_classes)
    weights = 1.0 / (counts + 1e-5)  # Avoid division by zero
    return weights



if __name__ == "__main__":
    run_joint_rgcn_cross_validation()
