# Medgemma Entity Linking

In [1]:
from pathlib import Path
from rich.console import Console
import polars as pl
import json

cons = Console()

## Load the MDACE dataset

In [2]:
inpatient_path = Path("../data/mdace/Inpatient/parquet")
assert inpatient_path.exists(), "The specified path does not exist."

In [3]:
df_inpatient = pl.read_parquet(inpatient_path)
cons.print(df_inpatient.schema)
cons.print(df_inpatient.shape)
cons.print(df_inpatient.head())

In [4]:
df_inpatient_icd10cm = df_inpatient.filter(pl.col("code_system") == "ICD-10-CM")

cons.print(df_inpatient_icd10cm.schema)
cons.print(df_inpatient_icd10cm.shape)
cons.print(df_inpatient_icd10cm.head())

## Split into train, validation, and test sets

In [5]:
import random

# Set random seed for reproducibility
random.seed(42)

# Get the total number of rows
n = df_inpatient_icd10cm.shape[0]

# Generate random indices and shuffle
indices = list(range(n))
random.shuffle(indices)

# Calculate split points
train_size = int(n * 0.6)
val_size = int(n * 0.2)

# Split indices
train_indices = indices[:train_size]
val_indices = indices[train_size:train_size + val_size]
test_indices = indices[train_size + val_size:]

# Create the split dataframes
df_train = df_inpatient_icd10cm[train_indices]
df_val = df_inpatient_icd10cm[val_indices]
df_test = df_inpatient_icd10cm[test_indices]

# Print split information
cons.print(f"Train set: {df_train.shape[0]} rows ({df_train.shape[0]/n*100:.1f}%)")
cons.print(f"Validation set: {df_val.shape[0]} rows ({df_val.shape[0]/n*100:.1f}%)")
cons.print(f"Test set: {df_test.shape[0]} rows ({df_test.shape[0]/n*100:.1f}%)")

In [6]:
df_train.write_parquet(inpatient_path / "inpatient_icd10cm_train.parquet")
df_val.write_parquet(inpatient_path / "inpatient_icd10cm_val.parquet")
df_test.write_parquet(inpatient_path / "inpatient_icd10cm_test.parquet")