In [1]:
# %%
import scanpy as sc
import polars as pl
from sklearn.model_selection import StratifiedShuffleSplit
from datasets import Dataset, DatasetDict, ClassLabel
from datasets import Features, Value, ClassLabel, Sequence  
import os
import numpy as np
import pandas as pd

import pyarrow as pa
import pyarrow.parquet as pq

def merge_h5ad_to_parquet_arrow(file_paths):
    all_tables = []
    total_cells_loaded = 0

    for i, fp in enumerate(file_paths, 1):
        print(f"Loading file {i}/{len(file_paths)}: {fp} ...")
        adata = sc.read_h5ad(fp)
        cell_ids = adata.obs_names.to_list()
        n_cells = len(cell_ids)
        n_features = adata.X.shape[1]
        
        print(f"  - File {i}: {n_cells:,} cells, {n_features:,} features")
        total_cells_loaded += n_cells

        emb = adata.X
        if hasattr(emb, "toarray"):
            emb = emb.toarray()

        # Arrow FixedSizeListArray
        flat_emb = emb.ravel()
        dim = emb.shape[1]
        emb_array = pa.FixedSizeListArray.from_arrays(
            pa.array(flat_emb, type=pa.float32()), dim
        )

        table = pa.table({
            "cell_id": pa.array(cell_ids, type=pa.string()),
            "embedding": emb_array
        })

        all_tables.append(table)

    print(f"\nTotal cells loaded: {total_cells_loaded:,}")
    print("Concatenating all Arrow tables...")
    merged_table = pa.concat_tables(all_tables)

    # Deduplicazione basata su cell_id (in-memory, potrebbe essere lento se enorme)
    df_temp = merged_table.to_pandas()
    before = len(df_temp)
    df_temp = df_temp.drop_duplicates(subset=["cell_id"])
    after = len(df_temp)
    print(f"Duplicates removed: {before - after:,}")

    # Torna a Arrow e salva
    dedup_table = pa.Table.from_pandas(df_temp)

    return dedup_table


def add_cell_annotations(pl_df, adata: sc.AnnData) -> pl.DataFrame:
    # Conversione sicura se pl_df è un Arrow Table
    if isinstance(pl_df, pa.Table):
        print("Converting Arrow Table to Polars DataFrame...")
        pl_df = pl.from_pandas(pl_df.to_pandas())
    elif not isinstance(pl_df, pl.DataFrame):
        raise TypeError("Expected pl.DataFrame or pyarrow.Table as input")

    print(f"Adding annotations to {len(pl_df):,} cells...")

    # Prepara DataFrame delle annotazioni
    obs_df = adata.obs[['ann_finest_level', 'cell_type']].copy()
    obs_df['cell_id'] = obs_df.index.astype(str)
    obs_pl_df = pl.from_pandas(obs_df.reset_index(drop=True))

    # Join su 'cell_id'
    if "cell_id" not in pl_df.columns:
        raise ValueError("Missing 'cell_id' column in main dataframe")

    merged_df = pl_df.join(obs_pl_df, on='cell_id', how='left')

    # Filtra direttamente le righe non desiderate
    filtered_df = merged_df.filter(
        (pl.col('cell_type') != "unknown") &
        (pl.col('ann_finest_level').is_not_null()) &
        (pl.col('ann_finest_level') != "Unknown")
    )

    print(f"Cells retained after filtering: {len(filtered_df):,}")

    return filtered_df




