# graph data preparation

In [68]:
import os
from pathlib import Path
import pandas as pd
from transformers import AutoTokenizer, AutoModel
import torch
from torch_geometric.data import Data
from tqdm import tqdm


In [69]:
CSV_PATH = Path("4 PCNA_SANR_final.csv")
# MODEL_DIR = Path("best_model_legalbert_pc") #suyamoonpathak/legalbert-pcna-finetuned
OUTPUT_DIR = Path("LegalBERT_PCNA_processed_graph_data_for_joint_prediction_csv")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)


In [70]:
tokenizer = AutoTokenizer.from_pretrained('suyamoonpathak/legalbert-pcna-finetuned')
model = AutoModel.from_pretrained('suyamoonpathak/legalbert-pcna-finetuned')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)


In [71]:
# Edge type mapping from CSV relations to integers compatible with previous code convention
relation_to_edge_type = {
    "support": 0,
    "attack": 1,
    "no-relation": 2
}

# Node label mapping for node classification
node_type_to_label = {
    "prem": 0,
    "conc": 1,
    "na": 2
}

# Priority order for node type when a node appears as both source and target with different types
node_type_priority = {
    "conc": 3,
    "prem": 2,
    "na": 1,
}


In [72]:
def generate_embeddings(texts, batch_size=4):
    embeddings = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Generating embeddings"):
        batch = texts[i:i+batch_size]
        inputs = tokenizer(batch,
                           padding=True,
                           truncation=True,
                           max_length=512,
                           return_tensors="pt").to(device)
        with torch.no_grad():
            outputs = model(**inputs)
        embeddings.append(outputs.last_hidden_state[:, 0, :].cpu())
    return torch.cat(embeddings, dim=0)

def generate_raw_embeddings_from_word_embeddings(texts, batch_size=8):
    embeddings = []
    
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        inputs = tokenizer(
            batch, 
            padding=True, 
            truncation=True, 
            max_length=512, 
            return_tensors="pt"
        ).to(device)
        
        # Extract token IDs from the tokenizer output
        input_ids = inputs['input_ids']
        
        with torch.no_grad():
            # Get the embeddings directly from the embedding layer
            word_embeddings = model.embeddings.word_embeddings(input_ids)
            
            cls_embeddings = word_embeddings[:, 0, :]  
            
            embeddings.append(cls_embeddings.cpu())
    
    return torch.cat(embeddings, dim=0)


def generate_gaussian_embeddings(texts, batch_size=4):
    hidden_size = model.config.hidden_size
    # hidden_size should be set to match model.config.hidden_size
    if hidden_size is None:
        raise ValueError("hidden_size must be specified to match the model's output dimensionality.")
    
    embeddings = []
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    for i in tqdm(range(0, len(texts), batch_size)):
        batch_len = len(texts[i:i+batch_size])
        # Generate gaussian noise with shape [batch_len, hidden_size]
        noise = torch.randn(batch_len, hidden_size).to(device)
        embeddings.append(noise.cpu())
    
    return torch.cat(embeddings, dim=0)

In [73]:
def determine_node_types(source_infos, target_infos):
    """
    Given dictionaries mapping text to set of source_types and target_types,
    determine final node type by priority:
    conclusion > premise > non-argumentative
    """
    node_types = {}
    all_nodes = set(list(source_infos.keys()) + list(target_infos.keys()))
    
    for text in all_nodes:
        source_types = source_infos.get(text, set())
        target_types = target_infos.get(text, set())
        combined_types = source_types.union(target_types)
        
        # Pick type by priority order
        best_type = None
        best_priority = 0
        for t in combined_types:
            prio = node_type_priority.get(t, 0)
            if prio > best_priority:
                best_priority = prio
                best_type = t
        # if best_type is None:
        #     # fallback to non-argumentative if something unexpected
        #     best_type = "non-argumentative"
        node_types[text] = best_type
    return node_types


In [74]:
df = pd.read_csv(CSV_PATH)
file_names = df['file_name'].unique()
assert len(file_names) == 40, f"Expected 40 files, found {len(file_names)}"


In [75]:
all_data = []

