In [None]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format='retina'

In [None]:
import os
import json

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import anndata
import scanpy as sc
import datasets

In [None]:
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.set_float32_matmul_precision('high')

In [None]:
from scmg.model.manifold_generation import (ConditionalDiffusionModel, 
                                            train_diffusion_model)

In [None]:
#all_data = datasets.load_from_disk('/GPUData_xingjie/SCMG/manifold_generator_training/datasets/standard_adata_Tabula_Sapiens_HS_2022_all_0')
all_data = datasets.load_from_disk('/GPUData_xingjie/SCMG/manifold_generator_training/training_dataset_combined/dataset/')

all_data = all_data.with_format("torch")
print(f'The dataset contains {len(all_data)} points.')

#data_loader = torch.utils.data.DataLoader(all_data, 4096, shuffle=True)

data_loader = torch.utils.data.DataLoader(all_data, 4096, shuffle=True,
                                           num_workers=48, persistent_workers=True)

In [None]:
device = 'cuda:1'

condition_classes = pd.read_csv(
    '/GPUData_xingjie/Softwares/SCMG_dev/tests/manifold_generator/cell_types.csv'
    )['cell_type'].values

model = ConditionalDiffusionModel(
    n_feature=512,
    n_time_feature=256,
    condition_classes=condition_classes,
    n_condition_feature=512,
    n_steps=1000,
    n_network_blocks=8,
).to(device)

In [None]:
train_diffusion_model(
    model,
    data_loader,
    num_epochs=1000,
    output_path='/GPUData_xingjie/Softwares/SCMG_dev/tests/manifold_generator/trained_diffusion_model',
    lr=1e-4,
)

In [None]:
with open('/GPUData_xingjie/Softwares/SCMG_dev/tests/manifold_generator/trained_diffusion_model/loss_history.json') as f:
    loss_history = json.load(f)
start, stop = 10, 10000
for k in loss_history:
    plt.plot(np.arange(len(loss_history[k]))[start:stop],
             np.array(loss_history[k])[start:stop])
    plt.title(k)
    plt.show()

In [None]:
# Load the diffusion model

best_model = torch.load('/GPUData_xingjie/Softwares/SCMG_dev/tests/manifold_generator/trained_diffusion_model/model.pt')
best_model.load_state_dict(torch.load('/GPUData_xingjie/Softwares/SCMG_dev/tests/manifold_generator/trained_diffusion_model/best_state_dict.pth'))

device = 'cuda:0'
best_model.to(device)
best_model.eval()

In [None]:
batch = next(iter(data_loader))
Z_shift = batch['X_ce_latent'].numpy()
plt.hist(np.linalg.norm(Z_shift, axis=1), bins=100)
plt.show()

In [None]:
generated_zs = []

generated_zs = best_model.generate(batch['cell_type']).detach().cpu().numpy()

plt.hist(np.linalg.norm(generated_zs, axis=1), bins=100)
plt.show()

In [None]:
batch = next(iter(data_loader))
np.unique(batch['cell_type'])