In [24]:
import scanpy as sc

In [25]:
path = '/mnt/storage/Daniele/preprocessed_data/cellxgene_census_human/'

In [26]:
path_m = '/mnt/storage/Daniele/preprocessed_data/cellxgene_census_mouse/'

In [27]:
ad = sc.read_h5ad(f'{path}0_preprocessed.h5ad', backed = 'r')

In [28]:
ad_m = sc.read_h5ad(f'{path_m}0_preprocessed.h5ad', backed = 'r')

In [29]:
import time

## Old versions

In [6]:
import gc
from pathlib import Path
from typing import List, Optional, Union
import numpy as np
from anndata import AnnData
from datasets import Dataset, Features, Sequence, Value, concatenate_datasets, load_from_disk
from joblib import Parallel, delayed

class AnnDataHFConverter:
    """
    A class to convert AnnData objects to Hugging Face datasets with support for batching and parallel processing.

    Attributes:
        adata (AnnData): The AnnData object containing the single-cell dataset to be converted.
        batch_size (int, optional): The size of each batch to process. Defaults to the number of observations in `adata`.
        gene_tokens (str): Column name in `adata.var` representing gene tokens.
        columns (List[str], optional): Additional observation metadata columns to include in the dataset.
        n_workers (int, optional): The number of worker processes for parallel computation. Defaults to 1 (sequential processing).

    Methods:
        generate_datasets(save: Optional[str] = None, merge: Optional[bool] = False): Converts `adata` into one or more Hugging Face datasets.
        _prepare_features(): Prepares feature descriptions for the Hugging Face dataset based on `adata` structure.
        _adata_to_dataset(adata_batch: AnnData, features: Features): Converts a single batch of AnnData to a Hugging Face Dataset.
    """
    def __init__(self, adata: AnnData, columns: List[str] = ['total_counts', 'n_genes_by_counts'], batch_size: Optional[int] = None, n_workers: Optional[int] = 1, save: Optional[str] = None, merge: Optional[bool] = True):
        """
        Initializes the AnnDataHFConverter_parallel object with the specified AnnData, batch size, gene tokens, columns, and number of workers.

        Parameters:
            columns (List[str]): Additional columns from `adata.obs` to be included in the datasets.
            batch_size (int, optional): Number of observations to process per batch. If None, the whole dataset is processed as one batch.
            n_workers (int, optional): Number of parallel worker processes to use. Defaults to 1 for sequential processing.
        """

        self.adata = adata
        self.columns = columns
        self.batch_size = batch_size if batch_size is not None else self.adata.n_obs
        self.n_workers = n_workers
        self.save = save
        self.merge = merge

    def generate_datasets(self) -> Optional[Union[Dataset, List[Dataset]]]:
        """
        Generates and optionally saves and/or merges datasets from the AnnData object.

        Parameters:
            save (str, optional): Path to save individual batches as separate datasets. If not specified, datasets are not saved.
            merge (bool, optional): Whether to merge all batch datasets into a single dataset. If True and `save` is specified, the merged dataset is saved.

        Returns:
            Union[Dataset, List[Dataset], None]: The resulting dataset(s) if not saved, or None if saved to disk.
        """
        
        features = self._prepare_features()
        if save:
            save_path = Path(save)
            save_path.mkdir(parents=True, exist_ok=True)

        datasets = []

        def process_batch(idx, start_idx, end_idx):
            adata_batch = self.adata[start_idx:end_idx]
            dataset = self._adata_to_HF(adata_batch, features)
            if save:
                batch_path = save_path / f"{idx}.dataset"
                dataset.save_to_disk(str(batch_path))
            del adata_batch
            gc.collect()
            return dataset

        adata_splits = [(idx, start_idx, min(start_idx + self.batch_size, self.adata.n_obs)) for idx, start_idx in enumerate(range(0, self.adata.n_obs, self.batch_size))]
        
        if self.n_workers > 1:
            datasets = Parallel(n_jobs=self.n_workers)(delayed(process_batch)(idx, start, end) for idx, start, end in adata_splits)
        else:
            for idx, start, end in adata_splits:
                datasets.append(process_batch(idx, start, end))

        if merge:
            if save:
                batches = list(save_path.iterdir())
                batches_datasets = [load_from_disk(str(save_path / batch)) for batch in sorted(batches)]
                dataset = concatenate_datasets(batches_datasets)
                merged_path = save_path / "merged.dataset"
                dataset.save_to_disk(str(merged_path))
                return None
            else:
                dataset = concatenate_datasets(datasets)
                return dataset
        elif not save:
            return datasets

        return None

    def _prepare_features(self) -> Features:
        """
        Prepares a Features object for the Hugging Face dataset based on the columns specified in the AnnData object.

        Returns:
            Features: A dictionary of features with data types corresponding to the columns in the AnnData object.
        """
        features = {'gexp': Sequence(Value("float32")), 'protein_embeddings': Sequence(Value("int32")), 'gene_name': Sequence(Value("string"))}
        if self.columns:
            for col in self.columns:
                if self.adata.obs[col].dtype.name in ['category', 'category_']:
                    features[col] = Value("string")
                else:
                    features[col] = Value(str(self.adata.obs[col].dtype))
        return Features(features)

    def _adata_to_HF(self, adata_batch: AnnData, features: Features) -> Dataset:
        """
        Converts a single batch of AnnData to a Hugging Face Dataset, more efficiently.

        Parameters:
            adata_batch (AnnData): A batch slice of the main AnnData object.
            features (Features): The features for the dataset as prepared by `_prepare_features`.

        Returns:
            Dataset: A single batch converted into a Hugging Face Dataset.
        """
        obs = adata_batch.obs
        gexp = adata_batch.X.toarray() if not isinstance(adata_batch.X, np.ndarray) else adata_batch.X
        gene_tokens = adata_batch.var[self.gene_tokens].to_numpy().astype(np.int32)
        gene_names = adata_batch.var_names.to_numpy()
        data_dict = {
            'gexp': [gexp[cell] for cell in range(adata_batch.n_obs)],
            'gene_token': [gene_tokens for _ in range(adata_batch.n_obs)],
            'gene_name': [gene_names for _ in range(adata_batch.n_obs)]
        }
        if self.columns:
            for col in self.columns:
                data_dict[col] = adata_batch.obs[col].to_numpy()

        dataset = Dataset.from_dict(data_dict, features=features)
        dataset.set_format(type="torch")

        return dataset


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
import json
from pathlib import Path
from typing import Dict, List, Union, Optional

import numpy as np
import pandas as pd
import scanpy as sc
from anndata import AnnData
from scipy.sparse import csr_matrix, hstack, issparse
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


