# graph data preparation

In [31]:
import os
from pathlib import Path
import pandas as pd
from transformers import AutoTokenizer, AutoModel
import torch
from torch_geometric.data import Data
from tqdm import tqdm


In [32]:
CSV_PATH = Path("./data/filtered_all_updated_events_removed_conclusion_source_and_empty_events.csv")
OUTPUT_DIR = Path("InCaseLawBERT_updated_events_raw_graph_data_for_joint_prediction_csv")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)


In [33]:
tokenizer = AutoTokenizer.from_pretrained('suyamoonpathak/incaselawbert-pcna-events-finetuned')
model = AutoModel.from_pretrained('suyamoonpathak/incaselawbert-pcna-events-finetuned')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)


In [34]:
# Edge type mapping from CSV relations to integers compatible with previous code convention
relation_to_edge_type = {
    "support": 0,
    "attack": 1,
    "no-relation": 2
}

# Node label mapping for node classification
node_type_to_label = {
    "premise": 0,
    "conclusion": 1,
    "non-argumentative": 2
}

# Priority order for node type when a node appears as both source and target with different types
node_type_priority = {
    "conclusion": 3,
    "premise": 2,
    "non-argumentative": 1,
}


In [35]:
def generate_embeddings(texts, batch_size=4):
    embeddings = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Generating embeddings"):
        batch = texts[i:i+batch_size]
        inputs = tokenizer(batch,
                           padding=True,
                           truncation=True,
                           max_length=512,
                           return_tensors="pt").to(device)
        with torch.no_grad():
            outputs = model(**inputs)
        embeddings.append(outputs.last_hidden_state[:, 0, :].cpu())
    return torch.cat(embeddings, dim=0)

def generate_raw_embeddings_from_word_embeddings(texts, batch_size=8):
    embeddings = []
    
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        inputs = tokenizer(
            batch, 
            padding=True, 
            truncation=True, 
            max_length=512, 
            return_tensors="pt"
        ).to(device)
        
        # Extract token IDs from the tokenizer output
        input_ids = inputs['input_ids']
        
        with torch.no_grad():
            # Get the embeddings directly from the embedding layer
            word_embeddings = model.embeddings.word_embeddings(input_ids)
            
            cls_embeddings = word_embeddings[:, 0, :]  
            
            embeddings.append(cls_embeddings.cpu())
    
    return torch.cat(embeddings, dim=0)


In [36]:
def determine_node_types(source_infos, target_infos):
    """
    Given dictionaries mapping text to set of source_types and target_types,
    determine final node type by priority:
    conclusion > premise > non-argumentative
    """
    node_types = {}
    all_nodes = set(list(source_infos.keys()) + list(target_infos.keys()))
    
    for text in all_nodes:
        source_types = source_infos.get(text, set())
        target_types = target_infos.get(text, set())
        combined_types = source_types.union(target_types)
        
        # Pick type by priority order
        best_type = None
        best_priority = 0
        for t in combined_types:
            prio = node_type_priority.get(t, 0)
            if prio > best_priority:
                best_priority = prio
                best_type = t
        if best_type is None:
            # fallback to non-argumentative if something unexpected
            best_type = "non-argumentative"
        node_types[text] = best_type
    return node_types


In [37]:
df = pd.read_csv(CSV_PATH)
file_names = df['file_name'].unique()
assert len(file_names) == 40, f"Expected 40 files, found {len(file_names)}"


In [38]:
all_data = []

for file_name in tqdm(file_names, desc="Processing cases"):
    sub_df = df[df['file_name'] == file_name]

    # 1. Extract all unique node texts (source and target)
    source_events = sub_df['source_event'].tolist()
    target_events = sub_df['target_event'].tolist()
    unique_texts = list(set(source_events).union(set(target_events)))

    # 2. Collect node types from source_type and target_type
    # Map text to set of source_types or target_types (because text can appear multiple times)
    source_type_map = {}
    target_type_map = {}
    for _, row in sub_df.iterrows():
        # source
        st = row['source_event']
        s_type = row['source_type'].strip().lower()
        source_type_map.setdefault(st, set()).add(s_type)
        # target
        tt = row['target_event']
        t_type = row['target_type'].strip().lower()
        target_type_map.setdefault(tt, set()).add(t_type)

    # 3. Determine final node types by priority of presence among source and target types
    node_types = determine_node_types(source_type_map, target_type_map)

    # 4. Map each unique text to index
    text_to_idx = {text: idx for idx, text in enumerate(unique_texts)}

    # 5. Generate embeddings for nodes
    embeddings = generate_raw_embeddings_from_word_embeddings(unique_texts)

    # 6. Create one-hot encoded node type features (3 classes)
    node_features_type = torch.zeros((len(unique_texts), 3))
    node_labels = []
    for i, text in enumerate(unique_texts):
        ntype = node_types[text]
        label = node_type_to_label.get(ntype, 2)  # Default non-argumentative if missing
        node_labels.append(label)
        if ntype == "premise":
            node_features_type[i, 0] = 1
        elif ntype == "conclusion":
            node_features_type[i, 1] = 1
        else:
            node_features_type[i, 2] = 1

    node_labels = torch.tensor(node_labels, dtype=torch.long)
    node_features = torch.cat([embeddings, node_features_type], dim=1)

    # 7. Build edges - edge indices and edge types
    edge_indices = []
    edge_types = []
    for _, row in sub_df.iterrows():
        src_text = row['source_event']
        tgt_text = row['target_event']
        rel = row['relation'].strip().lower()
        if src_text in text_to_idx and tgt_text in text_to_idx:
            edge_indices.append([text_to_idx[src_text], text_to_idx[tgt_text]])
            edge_types.append(relation_to_edge_type[rel])
        else:
            print(f"Warning: Missing node index for edge {src_text} -> {tgt_text} in file {file_name}")

    edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
    edge_type = torch.tensor(edge_types, dtype=torch.long)

    # 8. Create Data object
    data = Data(
        x=node_features,
        edge_index=edge_index,
        edge_type=edge_type,
        y=node_labels,
        xml_file=file_name
    )

    # 9. Save Data object as .pt file with filename exactly as file_name.pt
    output_path = OUTPUT_DIR / f"{file_name}.pt"
    torch.save(data, output_path)
    all_data.append(data)




Processing cases: 100%|██████████| 40/40 [00:00<00:00, 65.78it/s]


In [39]:
print(f"Processed and saved {len(all_data)} files.")

assert len(all_data) == 40, f"Warning: processed file count mismatch, expected 40 but got {len(all_data)}"

Processed and saved 40 files.
