In [None]:
import polars as pl
import numpy as np
from tqdm.auto import tqdm
import torch

In [None]:
ds = pl.scan_parquet("./data/datasets/pl/raw/*")
ds.columns

In [None]:
ds.with_columns(pl.col("text_legal_bases").list.len().alias("lb_len")).select("lb_len").collect().to_pandas().hist()

### Legal base as an edge

In [None]:
isap_docs = (
    ds
    .filter(pl.col("text_legal_bases").list.len() > 0)
    .with_columns(
        pl.col("text_legal_bases")
        .map_elements(lambda legal_bases: set(lb["isap_id"] for lb in legal_bases))
        .alias("isap_id")
    )
    .select(["_id", "isap_id"])
    .collect()
    .to_pandas()
)
num_docs = len(isap_docs)
isap_docs = isap_docs.reset_index()
isap_docs.head()

In [None]:
isap_ids = isap_docs['isap_id'].explode().unique()
isap_id_2_index = {iid: idx for idx, iid in enumerate(isap_ids)}
print(f"Unique ISAP ids: {len(isap_ids)}")

# Create graph

In [None]:
from torch_geometric.data import HeteroData

In [None]:
data = HeteroData()

data["legal_bases"].x = torch.zeros(len(isap_ids), 1)
data["judgements"].x = torch.zeros(num_docs, 1)

edges = []
for index, doc in isap_docs.iterrows():
    for iid in doc["isap_id"]:
        edges.append([index, isap_is_2_index[iid]])

edge_index = torch.tensor(edges, dtype=torch.long)
edge_index