class AnnDataProcessor:
    def __init__(self, adata: AnnData, protein_embeddings: Path, var_names: Optional[str] = None, isMouse: Optional[bool] = False) -> None:
        """
        Initialize the processor with an AnnData object, specifying options for gene identifiers and QC metrics.

        Args:
            adata (AnnData): The AnnData object to be processed.
            protein_embeddings (Path): Path to the file containing protein embeddings.
            var_names (Optional[str]): Column in `adata.var` to use for gene identifiers, defaults to `adata.var_names` if None.
            isMouse (Optional[bool]): Set to True if the dataset originates from mice, influencing the loaded gene vocabularies.
        """
        logging.info("Initializing AnnDataProcessor")
        #self.genes_dictionary = self._load_genes_dictionary('mouse_gene_vocab.json' if isMouse else 'human_gene_vocab.json')
        self.genes_dictionary = human_dict
        self.protein_coding_genes = set(self.genes_dictionary.keys())
        self.adata = adata
        self.protein_embeddings = protein_embeddings
        self.adata.var_names = self.adata.var[var_names] if var_names else self.adata.var_names
        self.var_names = var_names
        self.computeQC = computeQC
        logging.info("AnnDataProcessor initialized successfully")

    def process_adata(self) -> None:
        """
        Processes the AnnData object by ensuring all required genes are present, and computing QC metrics if specified.
        """
        logging.info("Processing AnnData object")
        missing_genes = self.protein_coding_genes - set(self.adata.var_names)
        if missing_genes:
            logging.info(f"Extending AnnData with {len(missing_genes)} missing genes")
            self._extend_anndata(list(missing_genes))
        self.adata = self.adata[:, self.adata.var_names.isin(self.protein_coding_genes)].copy()
        self.adata.var['feature_id'] = [self.genes_dictionary[gene][1] for gene in self.adata.var_names]
        self.adata.var['feature_name'] = self.adata.var_names
        self.adata.var = self.adata.var[['feature_id', 'feature_name', 'extended']]
        logging.info("AnnData processing complete")
        logging.info("Loading protein embeddings")
        self.protein_embeddings = pd.read_parquet(self.protein_embeddings).loc[self.adata.var_names.values]

    def _load_genes_dictionary(self, filename: str) -> Dict[str, List[Union[int, str]]]:
        """
        Load a JSON file containing gene information and return it as a dictionary.

        Args:
            filename (str): Filename of the JSON to load, relative to a predefined directory.

        Returns:
            Dict[str, List[Union[int, str]]]: Dictionary containing gene information.
        """
        relative_path = Path(__file__).parent / 'resources' / filename
        with open(relative_path, 'r') as file:
            data = json.load(file)
        return data

    def _extend_anndata(self, missing_genes: List[str]) -> None:
        """
        Extend the AnnData object by adding missing genes. Adds these genes to the var DataFrame and sets their expression values to -1 (masked).

        Args:
            missing_genes (List[str]): List of gene names that are missing and need to be added.
        """
        logging.info(f"Adding {len(missing_genes)} missing genes to AnnData")
        new_var = pd.DataFrame(index=missing_genes)
        new_var['extended'] = True  # Mark these genes as extended
        if not issparse(self.adata.X):
            self.adata.X = csr_matrix(self.adata.X)
        masked_matrix = csr_matrix((self.adata.n_obs, len(missing_genes)), dtype=float)
        new_X = hstack([self.adata.X, masked_matrix]) 
        extended_var = pd.concat([self.adata.var, new_var], axis=0)
        extended_var['extended'] = extended_var['extended'].fillna(value = False)

        self.adata = AnnData(
            X=new_X,
            obs=self.adata.obs,
            var=extended_var
        )
        logging.info("AnnData extension complete")


## Old try

In [None]:
import json
from importlib import resources
from pathlib import Path
from typing import Dict, List, Optional, Union

import numpy as np
import pandas as pd
import scanpy as sc
from anndata import AnnData
from scipy.sparse import csr_matrix, hstack, issparse


class SetupAnnData:
    def __init__(self, adata: AnnData, var_names: Optional[str] = None, isMouse: Optional[bool] = False, computeQC: Optional[bool] = False) -> None:
        """
        Initialize the modifier with an AnnData object. Load default vocabularies based on the species.

        Parameters:
        adata (AnnData): The AnnData object to modify.
        var_names (Optional[str]): Column in adata.var to use for gene identifiers, defaults to adata.var_names if None.
        isMouse (Optional[bool]): Indicates if the dataset is from mouse, affects loaded vocabularies.
        computeQC (Optional[bool]): If True, computes QC metrics and adds them to adata.
        """
        self.vocab = h_d#self.load_json('human_gene_vocab.json') if not isMouse else self.load_json('mouse_gene_vocab.json')
        self.gene_tokens_vocab = {k: v[0] for k, v in self.vocab.items()}
        self.gene_ensembl_vocab = {k: v[1] for k, v in self.vocab.items()}
        self._adata = adata.copy()
        self._adata.var_names = self._adata.var_names if var_names else self._adata.var_names
        if computeQC:
            sc.pp.calculate_qc_metrics(self._adata, inplace=True)

    def load_json(self, vocab_name: str) -> Dict[str, List[Union[int,str]]]:
        """
        Load a JSON file from the specified relative path and return it as a dictionary.

        Parameters:
        filename (str): The filename of the JSON file to load relative to the class file location.

        Returns:
        Dict[str, List[int,str]]: The loaded JSON data.
        """
        relative_path = Path(__file__).parent.parent / 'resources' / vocab_name
        with open(relative_path) as file:
            data = json.load(file)
        if not isinstance(data, dict) or not all(isinstance(k, str) and isinstance(v, list) and isinstance(v[0], int) and isinstance(v[1], str) for k, v in data.items()):
            raise ValueError("Format as expected: Dict[str, List[str]]")  # noqa: TRY003
        return data

    def map_gene_ids(self) -> None:
        """
        Update the AnnData object's var DataFrame based on the loaded JSON mapping for gene tokens and Ensembl IDs.
        """
        gene_ids_mapping = pd.Series(self.gene_tokens_vocab)
        gene_ensembl_mapping = pd.Series(self.gene_ensembl_vocab)
        target_names = self._adata.var_names.astype(str)
        self._adata.var['gene_token'] = target_names.map(gene_ids_mapping)
        self._adata.var['gene_ensembl'] = target_names.map(gene_ensembl_mapping)

    def extend_anndata(self, missing_genes: List[str]) -> None:
        """
        Extend the AnnData object by adding missing genes. Adds these genes to the var DataFrame and sets their expression values to -1 (masked).
        """
        new_var = pd.DataFrame(index=missing_genes)
        masked_matrix = csr_matrix(np.full((self._adata.n_obs, len(missing_genes)), -1))  # Fill matrix with -1, mask on the flight missing genes
        current_X = self._adata.X.tocsr() if issparse(self._adata.X) else csr_matrix(self._adata.X)
        new_X = hstack([current_X, masked_matrix]).tocsr()
        self._adata = AnnData(
            X=new_X,
            obs=self._adata.obs,
            var=pd.concat([self._adata.var, new_var], axis=0)
        )
        self._adata.var_names_make_unique()

    def process_adata(self) -> AnnData:
        """
        Return the modified AnnData object after ensuring all genes in the vocabulary are present and mapped.
        """
        adata_genes = set(self._adata.var_names)
        vocab_genes = set(self.gene_tokens_vocab.keys())
        missing_genes = vocab_genes - adata_genes
        if missing_genes:
            self.extend_anndata(list(missing_genes))
        self._adata = self._adata[:,list(vocab_genes)].copy()
        self.map_gene_ids()

        return self._adata


