In [None]:
import anndata as ad
import lightning as L

import os
from os.path import join

from modlyn.io.datamodules import ClassificationDataModule
from modlyn.models.linear import Linear
from modlyn.io.loading import read_lazy

In [None]:
DATA_DIR = "/home/sagemaker-user/tahoe-100M"
adata_chunks = os.listdir(DATA_DIR)

adata = ad.concat([
    read_lazy(join(DATA_DIR, chunk) for chunk in os.listdir(DATA_DIR))
])

In [None]:
adata.obs["y"] = adata.obs["cell_line"].astype("category").cat.codes.to_numpy()

In [None]:
adata_train = adata[:80_527_360]
adata_val = adata[80_527_360:]

datamodule = ClassificationDataModule(
    adata_train=adata_train,
    adata_val=adata_val,
    label_column="y",
    train_dataloader_kwargs={
        "batch_size": 2048,
        "shuffle": True,
        "drop_last": True,
    },
    val_dataloader_kwargs={
        "batch_size": 2048,
        "shuffle": False,
        "drop_last": False,
    },
)

In [None]:
linear = Linear(
    n_genes=adata.n_vars,
    n_covariates=adata.obs["y"].nunique(),
    learning_rate=1e-2,
)

In [None]:
trainer = L.Trainer(
    max_epochs=3,
    log_every_n_steps=100
)

In [None]:
trainer.fit(model=linear, datamodule=datamodule)