In [None]:
import pandas as pd
import numpy as np
import anndata
import json
import os
import scanpy as sc
from sklearn.model_selection import train_test_split

## Data download

In [None]:

# Define data paths
data_dir = "data_input"
os.makedirs(data_dir, exist_ok=True)

pancreas_adata_path = os.path.join(data_dir, "pancreas_full.h5ad")
train_path = os.path.join(data_dir, "pancreas_train.h5ad")
valid_path = os.path.join(data_dir, "pancreas_valid.h5ad")
test_path  = os.path.join(data_dir, "pancreas_test.h5ad")

# Download if missing, otherwise load from local file
pancreas_adata = sc.read(
    pancreas_adata_path,
    backup_url="https://figshare.com/ndownloader/files/24539828",
)

# Split dataset by technology: keep smartseq2/celseq2 as held-out test
query_mask = pancreas_adata.obs["tech"].isin(["smartseq2", "celseq2"]).to_numpy()
pancreas_no_test = pancreas_adata[~query_mask].copy()
pancreas_test    = pancreas_adata[ query_mask].copy()

# 80/20 train/valid split on the remaining data, stratified by technology
y = pancreas_no_test.obs["tech"].astype("category")
indices = np.arange(pancreas_no_test.n_obs)

idx_train, idx_valid = train_test_split(
    indices,
    test_size=0.20,
    train_size=0.80,
    random_state=42,
    shuffle=True,
    stratify=y  # stratify by technology
)

pancreas_train = pancreas_no_test[idx_train].copy()
pancreas_valid = pancreas_no_test[idx_valid].copy()

# Save splits
pancreas_train.write(train_path)
pancreas_valid.write(valid_path)
pancreas_test.write(test_path)

print(
    f"Train: {pancreas_train.n_obs} cells | "
    f"Valid: {pancreas_valid.n_obs} cells | "
    f"Test: {pancreas_test.n_obs} cells"
)

# Print counts per technology
print("\nCells per technology:")
for name, ad in [("Train", pancreas_train),
                 ("Valid", pancreas_valid),
                 ("Test", pancreas_test)]:
    counts = ad.obs["tech"].value_counts().sort_index()
    print(f"\n{name} split:")
    for tech, n in counts.items():
        print(f"  {tech}: {n}")

# --- Cleanup: delete the original full dataset file ---
del pancreas_adata  # drop reference to ensure no open handle
try:
    if os.path.exists(pancreas_adata_path):
        os.remove(pancreas_adata_path)
        print(f"Deleted '{pancreas_adata_path}'")
except Exception as e:
    print(f"[WARN] Could not delete '{pancreas_adata_path}': {e}")


# --- Save full gene list to JSON ---
all_genes = pancreas_train.var_names.tolist()

genes_json_path = os.path.join("data_input", "all_genes_list.json")
os.makedirs("data_input", exist_ok=True)

with open(genes_json_path, "w") as f:
    json.dump(all_genes, f, indent=2)

print(f"Saved {len(all_genes)} genes to {genes_json_path}")

## Data inspection

In [None]:
# Read the data
adata = anndata.read_h5ad("./data_input/pancreas_train.h5ad")

In [None]:
# Display the AnnData object summary
print("AnnData object summary:")
print(adata)

# Display the first few rows of the observation metadata
print("\nFirst 5 rows of adata.obs:")
print(adata.obs.head())

# Display available layers
print("\nAvailable layers in adata:")
print(adata.layers.keys())

# Display the first 5x5 block of the counts layer (if it exists)
print("\nFirst 5x5 of counts layer:")
print(pd.DataFrame(adata.layers["counts"][:5, :5], 
                    columns=adata.var_names[:5], 
                    index=adata.obs_names[:5]))

In [None]:
# 1. How many unique technologies are present, and what are their names?



In [None]:
# 2. How many samples (cells) belong to each technology?



In [None]:
# 3. What is the total number of genes measured in the dataset?



In [None]:
# 4. What is the total number of samples (cells) in the dataset?



## Variance analysis

In [None]:
# 5. For each technology, calculate the variance of each gene across all cells.



In [None]:
# 6. Compute a weighted average of these variances for each gene, using the number of cells per technology as weights.



In [None]:
# 7. Save a list containing the top 2000 genes
#top2000_genes = ...

# Save to a JSON file
#with open("./data_output/top2000_genes_centralized.json", "w") as f:
#    json.dump(top2000_genes, f, indent=2)