In [None]:
import gc
from pathlib import Path
from typing import List, Optional, Union

import numpy as np
from anndata import AnnData

from datasets import (  # type: ignore  # noqa: PGH003
    Dataset,
    Features,
    Sequence,
    Value,
    concatenate_datasets,
    load_from_disk,
)


class AnnDataHFConverter:
    def __init__(self, adata: AnnData, batch_size: Optional[int] = None, gene_tokens: str = 'gene_token', columns: Optional[List[str]] = ['total_counts', 'n_genes_by_counts']) -> None:
        self.adata = adata
        self.batch_size = batch_size if batch_size is not None else self.adata.n_obs
        self.gene_tokens = gene_tokens
        self.columns = columns

    def generate_datasets(self, save: Optional[str] = None, merge: Optional[bool] = False) -> Optional[Union[Dataset, List[Dataset]]]:
        features = self._prepare_features()
        total_cells = self.adata.n_obs

        if save:
            save_path = Path(save)
            if not save_path.exists():
                save_path.mkdir(parents=True, exist_ok=True)

        if self.batch_size == total_cells and merge:
            merge = False

        datasets = []
        for idx, start_idx in enumerate(range(0, total_cells, self.batch_size)):
            end_idx = min(start_idx + self.batch_size, total_cells)
            adata_batch = self.adata[start_idx:end_idx].copy()
            dataset = self._adata_to_dataset(adata_batch, features)
            if save:
                batch_path = save_path / f"{idx}.dataset"
                dataset.save_to_disk(str(batch_path))
            datasets.append(dataset)
            del adata_batch
            gc.collect()

        if merge:
            if save:
                batches = list(save_path.iterdir())
                batches_datasets = [load_from_disk(str(save_path / batch)) for batch in sorted(batches)]
                dataset = concatenate_datasets(batches_datasets)
                merged_path = save_path / "merged.dataset"
                dataset.save_to_disk(str(merged_path))
                return None
            else:
                dataset = concatenate_datasets(datasets)
                return dataset
        elif not save:
            return datasets

        return None

    def _prepare_features(self) -> Features:
        features = {'gexp': Sequence(Value("float32")), 'gene_token': Sequence(Value("int32")), 'gene_name': Sequence(Value("string"))}
        if self.columns:
            for col in self.columns:
                if self.adata.obs[col].dtype.name in ['category', 'category_']:
                    features[col] = Value("string")
                else:
                    features[col] = Value(str(self.adata.obs[col].dtype))
        return Features(features)


    def _adata_to_dataset(self, adata_batch: AnnData, features: Features) -> Dataset:
        """
        Converts a single batch of AnnData to a Hugging Face Dataset, more efficiently.

        Parameters:
        adata_batch (AnnData): A batch slice of the main AnnData object.

        Returns:
        Dataset: A single batch converted into a Hugging Face Dataset.
        """
        obs = adata_batch.obs
        gexp = adata_batch.X.toarray() if not isinstance(adata_batch.X, np.ndarray) else adata_batch.X
        gene_tokens = adata_batch.var[self.gene_tokens].to_numpy().astype(np.int32)
        gene_names = adata_batch.var_names.to_numpy()

        data_dict = {
            'gexp': [gexp[i].squeeze() for i in range(adata_batch.n_obs)],
            'gene_token': [gene_tokens for _ in range(adata_batch.n_obs)],
            'gene_name': [gene_names for _ in range(adata_batch.n_obs)]
        }

        if self.columns:
            for col in self.columns:
                data_dict[col] = obs[col].to_numpy()

        dataset = Dataset.from_dict(data_dict, features=features)
        dataset.set_format(type="torch")

        return dataset


In [None]:
import gc
from pathlib import Path
from typing import List, Optional, Union
import numpy as np
from anndata import AnnData
from datasets import Dataset, Features, Sequence, Value, concatenate_datasets, load_from_disk
from joblib import Parallel, delayed

