In [None]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format='retina'

In [None]:
#import time
#import os
#
#while True:
#    
#    if os.path.exists(
#        '/GPUData_xingjie/SCMG/contrastive_embedding_training/training_dataset/done.csv'
#    ):
#        break
#
#    time.sleep(10)

In [None]:
import os
import json

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import anndata
import scanpy as sc
import datasets

In [None]:
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.set_float32_matmul_precision('high')

In [None]:
from scmg.model.contrastive_embedding import (CellEmbedder, 
                            train_contrastive_embedder)

In [None]:
input_path = '/GPUData_xingjie/SCMG/contrastive_embedding_training/training_dataset'

# Load the expression dataset
exp_data = datasets.load_from_disk(os.path.join(input_path, 'combined_exp'))
exp_data = exp_data.with_format("torch")
print(f'The expression dataset contains {len(exp_data)} points.')

# Load the edge dataset
edge_data = datasets.load_from_disk(os.path.join(input_path, 'edge_dataset'))
edge_data = edge_data.with_format("torch")
print(f'The edge dataset contains {len(edge_data)} points.')

edge_loader = torch.utils.data.DataLoader(edge_data, 4096, shuffle=True,
                                        num_workers=8, persistent_workers=True)

with open(os.path.join(input_path, 'dataset_id_map.json')) as json_file:
    dataset_id_map = json.load(json_file)

In [None]:
cell_id_df = pd.read_parquet(os.path.join(input_path, 'cell_id.parquet'))

dataset_to_cell_idx_map = {}
dataset_to_cell_idx_map['AllenBrain_WB_MM_2023:all'] = cell_id_df[
        cell_id_df['dataset_id'].isin(['AllenBrain_WB_MM_2023:all'])].index.values

dataset_to_cell_idx_map['Suo_ImmuneDev_HS_2022:all'] = cell_id_df[
        cell_id_df['dataset_id'].isin(['Suo_ImmuneDev_HS_2022:all'])].index.values

dataset_to_cell_idx_map['Tabula_Sapiens_HS_2022:all'] = cell_id_df[
        cell_id_df['dataset_id'].isin(['Tabula_Sapiens_HS_2022:all'])].index.values

dataset_to_cell_idx_map['Qiu_Organogenesis_MM_2022:all'] = cell_id_df[
        cell_id_df['dataset_id'].isin([
        'Qiu_Organogenesis_MM_2022:all', 'Qiu_whole_embryo_dev_MM_2024'])].index.values

dataset_to_cell_idx_map['Qiu_whole_embryo_dev_MM_2024:all'] = cell_id_df[
        cell_id_df['dataset_id'].isin([
        'Qiu_Organogenesis_MM_2022:all', 'Qiu_whole_embryo_dev_MM_2024'])].index.values

In [None]:
device = 'cuda:1'

model = CellEmbedder(
    n_genes=exp_data[0]['X_exp'].shape[0],
    dataset_id_map=dataset_id_map
).to(device)

train_contrastive_embedder(
    model=model,
    edge_loader=edge_loader,
    exp_data=exp_data,
    dataset_to_cell_idx_map=dataset_to_cell_idx_map,
    num_epochs=100,
    output_path='/GPUData_xingjie/Softwares/SCMG_dev/tests/contrastive_embedding/trained_embedder',
)

In [None]:
with open('/GPUData_xingjie/Softwares/SCMG_dev/tests/contrastive_embedding/trained_embedder/loss_history.json') as f:
    loss_history = json.load(f)

start, stop = 1, 1000
for k in loss_history:
    plt.plot(np.arange(len(loss_history[k]))[start:stop],
             np.array(loss_history[k]).mean(axis=1)[start:stop])
    plt.title(k)
    plt.show()

In [None]:
## Continue training from the break point
#model_path = '/GPUData_xingjie/Softwares/SCMG_dev/tests/contrastive_embedding/trained_embedder/'
#
#model = torch.load(os.path.join(model_path, 'model.pt'))
#model.load_state_dict(torch.load(os.path.join(model_path, 'best_state_dict.pth')))
#
#device = 'cuda:1'
#model.to(device)
#
#with open('/GPUData_xingjie/Softwares/SCMG_dev/tests/contrastive_embedding/trained_embedder/loss_history.json') as f:
#    loss_history = json.load(f)
#
#train_contrastive_embedder(
#    model=model,
#    edge_loader=edge_loader,
#    exp_data=exp_data,
#    dataset_to_cell_idx_map=dataset_to_cell_idx_map,
#    num_epochs=100,
#    output_path='/GPUData_xingjie/Softwares/SCMG_dev/tests/contrastive_embedding/trained_embedder',
#)