# Train a scVI model using Lamin

In [None]:
from __future__ import annotations

import os
from pprint import pprint

import pytest
import scvi
from scvi.data import synthetic_iid
from scvi.dataloaders import MappedCollectionDataModule

In [None]:
os.system("lamin init --storage ./lamindb_collection")  # one time for github runner (comment)
import lamindb as ln

ln.setup.init()  # one time for github runner (comment out when runing localy)

# prepare test data
adata1 = synthetic_iid()
adata2 = synthetic_iid()

artifact1 = ln.Artifact.from_anndata(adata1, key="part_one.h5ad").save()
artifact2 = ln.Artifact.from_anndata(adata2, key="part_two.h5ad").save()

collection = ln.Collection([artifact1, artifact2], key="gather")
# test mapped without saving first
# with collection.mapped() as ls_ds:
#    assert ls_ds.__class__.__name__ == "MappedCollection"
collection.save()

artifacts = collection.artifacts.all()
artifacts.df()

# large data example
# ln.track("d1kl7wobCO1H0005")
# ln.setup.init(name="lamindb_instance_name", storage=save_path)  # is this need in github test
# ln.setup.init()
# collection = ln.Collection.using("laminlabs/cellxgene").get(name="covid_normal_lung")
# artifacts = collection.artifacts.all()
# artifacts.df()

datamodule = MappedCollectionDataModule(
    collection, batch_key="batch", batch_size=1024, join="inner"
)

print(datamodule.n_obs, datamodule.n_vars, datamodule.n_batch)

pprint(datamodule.registry)

model = scvi.model.SCVI(adata=None, registry=datamodule.registry)
pprint(model.summary_stats)
pprint(model.module)
inference_dataloader = datamodule.inference_dataloader()

model.train(max_epochs=1, batch_size=1024, datamodule=datamodule)

_ = model.get_elbo(dataloader=inference_dataloader)
_ = model.get_marginal_ll(dataloader=inference_dataloader)
_ = model.get_reconstruction_error(dataloader=inference_dataloader)
_ = model.get_latent_representation(dataloader=inference_dataloader)

model.save("lamin_model", save_anndata=False, overwrite=True, datamodule=datamodule)
model_query = model.load_query_data(
    adata=False, reference_model="lamin_model", registry=datamodule.registry
)
model_query.train(max_epochs=1, datamodule=datamodule)
_ = model_query.get_elbo(dataloader=inference_dataloader)
_ = model_query.get_marginal_ll(dataloader=inference_dataloader)
_ = model_query.get_reconstruction_error(dataloader=inference_dataloader)
_ = model_query.get_latent_representation(dataloader=inference_dataloader)

adata = collection.load(join="inner")
scvi.model.SCVI.setup_anndata(adata, batch_key="batch")
with pytest.raises(ValueError):
    model.load_query_data(adata=adata)
model_query_adata = model.load_query_data(adata=adata, reference_model="lamin_model")
model_query_adata.train(max_epochs=1)
_ = model_query_adata.get_elbo()
_ = model_query_adata.get_marginal_ll()
_ = model_query_adata.get_reconstruction_error()
_ = model_query_adata.get_latent_representation()
_ = model_query_adata.get_latent_representation(dataloader=inference_dataloader)

model.save("lamin_model", save_anndata=False, overwrite=True, datamodule=datamodule)
model.load("lamin_model", adata=False)
model.load_query_data(adata=False, reference_model="lamin_model", registry=datamodule.registry)

model.load_query_data(adata=adata, reference_model="lamin_model")
model_adata = model.load("lamin_model", adata=adata)
scvi.model.SCVI.setup_anndata(adata, batch_key="batch")
model_adata.train(max_epochs=1)
model_adata.save("lamin_model_anndata", save_anndata=True, overwrite=True, datamodule=datamodule)
model_adata.load("lamin_model_anndata")
model_adata.load_query_data(
    adata=adata, reference_model="lamin_model_anndata", registry=datamodule.registry
)