class AnnDataHFConverter_parallel:
    """
    A class to convert AnnData objects to Hugging Face datasets with support for batching and parallel processing.

    Attributes:
        adata (AnnData): The AnnData object containing the single-cell dataset to be converted.
        batch_size (int, optional): The size of each batch to process. Defaults to the number of observations in `adata`.
        gene_tokens (str): Column name in `adata.var` representing gene tokens.
        columns (List[str], optional): Additional observation metadata columns to include in the dataset.
        n_workers (int, optional): The number of worker processes for parallel computation. Defaults to 1 (sequential processing).

    Methods:
        generate_datasets(save: Optional[str] = None, merge: Optional[bool] = False): Converts `adata` into one or more Hugging Face datasets.
        _prepare_features(): Prepares feature descriptions for the Hugging Face dataset based on `adata` structure.
        _adata_to_dataset(adata_batch: AnnData, features: Features): Converts a single batch of AnnData to a Hugging Face Dataset.
    """
    def __init__(self, adata: AnnData, columns: List[str] = ['total_counts', 'n_genes_by_counts'], batch_size: Optional[int] = None, gene_tokens: str = 'gene_token', n_workers: Optional[int] = 1):
        """
        Initializes the AnnDataHFConverter_parallel object with the specified AnnData, batch size, gene tokens, columns, and number of workers.

        Parameters:
            adata (AnnData): The AnnData object to be processed.
            columns (List[str]): Additional columns from `adata.obs` to be included in the datasets.
            batch_size (int, optional): Number of observations to process per batch. If None, the whole dataset is processed as one batch.
            gene_tokens (str): The key in the `adata.var` DataFrame that corresponds to the gene tokens.
            n_workers (int, optional): Number of parallel worker processes to use. Defaults to 1 for sequential processing.
        """
        self.adata = adata
        self.batch_size = batch_size if batch_size is not None else self.adata.n_obs
        self.gene_tokens = gene_tokens
        self.columns = columns
        self.n_workers = n_workers

    def generate_datasets(self, save: Optional[str] = None, merge: Optional[bool] = False) -> Optional[Union[Dataset, List[Dataset]]]:
        """
        Generates and optionally saves and/or merges datasets from the AnnData object.

        Parameters:
            save (str, optional): Path to save individual batches as separate datasets. If not specified, datasets are not saved.
            merge (bool, optional): Whether to merge all batch datasets into a single dataset. If True and `save` is specified, the merged dataset is saved.

        Returns:
            Union[Dataset, List[Dataset], None]: The resulting dataset(s) if not saved, or None if saved to disk.
        """
        features = self._prepare_features()
        total_cells = self.adata.n_obs

        if save:
            save_path = Path(save)
            save_path.mkdir(parents=True, exist_ok=True)

        datasets = []

        def process_batch(idx, start_idx, end_idx):
            adata_batch = self.adata[start_idx:end_idx]
            dataset = self._adata_to_dataset(adata_batch, features)
            if save:
                batch_path = save_path / f"{idx}.dataset"
                dataset.save_to_disk(str(batch_path))
            del adata_batch
            gc.collect()
            return dataset

        adata_splits = [(idx, start_idx, min(start_idx + self.batch_size, total_cells)) for idx, start_idx in enumerate(range(0, total_cells, self.batch_size))]
        
        if self.n_workers > 1:
            datasets = Parallel(n_jobs=self.n_workers)(delayed(process_batch)(idx, start, end) for idx, start, end in adata_splits)
        else:
            for idx, start, end in adata_splits:
                datasets.append(process_batch(idx, start, end))

        if merge:
            if save:
                batches = list(save_path.iterdir())
                batches_datasets = [load_from_disk(str(save_path / batch)) for batch in sorted(batches)]
                dataset = concatenate_datasets(batches_datasets)
                merged_path = save_path / "merged.dataset"
                dataset.save_to_disk(str(merged_path))
                return None
            else:
                dataset = concatenate_datasets(datasets)
                return dataset
        elif not save:
            return datasets

        return None

    def _prepare_features(self) -> Features:
        """
        Prepares a Features object for the Hugging Face dataset based on the columns specified in the AnnData object.

        Returns:
            Features: A dictionary of features with data types corresponding to the columns in the AnnData object.
        """
        features = {'gexp': Sequence(Value("float32")), 'gene_token': Sequence(Value("int32")), 'gene_name': Sequence(Value("string"))}
        if self.columns:
            for col in self.columns:
                if self.adata.obs[col].dtype.name in ['category', 'category_']:
                    features[col] = Value("string")
                else:
                    features[col] = Value(str(self.adata.obs[col].dtype))
        return Features(features)

    def _adata_to_dataset(self, adata_batch: AnnData, features: Features) -> Dataset:
        """
        Converts a single batch of AnnData to a Hugging Face Dataset, more efficiently.

        Parameters:
            adata_batch (AnnData): A batch slice of the main AnnData object.
            features (Features): The features for the dataset as prepared by `_prepare_features`.

        Returns:
            Dataset: A single batch converted into a Hugging Face Dataset.
        """
        obs = adata_batch.obs
        gexp = adata_batch.X.toarray() if not isinstance(adata_batch.X, np.ndarray) else adata_batch.X
        gene_tokens = adata_batch.var[self.gene_tokens].to_numpy().astype(np.int32)
        gene_names = adata_batch.var_names.to_numpy()
        data_dict = {
            'gexp': [gexp[cell] for cell in range(adata_batch.n_obs)],
            'gene_token': [gene_tokens for _ in range(adata_batch.n_obs)],
            'gene_name': [gene_names for _ in range(adata_batch.n_obs)]
        }
        if self.columns:
            for col in self.columns:
                data_dict[col] = adata_batch.obs[col].to_numpy()

        dataset = Dataset.from_dict(data_dict, features=features)
        dataset.set_format(type="torch")

        return dataset


In [None]:
import json
with open('/home/daniele/Code/scGraph/scgraph/resources/human_gene_vocab.json','r') as file:
    h_d = json.load(file)

In [None]:
setup = SetupAnnData(ad)
ad_ = setup.process_adata()

In [None]:
import torch
from torch.utils.data import Dataset
import numpy as np
from anndata import AnnData
from typing import List, Optional

class AnnDataTorchDataset(Dataset):
    """
    A PyTorch Dataset to handle AnnData objects for processing with PyTorch models.
    
    Attributes:
        adata (AnnData): An AnnData object containing the single-cell dataset.
        gene_tokens (str): The column name in adata.var representing gene tokens.
        columns (List[str]): List of additional observation metadata columns to include in the dataset.
    """
    def __init__(self, adata: AnnData, gene_tokens: str = 'gene_token', columns: List[str] = ['total_counts', 'n_genes_by_counts']):
        """
        Initializes the AnnDataTorchDataset object with the specified AnnData and configuration.

        Parameters:
            adata (AnnData): The AnnData object to be processed.
            gene_tokens (str): The key in the adata.var DataFrame that corresponds to the gene tokens.
            columns (List[str]): Additional columns from adata.obs to be included in the datasets.
        """
        self.adata = adata
        self.gene_tokens = gene_tokens
        self.columns = columns

        self.gexp = self.adata.X.toarray() if not isinstance(self.adata.X, np.ndarray) else self.adata.X
        self.gene_tokens = self.adata.var[gene_tokens].to_numpy().astype(np.int32)
        self.gene_names = self.adata.var_names.to_numpy()
        self.obs_data = {col: self.adata.obs[col].to_numpy() for col in self.columns}

    def __len__(self):
        return self.adata.n_obs

    def __getitem__(self, idx):
        sample_data = {
            'gexp': torch.tensor(self.gexp[idx], dtype=torch.float32),
            'gene_token': torch.tensor(self.gene_tokens, dtype=torch.int32),
            #'gene_name': self.gene_names  # This might need adjustment if you want it as a tensor
        }
        
        for col in self.columns:
            sample_data[col] = [self.adata.obs[col][idx]]#.to_numpy()

        return sample_data

# Example usage
# Load your AnnData object somehow, for example:
# adata = AnnData.read_h5ad('your_data_file.h5ad')
# dataset = AnnDataTorchDataset(adata)
# DataLoader can be used to create batches from the dataset
# from torch.utils.data import DataLoader
# dataloader = DataLoader(dataset, batch_size=10, shuffle=True)


In [None]:
esm_human = pd.read_parquet('/home/daniele/Code/scGraph/scgraph/protein_embeddings/human_embeddings.parquet')

In [None]:
esm_mouse = pd.read_parquet('/home/daniele/Code/scGraph/scgraph/protein_embeddings/mouse_embeddings.parquet')

