In [1]:
import scanpy as sc
import pandas as pd
import pytorch_lightning as pl

In [2]:
scdata = sc.read_h5ad("data/raw/HCL_final_USE.h5ad")

In [5]:
class SingleCellDataset:
    def __init__(self, fp):
        self.df = pd.read_json(fp)
        self.genes = [c for c in df.columns.tolist() if c ]
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        gene_expression, cell_type = self.X[idx,:]
        return gene_expression, cell_type
    
    @property
    def X(self):
        return self.df[self.genes].values()
    
    @property
    def y(self):
        return self.df["tissue"]

In [6]:
class SingleCellDataModule(pl.LightningDataModule):
    def __init__(self, train_fp, test_fp, val_fp, batch_size):
        super().__init__()
        self.train_fp = train_fp
        self.test_fp = test_fp
        self.val_fp = val_fp
        self.batch_size = batch_size

    def setup(self, stage=None):
        self.train_dataset = SingleCellDataset(self.train_fp)
        self.test_dataset = SingleCellDataset(self.test_fp)
        self.val_dataset = SingleCellDataset(self.val_fp)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)

In [None]:
genes = [m.decode() for m in scdata.var.index.tolist()]
genes[:10]

In [None]:
tissues = [m.decode() for m in scdata.obs.tissue.tolist()]
tissues[:10]

In [None]:
df = pd.DataFrame(data=scdata.X, columns=genes).assign(tissue=tissues)

In [None]:
df.to_json("data/intermediate/train.json.gz")