for file_name in tqdm(file_names, desc="Processing cases"):
    sub_df = df[df['file_name'] == file_name]

    # 1. Extract all unique node texts (source and target)
    source_texts = sub_df['source_text'].tolist()
    target_texts = sub_df['target_text'].tolist()
    unique_texts = list(set(source_texts).union(set(target_texts)))

    # 2. Collect node types from source_type and target_type
    # Map text to set of source_types or target_types (because text can appear multiple times)
    source_type_map = {}
    target_type_map = {}
    for _, row in sub_df.iterrows():
        # source
        st = row['source_text']
        s_type = row['source_type'].strip().lower()
        source_type_map.setdefault(st, set()).add(s_type)
        # target
        tt = row['target_text']
        t_type = row['target_type'].strip().lower()
        target_type_map.setdefault(tt, set()).add(t_type)

    # 3. Determine final node types by priority of presence among source and target types
    node_types = determine_node_types(source_type_map, target_type_map)

    # 4. Map each unique text to index
    text_to_idx = {text: idx for idx, text in enumerate(unique_texts)}

    # 5. Generate embeddings for nodes
    embeddings = generate_embeddings(unique_texts)

    # 6. Create one-hot encoded node type features (3 classes)
    node_features_type = torch.zeros((len(unique_texts), 3))
    node_labels = []
    for i, text in enumerate(unique_texts):
        ntype = node_types[text]
        label = node_type_to_label.get(ntype, 2)  # Default non-argumentative if missing
        node_labels.append(label)
        if ntype == "prem":
            node_features_type[i, 0] = 1
        elif ntype == "conc":
            node_features_type[i, 1] = 1
        elif ntype == "na":
            node_features_type[i, 2] = 1

    node_labels = torch.tensor(node_labels, dtype=torch.long)
    node_features = torch.cat([embeddings, node_features_type], dim=1)

    # 7. Build edges - edge indices and edge types
    edge_indices = []
    edge_types = []
    for _, row in sub_df.iterrows():
        src_text = row['source_text']
        tgt_text = row['target_text']
        rel = row['relation'].strip().lower()
        if src_text in text_to_idx and tgt_text in text_to_idx:
            edge_indices.append([text_to_idx[src_text], text_to_idx[tgt_text]])
            edge_types.append(relation_to_edge_type[rel])
        else:
            print(f"Warning: Missing node index for edge {src_text} -> {tgt_text} in file {file_name}")

    edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
    edge_type = torch.tensor(edge_types, dtype=torch.long)

    # 8. Create Data object
    data = Data(
        x=node_features,
        edge_index=edge_index,
        edge_type=edge_type,
        y=node_labels,
        xml_file=file_name
    )

    # 9. Save Data object as .pt file with filename exactly as file_name.pt
    output_path = OUTPUT_DIR / f"{file_name}.pt"
    torch.save(data, output_path)
    all_data.append(data)




Generating embeddings: 100%|██████████| 46/46 [00:08<00:00,  5.38it/s]
Generating embeddings: 100%|██████████| 46/46 [00:12<00:00,  3.80it/s]
Generating embeddings: 100%|██████████| 35/35 [00:07<00:00,  4.67it/s]
Generating embeddings: 100%|██████████| 40/40 [00:11<00:00,  3.60it/s]
Generating embeddings: 100%|██████████| 29/29 [00:08<00:00,  3.43it/s]
Generating embeddings: 100%|██████████| 46/46 [00:10<00:00,  4.20it/s]
Generating embeddings: 100%|██████████| 44/44 [00:17<00:00,  2.46it/s]
Generating embeddings: 100%|██████████| 34/34 [00:09<00:00,  3.72it/s]
Generating embeddings: 100%|██████████| 31/31 [00:07<00:00,  3.90it/s]
Generating embeddings: 100%|██████████| 29/29 [00:06<00:00,  4.36it/s]
Generating embeddings: 100%|██████████| 31/31 [00:06<00:00,  4.52it/s]
Generating embeddings: 100%|██████████| 40/40 [00:07<00:00,  5.05it/s]
Generating embeddings: 100%|██████████| 32/32 [00:08<00:00,  3.63it/s]
Generating embeddings: 100%|██████████| 38/38 [00:11<00:00,  3.25it/s]
Genera

