In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys

base_path = ".."
sys.path.append(base_path)

import pickle
import scanpy as sc
import squidpy as sq
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch_geometric as pyg
from torch_geometric.data import Dataset, Data, Batch
from torch_geometric.loader import DataLoader
import torch_geometric.utils as pyg_utils
import torch_geometric as pyg
import pytorch_lightning as pl
from pytorch_lightning.utilities.model_summary import ModelSummary
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import (
    RichProgressBar,
    TQDMProgressBar,
    ModelCheckpoint,
)
import wandb
import pandas as pd


from utils.gnn_utils import *
from utils.gnn import *
from utils.module import *
from utils.utils import seed_everything
from src.data.datasets import *

from configs.main_config import config


# ignore all warnings
import warnings

warnings.filterwarnings("ignore")

plt.rcParams["figure.dpi"] = 80

# set seeds
seed_everything(seed=42)

In [None]:
data_dir = f"{base_path}/data"

experiment_dir = "experiments/experiment_bulk_real"

bulk_data_dir = f"{data_dir}/bulk"
bulk_data = sc.read_text(f"{bulk_data_dir}/GSE120502.txt").T

# load previously prepared data
X_real, X_real_train, X_sim, y_sim = load_prepared_data(f"{experiment_dir}/datasets")

# load cell types
celltype_names = load_celltypes(f"{experiment_dir}/datasets/celltypes.txt")

# load real groundtruth if available
y_real, y_real_df = load_real_groundtruth(
    f"{bulk_data_dir}/gt_GSE120502.txt", col_order=celltype_names
)

sim_data_dir = f"{data_dir}/simulated"
sim_data = sc.read_h5ad(f"{sim_data_dir}/simulated_pbmc8k_qc.h5ad")

st_data = None

In [None]:
# load cell types
celltype_names = load_celltypes(f"{experiment_dir}/datasets/celltypes.txt")

# load sample names
sample_names = load_sample_names(f"{experiment_dir}/datasets/sample_names.txt")

# load previously prepared data
X_real, X_real_train, X_sim, y_sim = load_prepared_data(f"{experiment_dir}/datasets")

In [None]:
train_data = prepare_dataset(X_real_train, X_sim, y_sim, y_real=y_real, st_data=st_data)
test_data = prepare_dataset(X_real, X_sim, y_sim, y_real=y_real, st_data=st_data)

train_loader = DataLoader(
    train_data,
    batch_size=64,
    shuffle=True,
    pin_memory=True,
    num_workers=0,
    drop_last=True,
)

test_loader = DataLoader(
    test_data, batch_size=3000, shuffle=False, pin_memory=True, num_workers=0
)

In [None]:
net = Dissect(
    num_genes=X_real.shape[1],
    num_celltypes=y_sim.shape[1],
)

model = DeconvolutionModel(
    net,
    # l1_lambda=0.0,
    # l2_lambda=0.0,
    sim_loss_fn="kl_div",
    spatial_data=st_data,
    celltype_names=celltype_names,
    sample_names=sample_names,
)

# setup callbacks
checkpoint_callback = ModelCheckpoint(
    monitor="train/total_loss", mode="min", save_last=True, dirpath="checkpoints"
)

In [None]:
wandb.finish()
wandb_mode = "online"
# wandb_mode = "disabled"
wandb_logger = WandbLogger(project="dissect-spatial", log_model=True, mode=wandb_mode)

# training
trainer = pl.Trainer(
    max_epochs=5000,
    max_steps=-1,
    accelerator="gpu",
    limit_train_batches=1,
    log_every_n_steps=1,
    check_val_every_n_epoch=500,
    devices=[7],
    precision=32,
    logger=wandb_logger,
    deterministic="warn",
    enable_checkpointing=True,
    # fast_dev_run=True,
    # profiler="simple",
    enable_progress_bar=False,
    callbacks=[checkpoint_callback],
)

val_dataloaders = test_loader
# val_dataloaders = None
trainer.fit(model, train_loader, val_dataloaders=val_dataloaders)
wandb.finish()