In [None]:
YOUR_ID = 0  # Set your client ID here

import os
import numpy as np
import scanpy as sc
from sklearn.model_selection import train_test_split

# ----------------------------
# Define paths
# ----------------------------
data_dir = "data_input"
os.makedirs(data_dir, exist_ok=True)

pancreas_adata_path = os.path.join(data_dir, "pancreas_full.h5ad")
train_path = os.path.join(data_dir, "pancreas_train.h5ad")
valid_path = os.path.join(data_dir, "pancreas_valid.h5ad")

# ----------------------------
# Load dataset (download if missing)
# ----------------------------
pancreas_adata = sc.read(
    pancreas_adata_path,
    backup_url="https://figshare.com/ndownloader/files/24539828",
)

# ----------------------------
# Exclude held-out technologies
# ----------------------------
held_out_techs = ["smartseq2", "celseq2"]
query_mask = pancreas_adata.obs["tech"].isin(held_out_techs).to_numpy()
pancreas_adata = pancreas_adata[~query_mask].copy()

# Select technology by client ID
techs = sorted(map(str, pancreas_adata.obs["tech"].unique()))
pancreas_adata = pancreas_adata[pancreas_adata.obs["tech"] == techs[YOUR_ID]].copy()

# ----------------------------
# Train/Validation split (80/20)
# ----------------------------
indices = np.arange(pancreas_adata.n_obs)

idx_train, idx_valid = train_test_split(
    indices,
    test_size=0.20,
    train_size=0.80,
    random_state=42,
    shuffle=True,
)

pancreas_train = pancreas_adata[idx_train].copy()
pancreas_valid = pancreas_adata[idx_valid].copy()

# ----------------------------
# Save splits
# ----------------------------
pancreas_train.write(train_path)
pancreas_valid.write(valid_path)

print(
    f"Train: {pancreas_train.n_obs} cells | "
    f"Valid: {pancreas_valid.n_obs} cells"
)

# ----------------------------
# Print counts per technology
# ----------------------------
print("\nCells per technology:")
for name, ad in [("Train", pancreas_train), ("Valid", pancreas_valid)]:
    counts = ad.obs["tech"].value_counts().sort_index()
    print(f"\n{name} split:")
    for tech, n in counts.items():
        print(f"  {tech}: {n}")

# ----------------------------
# Cleanup: delete the original full dataset
# ----------------------------
del pancreas_adata  # Drop reference to free memory
try:
    if os.path.exists(pancreas_adata_path):
        os.remove(pancreas_adata_path)
except Exception as e:
    print(f"[WARN] Could not delete '{pancreas_adata_path}': {e}")