In [None]:
human_genes = set(ad.var_names).intersection(set(esm_human.index))

In [None]:
mouse_genes = set(ad_m.var_names).intersection(set(esm_mouse.index))

In [None]:
gene_dict = {'protein_coding_genes_human': list(human_genes), 'protein_coding_genes_mouse': list(mouse_genes)}

In [None]:
json_file = '/home/daniele/Code/scGraph/scgraph/resources/gene_dictionary.json'
with open(json_file, 'w') as file:
    json.dump(gene_dict, file, indent = 4)

# Try esm

In [None]:
import json
json_file = '/home/daniele/Code/scGraph/scgraph/resources/gene_dictionary.json'
with open(json_file, 'r') as file:
    gene_dict = json.load(file)

In [2]:
import json
json_file = '/home/daniele/Code/scGraph/scgraph/resources/human_gene_vocab.json'
with open(json_file, 'r') as file:
    human_dict = json.load(file)

In [1]:
import json
json_file = '/home/daniele/Code/scGraph/scgraph/resources/mouse_gene_vocab.json'
with open(json_file, 'r') as file:
    m_d = json.load(file)

In [None]:
len(set(h_d.keys()).intersection(set(ad.var_names)))

In [None]:
len(set(m_d.keys()).intersection(set(ad_m.var_names)))

In [None]:
import json
json_file = '/home/daniele/Code/scGraph/scgraph/resources/human_gene_vocab.json'
with open(json_file, 'w') as file:
    json.dump(human_dict,file,indent=4)

In [None]:
import json
json_file = '/home/daniele/Code/scGraph/scgraph/resources/mouse_gene_vocab.json'
with open(json_file, 'w') as file:
    json.dump(mouse_dict,file,indent=4)

In [5]:
import pandas as pd
ortho = pd.read_table('/home/daniele/orthologs_table.txt')

In [6]:
ortho

Unnamed: 0,human_entrez_gene,human_ensembl_gene,human_assert_ids,mouse_entrez_gene,mouse_ensembl_gene,mouse_assert_ids
0,-,ENSG00000212395,ENSG00000212395,-,ENSMUSG00002076924,ENSMUSG00002076924
1,-,ENSG00000212371,ENSG00000212371,115487044,ENSMUSG00000065089,ENSMUSG00000065089
2,-,ENSG00000252473,ENSG00000252473,115489211,ENSMUSG00000119329,ENSMUSG00000119329
3,-,ENSG00000252408,ENSG00000252408,115489209,ENSMUSG00000064536,ENSMUSG00000064536
4,-,ENSG00000252762,ENSG00000252762,-,ENSMUSG00002076040,ENSMUSG00002076040
...,...,...,...,...,...,...
68075,79699,ENSG00000162378,"OG6_107195,79699,ENOG5035HRW,ENSP00000294353,1...",414872,ENSMUSG00000034636,"OG6_107195,414872,ENOG5035HRW,ENSMUSP000000438..."
68076,79699,ENSG00000162378,OG6_107195,230590,ENSMUSG00000034645,OG6_107195
68077,7791,ENSG00000159840,"621894at9347,OG6_127905,ENSP00000324422,HUMAN9...",22793,ENSMUSG00000029860,"621894at9347,OG6_127905,ENSMUSP00000070427,MOU..."
68078,23140,ENSG00000074755,"9027,O43149,HUMAN|HGNC=29027|UniProtKB=O43149,...",195018,ENSMUSG00000055670,"9027,Q5SSH7,MOUSE|MGI=MGI=2444286|UniProtKB=Q5..."


In [9]:
h_d_ensembl = {v[1]:k for k,v in human_dict.items()}

In [12]:
ortho_subset = ortho[ortho['mouse_ensembl_gene'].isin([v[1] for v in m_d.values()])]

In [15]:
len(ortho_subset.mouse_ensembl_gene.unique())

16076

In [19]:
len([v[1] for v in m_d.values()])

16161

In [20]:
mapping = {k:v for k,v in zip(ortho['mouse_ensembl_gene'],ortho['human_ensembl_gene'])}

In [None]:
a