def create_hf_dataset_from_polars_chunked(
    pl_df: pl.DataFrame,
    output_dir: str,
    val_size=0.2,
    test_size=0.3,
    random_state=42,
    chunk_size=100000
):
    print(f"\nCreating HuggingFace dataset from {len(pl_df):,} cells...")

    n_rows = len(pl_df)
    chunks = []
    for start in range(0, n_rows, chunk_size):
        end = min(start + chunk_size, n_rows)
        print(f"Processing chunk {start//chunk_size + 1}/{(n_rows-1)//chunk_size + 1}: rows {start} to {end-1}")
        chunk_df = pl_df.slice(start, end - start)
        chunk_pandas = chunk_df.to_pandas()
        chunks.append(chunk_pandas)

    df = pd.concat(chunks, ignore_index=True)
    df['ann_finest_level'] = df['ann_finest_level'].cat.remove_unused_categories()
    df['cell_type'] = df['cell_type'].cat.remove_unused_categories()
    print(f"Final DataFrame shape: {df.shape}")

    # Convert embedding column to np.array float32
    print("Converting embeddings to numpy arrays...")
    df['embedding'] = df['embedding'].apply(lambda x: np.array(x, dtype=np.float32))

    # Include all unique values, including "Unknown" if present
    ann_labels = sorted(df['ann_finest_level'].unique())
    celltype_labels = sorted(df['cell_type'].unique())
    
    
    ann_labels = [label for label in ann_labels if pd.notnull(label)]
    celltype_labels = [label for label in celltype_labels if pd.notnull(label)]

    ann_class_label = ClassLabel(names=ann_labels)
    ct_class_label = ClassLabel(names=celltype_labels)

    df['ann_finest_level'] = df['ann_finest_level'].apply(lambda x: ann_class_label.str2int(x) if pd.notnull(x) else -1)
    df['cell_type'] = df['cell_type'].apply(lambda x: ct_class_label.str2int(x) if pd.notnull(x) else -1)

    # Stratified split
    stratifier = StratifiedShuffleSplit(n_splits=1, test_size=val_size + test_size, random_state=random_state)
    train_idx, temp_idx = next(stratifier.split(df, df['cell_type']))

    df_train = df.iloc[train_idx]
    df_temp = df.iloc[temp_idx]

    relative_test_size = test_size / (val_size + test_size)
    stratifier_temp = StratifiedShuffleSplit(n_splits=1, test_size=relative_test_size, random_state=random_state)
    val_idx, test_idx = next(stratifier_temp.split(df_temp, df_temp['cell_type']))

    df_val = df_temp.iloc[val_idx]
    df_test = df_temp.iloc[test_idx]

    # Define feature schema: embedding is fixed size list of float32
    embedding_dim = df['embedding'].iloc[0].shape[0]
    features = Features({
        "cell_id": Value("string"),
        "embedding": Sequence(Value("float32"), length=embedding_dim),
        "ann_finest_level": ann_class_label,
        "cell_type": ct_class_label
    })

    def create_hf(ds_df):
        ds = Dataset.from_pandas(ds_df.reset_index(drop=True), features=features)
        return ds

    ds_dict = DatasetDict({
        "train": create_hf(df_train),
        "validation": create_hf(df_val),
        "test": create_hf(df_test)
    })

    os.makedirs(output_dir, exist_ok=True)
    ds_dict.save_to_disk(output_dir)

    print(f"✓ Dataset saved with ClassLabel and embedding as fixed-size float32 arrays.")
    return ds_dict

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
file_paths = ['/equilibrium/datasets/TCGA-histological-data/lung_embeddings_lechunck/part_800000_to_2282446/batched_analysis_model_gene_expression_embedding_t4_part_800000_to_2282446.h5ad', 
              '/equilibrium/datasets/TCGA-histological-data/lung_embeddings/batched_analysis_backup_800000.h5ad']

# Merge dei file con embedding già come lista
merged_df = merge_h5ad_to_parquet_arrow(file_paths)

Loading file 1/2: /equilibrium/datasets/TCGA-histological-data/lung_embeddings_lechunck/part_800000_to_2282446/batched_analysis_model_gene_expression_embedding_t4_part_800000_to_2282446.h5ad ...
  - File 1: 1,482,447 cells, 3,072 features
Loading file 2/2: /equilibrium/datasets/TCGA-histological-data/lung_embeddings/batched_analysis_backup_800000.h5ad ...
  - File 2: 800,000 cells, 3,072 features

Total cells loaded: 2,282,447
Concatenating all Arrow tables...
Duplicates removed: 0


