In [None]:
import anndata as ad
import lightning as L
from tqdm import tqdm

import warnings
import os
from os.path import join

In [None]:
%load_ext autoreload

In [None]:
%autoreload
from modlyn.io.datamodules import ClassificationDataModule
from modlyn.models.linear import Linear
from modlyn.io.loading import read_lazy

In [None]:
DATA_DIR = "/mnt/dssmcmlfs01/tahoe100M"

with warnings.catch_warnings():
    warnings.simplefilter("ignore")  # ignore zarr warnings that zarrv3 codec is not final yet
    adata = ad.concat([
        read_lazy(join(DATA_DIR, chunk)) for chunk in tqdm(os.listdir(DATA_DIR))
    ])

In [None]:
adata.obs["y"] = adata.obs["cell_line"].astype("category").cat.codes.to_numpy().astype("i8")

In [None]:
adata_train = adata[:80_527_360]
adata_val = adata[80_527_360:]

datamodule = ClassificationDataModule(
    adata_train=adata_train,
    adata_val=adata_val,
    label_column="y",
    train_dataloader_kwargs={
        "batch_size": 2048,
        "drop_last": True,
    },
    val_dataloader_kwargs={
        "batch_size": 2048,
        "drop_last": False,
    },
)

In [None]:
linear = Linear(
    n_genes=adata.n_vars,
    n_covariates=adata.obs["y"].nunique(),
    learning_rate=1e-2,
)

In [None]:
trainer = L.Trainer(
    max_epochs=3,
    log_every_n_steps=100,
    max_steps=1000,  # only fit a few steps for the sake of this tutorial
)

In [None]:
trainer.fit(model=linear, datamodule=datamodule)