In [None]:
mapping = {}
for i in range(ortho_subset.shape[0]):
    mapping[

In [None]:
mapping = {}
for k in mapping.keys():
    human = mapping[k]

In [None]:
mapping_last = {k:v for k,v in zip(

In [31]:
ad.isbacked

True

In [32]:
import json
from pathlib import Path
from typing import Dict, List, Union, Optional, Tuple

import numpy as np
import pandas as pd
import scanpy as sc
from anndata import AnnData
from scipy.sparse import csr_matrix, hstack, issparse
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


class AnnDataProcessor:
    def __init__(self, adata: AnnData, protein_embeddings: Path, var_names: Optional[str] = None, isMouse: Optional[bool] = False) -> None:
        """
        Initialize the processor with an AnnData object, specifying options for gene identifiers and QC metrics.

        Args:
            adata (AnnData): The AnnData object to be processed.
            protein_embeddings(Path): path to the esm protein embeddings stored as parquet file.
            var_names (Optional[str]): Column in `adata.var` to use for gene identifiers, defaults to `adata.var_names` if None.
            isMouse (Optional[bool]): Set to True if the dataset originates from mice, influencing the loaded gene vocabularies.
        """
        
        logging.info("Initializing AnnDataProcessor")
        #self.genes_dictionary = self._load_genes_dictionary('mouse_gene_vocab.json' if isMouse else 'human_gene_vocab.json')
        self.genes_dictionary = human_dict
        self.protein_coding_genes = set(self.genes_dictionary.keys())
        self.adata = adata
        self.protein_embeddings = protein_embeddings
        self.adata.var_names = self.adata.var[var_names] if var_names else self.adata.var_names
        if 'total_counts' not in self.adata.obs.columns:
            logging.info("total_counts not detected, computing QC metrics.")
            sc.pp.calculate_qc_metrics(self.adata, inplace = True)
        logging.info("AnnDataProcessor initialized successfully")

    def process(self) -> Tuple[AnnData, pd.DataFrame]:
        """
        Processes the AnnData object by ensuring all required genes are present, and computing QC metrics if specified.
        """
        
        logging.info("Processing AnnData object")
        missing_genes = self.protein_coding_genes - set(self.adata.var_names)
        if missing_genes:
            logging.info(f"Extending AnnData with {len(missing_genes)} missing genes")
            self._extend_anndata(list(missing_genes))
        else:
            self.adata.var['extended'] = 0
        self.adata = self.adata[:, self.adata.var_names.isin(self.protein_coding_genes)].copy()
        self.adata.var['feature_id'] = [self.genes_dictionary[gene][1] for gene in self.adata.var_names]
        self.adata.var['feature_name'] = self.adata.var_names
        self.adata.var = self.adata.var[['feature_id', 'feature_name', 'extended']]        
        logging.info("AnnData preprocessed succesfully.")
        logging.info("Reading esm protein embeddings")
        self.adata = self.adata[:,list(self.protein_coding_genes)]
        protein_embeddings = pd.read_parquet(self.protein_embeddings).loc[list(self.protein_coding_genes)]        
        return self.adata, protein_embeddings
    
    def _load_genes_dictionary(self, filename: str) -> Dict[str, List[Union[int, str]]]:
        """
        Load a JSON file containing gene information and return it as a dictionary.

        Args:
            filename (str): Filename of the JSON to load, relative to a predefined directory.

        Returns:
            Dict[str, List[Union[int, str]]]: Dictionary containing gene information.
        """
        
        relative_path = Path(__file__).parent / 'resources' / filename
        with open(relative_path, 'r') as file:
            data = json.load(file)
        return data

    def _extend_anndata(self, missing_genes: List[str]) -> None:
        """
        Extend the AnnData object by adding missing genes. Adds these genes to the var DataFrame and sets their expression values to -1 (masked).

        Args:
            missing_genes (List[str]): List of gene names that are missing and need to be added.
        """
        
        logging.info(f"Adding {len(missing_genes)} missing genes to AnnData")
        new_var = pd.DataFrame(index=missing_genes)
        new_var['extended'] = 1  # Mark these genes as extended
        if not issparse(self.adata.X):
            self.adata.X = csr_matrix(self.adata.X)
        masked_matrix = csr_matrix((self.adata.n_obs, len(missing_genes)), dtype=float)
        new_X = hstack([self.adata.X, masked_matrix]) 
        extended_var = pd.concat([self.adata.var, new_var], axis=0)
        extended_var['extended'] = extended_var['extended'].fillna(value = 0)

        self.adata = AnnData(
            X=new_X,
            obs=self.adata.obs,
            var=extended_var
        )
        logging.info("AnnData extension complete")


In [43]:
import shutil
import gc
from pathlib import Path
from typing import List, Optional, Union, Generator
import numpy as np
from anndata import AnnData
from datasets import Dataset, Features, Sequence, Value, Array2D, concatenate_datasets, load_from_disk
from joblib import Parallel, delayed

class AnnDataHFConverter:
    """
    A class to convert AnnData objects to Hugging Face datasets with support for batching and parallel processing.

    Attributes:
        adata (AnnData): The AnnData object containing the single-cell dataset to be converted.
        batch_size (Optional[int]): The size of each batch to process. Defaults to the number of observations in `adata`.
        columns (List[str]): Additional observation metadata columns to include in the dataset.
        n_workers (int): The number of worker processes for parallel computation. Defaults to 1 (sequential processing).
        save (Optional[str]): Path to save the processed datasets.
        merge (bool): Whether to merge all processed batches into a single dataset.
    """
    def __init__(self, adata: AnnData, covariates: Optional[Union[str, List[str]]] = None, batch_size: Optional[int] = None, save: Optional[str] = None, merge: bool = True, n_workers: int = 1, cache_dir: Optional[str] = None):
        """
        Initializes the AnnDataHFConverter object with the specified AnnData, batch size, columns, and number of workers.

        Parameters:
            adata (AnnData): The AnnData object to convert, preprocessed with the `AnnDataProcessor` class.
            columns (List[str]): Additional columns from `adata.obs` to be included in the datasets.
            batch_size (Optional[int]): Number of observations to process per batch. If None, processes the entire dataset as one batch.
            save (Optional[str]): Path to save individual batches as separate datasets. If not specified, datasets are not saved.
            merge (bool): Whether to merge all batch datasets into a single dataset. If True and `save` is specified, the merged dataset is saved.
            n_workers (int): Number of parallel worker processes to use. Defaults to 1 for sequential processing.
            cache_dir(Optional[str]): the cache dir used by Hugging face to store the dataset while generating it.
        """

        self.adata = adata
        self.covariates = ['total_counts', 'n_genes_by_counts']#, 'extended']
        self.cache_dir = cache_dir
        #assert 'extended' in self.adata.var.columns, "Please preprocess the AnnData with the `AnnDataProcessor` class."
        
        if covariates:
            if isinstance(covariates, str):
                self.covariates.append(covariates)
            elif isinstance(covariates, list):
                self.covariates.extend(covariates)
            else:
                raise ValueError("additional_covariates must be either a string or a list of strings.")
                
        self.batch_size = batch_size if batch_size is not None else self.adata.n_obs
        self.n_workers = n_workers
        self.save = save
        self.merge = merge
        
        if (self.merge or n_workers!=1) and self.batch_size == self.adata.n_obs:
            self.merge = False
            self.n_workers = 1
       
    def generate_datasets(self) -> Optional[Union[Dataset, List[Dataset]]]:
        """
        Generates and optionally saves and/or merges datasets from the AnnData object.

        Returns:
            Optional[Union[Dataset, List[Dataset]]]: The resulting dataset(s) if not saved, or None if saved to disk.
        """
        
        features = self._prepare_features()
        
        if self.save:
            save_path = Path(self.save)
            save_path.mkdir(parents=True, exist_ok=True)

        datasets = []

        def process_batch(idx, start_idx, end_idx):
            adata_batch = self.adata[start_idx:end_idx]#.copy()
            generator = self._create_generator(adata_batch)
            dataset = Dataset.from_generator(generator=generator, features=features, cache_dir=self.cache_dir)
            dataset.set_format(type="torch")
            batch_path = save_path / f"batch_{idx}.dataset" if self.save else None
            if self.save:
                dataset.save_to_disk(str(batch_path))
            del adata_batch
            gc.collect()
            return dataset, batch_path

        adata_splits = [(idx, start_idx, min(start_idx + self.batch_size, self.adata.n_obs)) for idx, start_idx in enumerate(range(0, self.adata.n_obs, self.batch_size))]
        
        if self.n_workers > 1:
            results = Parallel(n_jobs=self.n_workers)(delayed(process_batch)(idx, start, end) for idx, start, end in adata_splits)
        else:
            results = [process_batch(idx, start, end) for idx, start, end in adata_splits]

        datasets, paths = zip(*results) if results else ([], [])

        if self.merge:
            dataset = concatenate_datasets(list(datasets))
            if self.save:
                merged_path = save_path / "merged.dataset"
                dataset.save_to_disk(str(merged_path))
                for path in paths:
                    if path and path.exists():
                        shutil.rmtree(str(path))
                print(f"Individual batches removed, merged file at: {merged_path}")
            return dataset

        return list(datasets) if not self.save else None

        return None

    def _prepare_features(self) -> Features:
        """
        Prepares a Features object for the Hugging Face dataset based on the columns specified in the AnnData object.
    
        Returns:
            Features: A dictionary of features with data types corresponding to the columns in the AnnData object.
        """
        
        features = {
            'gexp': Sequence(Value("float32")),
            'gene_names': Sequence(Value("string"))
        }
        for col in self.covariates:
            if col in self.adata.obs.columns :
                dtype = self.adata.obs[col].dtype
                features[col] = Value("string") if dtype.name in ['category', 'object'] else Value(str(dtype))
            elif col in self.adata.var.columns:
                dtype = self.adata.var[col].dtype
                features[col] = Sequence(Value("string")) if dtype.name in ['category', 'object'] else Sequence(Value(str(dtype)))        
            else:
                raise ValueError(f'Covariate {col} not found in either .obs or .var')
            
        return Features(features)
     
    def _create_generator(self, adata_batch: AnnData) -> Generator:
        """
        Create a generator function that can be passed to Dataset.from_generator.

        Parameters:
        adata_batch(
        """
        def generator():
            for i in range(adata_batch.n_obs):
                cell_data = {
                    'gexp': adata_batch.X.getrow(i).A.squeeze(),
                    'gene_names': adata_batch.var_names.values
                }
                    
                for col in self.covariates:
                    if col in adata_batch.obs.columns:
                        cell_data[col] = adata_batch.obs[col].iloc[i]
                    elif col in adata_batch.var.columns:
                        cell_data[col] = adata_batch.var[col].tolist()
                    
                yield cell_data
        
        return generator


In [9]:
import torch
from torch.utils.data import Dataset, DataLoader
import anndata as ad
import numpy as np

In [10]:
class AnnDataTorchDataset(Dataset):
    """
    A Torch Dataset wrapper for an AnnData object to provide data cell by cell.

    Attributes:
        adata (ad.AnnData): The AnnData object containing the single-cell dataset.
    """

    def __init__(self, adata: ad.AnnData, protein_embeddings: pd.DataFrame, covariates: Optional[Union[str, List[str]]] = None):
        """
        Initializes the dataset with an AnnData object.

        Parameters:
        adata (ad.AnnData): An AnnData object loaded with single-cell data.
        """
        self.adata = adata
        self.X = self.adata.X if isinstance(self.adata.X,np.ndarray) else self.adata.X.A
        self.covariates = ['total_counts', 'n_genes_by_counts', 'extended']
        self.protein_embeddings = protein_embeddings.values.squeeze()
        assert 'extended' in self.adata.var.columns, "Please preprocess the AnnData with the `AnnDataProcessor` class."
        
        if covariates:
            if isinstance(covariates, str):
                self.covariates.append(covariates)
            elif isinstance(covariates, list):
                self.covariates.extend(covariates)
            else:
                raise ValueError("additional_covariates must be either a string or a list of strings.")

    def __len__(self):
        """
        Returns the total number of cells in the dataset.
        """
        return self.adata.n_obs

    def __getitem__(self, idx):
        """
        Retrieves the expression values for the cell at the provided index.

        Parameters:
        idx (int): The index of the cell to retrieve.

        Returns:
        torch.Tensor: The expression values of the cell as a tensor.
        """
        cell = {'gexp': self.X[idx], 'protein_embeddings': self.protein_embeddings}                 
        for col in self.covariates:
            if col in self.adata.obs.columns:
                cell[col] = self.adata.obs[col].iloc[idx]
            elif col in self.adata.var.columns:
                cell[col] = torch.tensor(self.adata.var[col].values)

                    
        return cell

In [None]:
            for col in ['total_counts', 'extended']:
                if col in adata_batch.obs:
                    cell_data[col] = adata_batch.obs[col].iloc[i]
                elif col in adata_batch.var:
                    cell_data[col] = adata_batch.var[col].values
            
            yield cell_data

In [36]:
ad

AnnData object with n_obs × n_vars = 673631 × 60530 backed at '/mnt/storage/Daniele/preprocessed_data/cellxgene_census_human/0_preprocessed.h5ad'
    obs: 'tissue_ontology_term_id', 'suspension_type', 'sex_ontology_term_id', 'cell_type', 'sex', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'disease', 'nnz', 'assay', 'development_stage_ontology_term_id', 'tissue', 'self_reported_ethnicity', 'development_stage', 'is_primary_data', 'soma_joinid', 'tissue_general_ontology_term_id', 'n_measured_vars', 'dataset_id', 'tissue_general', 'disease_ontology_term_id', 'tissue_type', 'raw_sum', 'raw_mean_nnz', 'self_reported_ethnicity_ontology_term_id', 'donor_id', 'raw_variance_nnz', 'observation_joinid', 'unique_donor_id', 'species', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_

In [37]:
ad_cut = ad[:1000,:1000]

In [38]:
ad_cut

View of AnnData object with n_obs × n_vars = 1000 × 1000 backed at '/mnt/storage/Daniele/preprocessed_data/cellxgene_census_human/0_preprocessed.h5ad'
    obs: 'tissue_ontology_term_id', 'suspension_type', 'sex_ontology_term_id', 'cell_type', 'sex', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'disease', 'nnz', 'assay', 'development_stage_ontology_term_id', 'tissue', 'self_reported_ethnicity', 'development_stage', 'is_primary_data', 'soma_joinid', 'tissue_general_ontology_term_id', 'n_measured_vars', 'dataset_id', 'tissue_general', 'disease_ontology_term_id', 'tissue_type', 'raw_sum', 'raw_mean_nnz', 'self_reported_ethnicity_ontology_term_id', 'donor_id', 'raw_variance_nnz', 'observation_joinid', 'unique_donor_id', 'species', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'total_counts_mt', 'log1p_total_counts_mt', 

In [14]:
#del ad_processor

In [39]:
ad_processor = AnnDataProcessor(ad_cut, '/home/daniele/Code/scGraph/scgraph/protein_embeddings/human_embeddings.parquet', var_names = 'feature_name')

2024-05-06 11:03:52,110 - INFO - Initializing AnnDataProcessor


ValueError: To copy an AnnData object in backed mode, pass a filename: `.copy(filename='myfilename.h5ad')`. To load the object into memory, use `.to_memory()`.

In [23]:
adata, embs = ad_processor.process()

2024-05-06 00:04:56,634 - INFO - Processing AnnData object
2024-05-06 00:04:56,643 - INFO - Extending AnnData with 18834 missing genes
2024-05-06 00:04:56,646 - INFO - Adding 18834 missing genes to AnnData
2024-05-06 00:04:56,698 - INFO - AnnData extension complete
2024-05-06 00:04:57,353 - INFO - AnnData preprocessed succesfully.
2024-05-06 00:04:57,354 - INFO - Reading esm protein embeddings


In [24]:
adata.var

Unnamed: 0,feature_id,feature_name,extended
OR4A15,ENSG00000181958,OR4A15,1.0
H2BC21,ENSG00000184678,H2BC21,1.0
ZNF235,ENSG00000159917,ZNF235,1.0
PCDHGA5,ENSG00000253485,PCDHGA5,1.0
AQP7B,ENSG00000259916,AQP7B,1.0
...,...,...,...
PRKCH,ENSG00000027075,PRKCH,0.0
NFIC,ENSG00000141905,NFIC,1.0
SSB,ENSG00000138385,SSB,1.0
GFOD2,ENSG00000141098,GFOD2,1.0


In [25]:
embs

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,2550,2551,2552,2553,2554,2555,2556,2557,2558,2559
OR4A15,0.075709,-0.016977,0.116577,-0.047384,-0.055592,-0.097196,-0.125780,0.117031,0.041157,-0.107229,...,0.199900,0.070354,-0.046446,-0.039389,0.124881,0.014685,-0.161402,0.018862,-0.254390,-0.049054
H2BC21,0.024577,0.027296,-0.019227,-0.009875,0.129760,0.018530,-0.078625,0.119971,0.117559,0.018290,...,-0.014135,0.029447,0.070439,-0.011796,0.070686,-0.077867,0.032035,0.050242,-0.089916,-0.002666
ZNF235,0.017618,-0.014136,0.073148,-0.048830,-0.077148,-0.064344,-0.052398,0.126266,0.087634,0.047113,...,-0.008405,0.042953,0.001227,0.027062,0.033975,-0.056882,-0.025897,0.023816,-0.230112,0.015032
PCDHGA5,-0.026696,-0.029550,-0.011700,-0.070574,-0.058281,-0.020026,-0.137816,0.121500,0.035187,-0.072562,...,0.076399,-0.009979,-0.005953,0.008551,0.006016,-0.002289,-0.061308,0.053673,-0.123048,-0.069099
AQP7B,-0.033831,-0.002237,0.149922,-0.070336,-0.074245,-0.082310,-0.046303,0.140770,0.112348,-0.066160,...,0.157253,-0.052053,0.021678,0.047585,0.156918,-0.076401,-0.015832,0.035389,-0.329794,-0.100048
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
PRKCH,0.047953,-0.078331,-0.001375,-0.063575,0.092070,-0.107469,-0.077264,0.058905,0.043936,-0.018209,...,-0.009887,0.005473,0.007390,0.055760,-0.072412,-0.073535,0.043553,-0.052560,-0.113003,-0.050934
NFIC,0.029466,-0.002191,0.058328,-0.007980,-0.016677,-0.008492,-0.030325,0.076611,-0.025818,-0.044419,...,0.036781,-0.106382,0.027533,0.075052,0.039753,-0.021913,-0.018525,0.043329,-0.079458,0.008996
SSB,0.011891,-0.036905,0.068846,0.002096,-0.013581,-0.059306,-0.073502,0.078924,0.036888,-0.008898,...,0.091435,-0.050210,0.063731,-0.097927,0.174676,-0.019236,0.048399,0.123177,-0.116802,-0.065504
GFOD2,0.006096,-0.028234,0.024738,-0.004632,0.008666,0.013890,-0.043456,0.036703,-0.009366,0.000409,...,-0.019945,-0.003638,0.011175,-0.058379,-0.010992,-0.053960,-0.052604,0.069502,-0.070411,0.004945


In [None]:
gc.collect()

In [44]:
now = time.time()
cv = AnnDataHFConverter(ad_cut, )#batch_size = 100, n_workers = 10, merge = True)
dataset = cv.generate_datasets()
print(time.time() - now)

ValueError: Currently, you cannot index repeatedly into a backed AnnData, that is, you cannot make a view of a view.

In [29]:
dataset[0]['gene_names']

['OR4A15',
 'H2BC21',
 'ZNF235',
 'PCDHGA5',
 'AQP7B',
 'ETNK2',
 'KALRN',
 'PDZD11',
 'PPP6R1',
 'ZC3H7B',
 'FAM120A',
 'ATXN3L',
 'SLC24A2',
 'PAQR7',
 'GCNT7',
 'STK32C',
 'MRPL36',
 'ITGB1',
 'RFTN2',
 'BNIP5',
 'FLOT2',
 'CTXND1',
 'PTPN13',
 'SLC35A3',
 'DEPDC1',
 'HAUS5',
 'AVP',
 'ADPRM',
 'VWDE',
 'DGAT2',
 'CBX2',
 'ZNF398',
 'IRX5',
 'UGT2A3',
 'USP14',
 'WDR90',
 'TRPC4AP',
 'RCBTB2',
 'CNTN1',
 'ARTN',
 'OR10A7',
 'EMC6',
 'KIAA0513',
 'ARNT2',
 'C6orf120',
 'NFXL1',
 'GSTP1',
 'SLC37A2',
 'CASC2',
 'PODNL1',
 'OR52J3',
 'IGFN1',
 'KRTAP10-4',
 'METAP2',
 'OSBPL2',
 'TCEAL9',
 'NNMT',
 'CCDC141',
 'IPMK',
 'UGT1A9',
 'CMKLR2',
 'HACD2',
 'PABPC1L',
 'JDP2',
 'PCYT1A',
 'TEX264',
 'TUSC3',
 'TMED1',
 'RER1',
 'LCN9',
 'HLCS',
 'CCDC153',
 'TAP2',
 'ESR2',
 'ST7',
 'SPTY2D1',
 'BTBD8',
 'LMAN1',
 'PDYN',
 'FRY',
 'WDR41',
 'S1PR1',
 'GOLGA8J',
 'BET1L',
 'IKZF4',
 'CXCL9',
 'TAMALIN',
 'MPZL2',
 'ACTL7A',
 'OS9',
 'SLCO1B3',
 'SGO2',
 'HS3ST6',
 'CCNA1',
 'PITPNM2',
 'U2SURP

In [None]:
torch_dataset = AnnDataTorchDataset(adata, embs)

In [None]:
dataloader = DataLoader(torch_dataset, batch_size = 1)

In [None]:
for i, batch in enumerate(dataloader):
    print(batch[''].shape)
    break

In [None]:
for i, batch in enumerate(dataloader):
    print(batch['protein_embeddings'])#.shape)
    break

In [None]:
list(embs.values.squeeze())

In [None]:
parquet = pd.read_parquet('/home/daniele/Code/scGraph/scgraph/protein_embeddings/human_embeddings.parquet')

In [None]:
parquet

In [None]:
parquet = parquet.loc[human_dict.keys()]

In [None]:
parquet

In [None]:
adata[:, list(human_dict.keys())].X