In [3]:
adata = sc.read_h5ad('/equilibrium/datasets/TCGA-histological-data/2b415299-371a-4e95-8cba-8c6036a72ad5.h5ad', backed='r+')

In [4]:
adata.obs['ann_finest_level']

CGATGTAAGTTACGGG_SC10                               Alveolar macrophages
cc05p_CATGCCTGTGTGCCTG_carraro_csmc                              Unknown
ATTCTACCAAGGTTCT_HD68                              EC aerocyte capillary
D062_TGACCCTTCAAACCCA-sub_wang_sub_batch3           Alveolar fibroblasts
muc9826_GTCGTGAGAGGA_mayr                      Multiciliated (non-nasal)
                                                         ...            
TTGTGGATCGTTCCTG_5-PX5-sub_mould                    Alveolar macrophages
TCAGGATCAAGACGTG_F02526                        Multiciliated (non-nasal)
CAACCTCTCATGTAGC-WSSS8015042-0_meyer_unpubl         Alveolar macrophages
022C-b_GGATGTTTCCAAGTAC_adams                                    Unknown
145I-a_GTCGTAACAGTAGAGC_adams                                    Unknown
Name: ann_finest_level, Length: 2282447, dtype: category
Categories (62, object): ['AT0', 'AT1', 'AT2', 'AT2 proliferating', ..., 'T cells proliferating', 'Tuft', 'Unknown', 'pre-TB secretory']

In [5]:
dataset_label = add_cell_annotations(merged_df, adata)

Converting Arrow Table to Polars DataFrame...
Adding annotations to 2,282,447 cells...
Cells retained after filtering: 1,659,134


In [6]:
import polars as pl

create_hf_dataset_from_polars_chunked(
    dataset_label, 
    output_dir='/equilibrium/datasets/TCGA-histological-data/huggingface/lung_embeddings',
    chunk_size=1000 
)


Creating HuggingFace dataset from 1,659,134 cells...
Processing chunk 1/1660: rows 0 to 999
Processing chunk 2/1660: rows 1000 to 1999
Processing chunk 3/1660: rows 2000 to 2999
Processing chunk 4/1660: rows 3000 to 3999
Processing chunk 5/1660: rows 4000 to 4999
Processing chunk 6/1660: rows 5000 to 5999
Processing chunk 7/1660: rows 6000 to 6999
Processing chunk 8/1660: rows 7000 to 7999
Processing chunk 9/1660: rows 8000 to 8999
Processing chunk 10/1660: rows 9000 to 9999
Processing chunk 11/1660: rows 10000 to 10999
Processing chunk 12/1660: rows 11000 to 11999
Processing chunk 13/1660: rows 12000 to 12999
Processing chunk 14/1660: rows 13000 to 13999
Processing chunk 15/1660: rows 14000 to 14999
Processing chunk 16/1660: rows 15000 to 15999
Processing chunk 17/1660: rows 16000 to 16999
Processing chunk 18/1660: rows 17000 to 17999
Processing chunk 19/1660: rows 18000 to 18999
Processing chunk 20/1660: rows 19000 to 19999
Processing chunk 21/1660: rows 20000 to 20999
Processing ch

Saving the dataset (21/21 shards): 100%|██████████| 829567/829567 [04:04<00:00, 3394.65 examples/s]
Saving the dataset (9/9 shards): 100%|██████████| 331826/331826 [01:18<00:00, 4253.63 examples/s] 
Saving the dataset (13/13 shards): 100%|██████████| 497741/497741 [02:20<00:00, 3539.17 examples/s]


✓ Dataset saved with ClassLabel and embedding as fixed-size float32 arrays.


DatasetDict({
    train: Dataset({
        features: ['cell_id', 'embedding', 'ann_finest_level', 'cell_type'],
        num_rows: 829567
    })
    validation: Dataset({
        features: ['cell_id', 'embedding', 'ann_finest_level', 'cell_type'],
        num_rows: 331826
    })
    test: Dataset({
        features: ['cell_id', 'embedding', 'ann_finest_level', 'cell_type'],
        num_rows: 497741
    })
})