# graph data preparation

In [36]:
import os
from pathlib import Path
import pandas as pd
from transformers import AutoTokenizer, AutoModel
import torch
from torch_geometric.data import Data
from tqdm import tqdm

In [37]:
CSV_PATH = Path("7 updated_events PC_SANR_final.csv")
OUTPUT_DIR = Path("BERT_processed_updated_PC_graph_data_for_edge_prediction_csv")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)


In [38]:
tokenizer = AutoTokenizer.from_pretrained('suyamoonpathak/bert-pc-updated_events-finetuned')
model = AutoModel.from_pretrained('suyamoonpathak/bert-pc-updated_events-finetuned')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)


In [39]:
# Edge type mapping from CSV relations to integers compatible with previous code convention
relation_to_edge_type = {
    "support": 0,
    "attack": 1,
    "no-relation": 2
}

# Node label mapping for node classification
node_type_to_label = {
    "prem": 0,
    "conc": 1,
}

# Priority order for node type when a node appears as both source and target with different types
node_type_priority = {
    "conc": 3,
    "prem": 2,
    
}


In [40]:
def generate_embeddings(texts, batch_size=4):
    embeddings = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Generating embeddings"):
        batch = texts[i:i+batch_size]
        inputs = tokenizer(batch,
                           padding=True,
                           truncation=True,
                           max_length=512,
                           return_tensors="pt").to(device)
        with torch.no_grad():
            outputs = model(**inputs)
        embeddings.append(outputs.last_hidden_state[:, 0, :].cpu())
    return torch.cat(embeddings, dim=0)

def generate_raw_embeddings_from_word_embeddings(texts, batch_size=8):
    embeddings = []
    
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        inputs = tokenizer(
            batch, 
            padding=True, 
            truncation=True, 
            max_length=512, 
            return_tensors="pt"
        ).to(device)
        
        # Extract token IDs from the tokenizer output
        input_ids = inputs['input_ids']
        
        with torch.no_grad():
            # Get the embeddings directly from the embedding layer
            word_embeddings = model.embeddings.word_embeddings(input_ids)
            
            cls_embeddings = word_embeddings[:, 0, :]  
            
            embeddings.append(cls_embeddings.cpu())
    
    return torch.cat(embeddings, dim=0)


def generate_gaussian_embeddings(texts, batch_size=4):
    hidden_size = model.config.hidden_size
    # hidden_size should be set to match model.config.hidden_size
    if hidden_size is None:
        raise ValueError("hidden_size must be specified to match the model's output dimensionality.")
    
    embeddings = []
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    for i in tqdm(range(0, len(texts), batch_size)):
        batch_len = len(texts[i:i+batch_size])
        # Generate gaussian noise with shape [batch_len, hidden_size]
        noise = torch.randn(batch_len, hidden_size).to(device)
        embeddings.append(noise.cpu())
    
    return torch.cat(embeddings, dim=0)

In [41]:
def determine_node_types(key_type_map):
    """
    Given mapping from node keys to set of types, determine by priority.
    """
    node_types = {}
    for key, types in key_type_map.items():
        best_type = None
        best_priority = 0
        for t in types:
            prio = node_type_priority.get(t, 0)
            if prio > best_priority:
                best_priority = prio
                best_type = t
        node_types[key] = best_type
    return node_types


In [42]:
df = pd.read_csv(CSV_PATH)
file_names = df['file_name'].unique()
assert len(file_names) == 40, f"Expected 40 files, found {len(file_names)}"


In [43]:
all_data = []
for file_name in tqdm(file_names, desc="Processing cases"):
    sub_df = df[df['file_name'] == file_name]
    node_key_texts = {}
    key_type_map = {}

    # First pass: collect all unique node keys and texts, types
    for _, row in sub_df.iterrows():
        for role in ["source", "target"]:
            node_ID = row[f"{role}_ID"]
            node_text = str(row[f"{role}_text"]) if pd.notnull(row[f"{role}_text"]) else ""
            node_type = row[f"{role}_type"].strip().lower()
            key = (node_ID, node_text, file_name) if node_text else (node_ID, file_name)
            node_key_texts[key] = node_text
            key_type_map.setdefault(key, set()).add(node_type)

    # Assign node type per priority
    node_types = determine_node_types(key_type_map)

    # Build node index
    node_keys = list(node_key_texts.keys())
    key_to_idx = {key: idx for idx, key in enumerate(node_keys)}
    texts_for_embedding = [node_key_texts[key] if node_key_texts[key]!="" else "empty" for key in node_keys]
    embeddings = generate_embeddings(texts_for_embedding)

    node_features_type = torch.zeros((len(node_keys), 2))
    node_labels = []
    for i, key in enumerate(node_keys):
        ntype = node_types[key]
        label = node_type_to_label.get(ntype, 2)  # Default non-argumentative if missing
        node_labels.append(label)
        if ntype == "prem":
            node_features_type[i, 0] = 1
        elif ntype == "conc":
            node_features_type[i, 1] = 1

    node_labels = torch.tensor(node_labels, dtype=torch.long)
    node_features = torch.cat([embeddings, node_features_type], dim=1)

    edge_indices = []
    edge_types = []
    for _, row in sub_df.iterrows():
        src_ID = row['source_ID']
        src_text = str(row['source_text']) if pd.notnull(row['source_text']) else ""
        tgt_ID = row['target_ID']
        tgt_text = str(row['target_text']) if pd.notnull(row['target_text']) else ""
        rel = row['relation'].strip().lower()

        src_key = (src_ID, src_text, file_name) if src_text else (src_ID, file_name)
        tgt_key = (tgt_ID, tgt_text, file_name) if tgt_text else (tgt_ID, file_name)
        if src_key in key_to_idx and tgt_key in key_to_idx:
            edge_indices.append([key_to_idx[src_key], key_to_idx[tgt_key]])
            edge_types.append(relation_to_edge_type[rel])
        else:
            print(f"Warning: Missing node index for edge {src_key} -> {tgt_key} in file {file_name}")

    edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
    edge_type = torch.tensor(edge_types, dtype=torch.long)

    data = Data(
        x=node_features,
        edge_index=edge_index,
        edge_type=edge_type,
        y=node_labels,
        xml_file=file_name
    )
    output_path = OUTPUT_DIR / f"{file_name}.pt"
    torch.save(data, output_path)
    all_data.append(data)