In [76]:
print(f"Processed and saved {len(all_data)} files.")

assert len(all_data) == 40, f"Warning: processed file count mismatch, expected 40 but got {len(all_data)}"

Processed and saved 40 files.


In [77]:
import torch
from pathlib import Path

def get_dataset_stats(output_dir="LegalBERT_PCNA_processed_graph_data_for_joint_prediction_csv"):
    pt_files = list(Path(output_dir).glob("*.pt"))
    if not pt_files:
        return "No processed files found. Run XML processing first."

    total_stats = {
        'total_nodes': 0,
        'total_support': 0,
        'total_attack': 0,
        'total_no_relation': 0,  
        'files': []
    }


    for pt_file in pt_files:
        try:
            data = torch.load(pt_file, weights_only=False)
            support = (data.edge_type == 0).sum().item()
            attack = (data.edge_type == 1).sum().item()
            no_relation = (data.edge_type == 2).sum().item()  

            
            file_stats = {
                'filename': pt_file.name,
                'nodes': data.x.shape[0],
                'support': support,
                'attack': attack,
                'no_relation': no_relation
            }
            
            total_stats['total_nodes'] += file_stats['nodes']
            total_stats['total_support'] += support
            total_stats['total_attack'] += attack
            total_stats['total_no_relation'] += no_relation
            total_stats['files'].append(file_stats)
            
            
        except Exception as e:
            print(f"Error loading {pt_file.name}: {str(e)}")
    
    return total_stats

stats = get_dataset_stats()

print(f"""
## Dataset Summary
- **Total Documents**: {len(stats['files'])}
- **Total Arguments**: {stats['total_nodes']}
- **Support Relationships**: {stats['total_support']}
- **Attack Relationships**: {stats['total_attack']}
- **No-Relation Pairs**: {stats['total_no_relation']}
- **Ratio (Support:Attack:NoRel)**: {stats['total_support']}:{stats['total_attack']}:{stats['total_no_relation']}
""")

for file in stats['files']:
    print(f"""
**{file['filename']}**
- Nodes: {file['nodes']}
- Support: {file['support']}
- Attack: {file['attack']}
- No-Relation: {file['no_relation']}""")




## Dataset Summary
- **Total Documents**: 40
- **Total Arguments**: 5867
- **Support Relationships**: 2272
- **Attack Relationships**: 145
- **No-Relation Pairs**: 5000
- **Ratio (Support:Attack:NoRel)**: 2272:145:5000


**A2009_Commission of the European Communities v Koninklijke FrieslandCampina NV_M.xml.pt**
- Nodes: 148
- Support: 48
- Attack: 0
- No-Relation: 125

**A2016_European Commission v Aer Lingus Ltd and Ryanair Designated Activity Company.xml.pt**
- Nodes: 169
- Support: 61
- Attack: 2
- No-Relation: 125

**A2012_BNP Paribas and Banca Nazionale del Lavoro SpA (BNL) v European Commission.xml.pt**
- Nodes: 179
- Support: 28
- Attack: 4
- No-Relation: 125

**R2011_France Télécom SA v European Commission.xml.pt**
- Nodes: 182
- Support: 77
- Attack: 0
- No-Relation: 125

**A2011_European Commission (C-106_09 P) and Kingdom of Spain (C-107_09 P) v Government of Gibraltar and United Kingdom of Great Britain and Northern Ireland.xml.pt**
- Nodes: 203
- Support: 46
- Attack: 2
-

In [78]:
import pandas as pd

# Load the filtered CSV file
df = pd.read_csv('4 PCNA_SANR_final.csv')

# Count occurrences of each relation type in the 'relation' column
relation_counts = df['relation'].value_counts()

# Extract counts for 'support', 'attack', and 'no-relation'
support_count = relation_counts.get('support', 0)
attack_count = relation_counts.get('attack', 0)
no_relation_count = relation_counts.get('no-relation', 0)

# Print the counts
print(f"Support count: {support_count}")
print(f"Attack count: {attack_count}")
print(f"No-relation count: {no_relation_count}")


Support count: 2272
Attack count: 145
No-relation count: 5000