Generating embeddings: 100%|██████████| 24/24 [00:00<00:00, 97.67it/s]
Generating embeddings: 100%|██████████| 25/25 [00:00<00:00, 90.48it/s]
Generating embeddings: 100%|██████████| 14/14 [00:00<00:00, 97.71it/s]
Generating embeddings: 100%|██████████| 19/19 [00:00<00:00, 85.46it/s]
Generating embeddings: 100%|██████████| 13/13 [00:00<00:00, 97.23it/s]
Generating embeddings: 100%|██████████| 24/24 [00:00<00:00, 107.07it/s]
Generating embeddings: 100%|██████████| 28/28 [00:00<00:00, 85.73it/s]
Generating embeddings: 100%|██████████| 14/14 [00:00<00:00, 101.09it/s]
Generating embeddings: 100%|██████████| 19/19 [00:00<00:00, 91.77it/s]
Generating embeddings: 100%|██████████| 10/10 [00:00<00:00, 89.81it/s]
Generating embeddings: 100%|██████████| 8/8 [00:00<00:00, 82.94it/s]
Generating embeddings: 100%|██████████| 19/19 [00:00<00:00, 98.64it/s]
Generating embeddings: 100%|██████████| 17/17 [00:00<00:00, 105.94it/s]
Generating embeddings: 100%|██████████| 16/16 [00:00<00:00, 99.82it/s]
Gener

In [44]:
print(f"Processed and saved {len(all_data)} files.")

assert len(all_data) == 40, f"Warning: processed file count mismatch, expected 40 but got {len(all_data)}"

Processed and saved 40 files.


In [None]:
import torch
from pathlib import Path

def get_dataset_stats(output_dir="LegalBERT_raw_graph_data_for_joint_prediction_csv"):
    pt_files = list(Path(output_dir).glob("*.pt"))
    if not pt_files:
        return "No processed files found. Run XML processing first."

    total_stats = {
        'total_nodes': 0,
        'total_support': 0,
        'total_attack': 0,
        'total_no_relation': 0,  
        'files': []
    }


    for pt_file in pt_files:
        try:
            data = torch.load(pt_file, weights_only=False)
            support = (data.edge_type == 0).sum().item()
            attack = (data.edge_type == 1).sum().item()
            no_relation = (data.edge_type == 2).sum().item()  

            
            file_stats = {
                'filename': pt_file.name,
                'nodes': data.x.shape[0],
                'support': support,
                'attack': attack,
                'no_relation': no_relation
            }
            
            total_stats['total_nodes'] += file_stats['nodes']
            total_stats['total_support'] += support
            total_stats['total_attack'] += attack
            total_stats['total_no_relation'] += no_relation
            total_stats['files'].append(file_stats)
            
            
        except Exception as e:
            print(f"Error loading {pt_file.name}: {str(e)}")
    
    return total_stats

stats = get_dataset_stats()

print(f"""
## Dataset Summary
- **Total Documents**: {len(stats['files'])}
- **Total Arguments**: {stats['total_nodes']}
- **Support Relationships**: {stats['total_support']}
- **Attack Relationships**: {stats['total_attack']}
- **No-Relation Pairs**: {stats['total_no_relation']}
- **Ratio (Support:Attack:NoRel)**: {stats['total_support']}:{stats['total_attack']}:{stats['total_no_relation']}
""")

for file in stats['files']:
    print(f"""
**{file['filename']}**
- Nodes: {file['nodes']}
- Support: {file['support']}
- Attack: {file['attack']}
- No-Relation: {file['no_relation']}""")



TypeError: string indices must be integers, not 'str'

In [None]:
import pandas as pd

# Load the filtered CSV file
df = pd.read_csv('3 PC_SANR_final.csv')

# Count occurrences of each relation type in the 'relation' column
relation_counts = df['relation'].value_counts()

# Extract counts for 'support', 'attack', and 'no-relation'
support_count = relation_counts.get('support', 0)
attack_count = relation_counts.get('attack', 0)
no_relation_count = relation_counts.get('no-relation', 0)

# Print the counts
print(f"Support count: {support_count}")
print(f"Attack count: {attack_count}")
print(f"No-relation count: {no_relation_count}")


Support count: 2272
Attack count: 145
No-relation count: 5000
