In [2]:
from scipy.stats import spearmanr
import pandas as pd
import numpy as np
import logging
import torch
import csv
from tqdm import tqdm
import anndata as ad
import scipy.sparse as sp
from scipy.io import mmwrite
from io import StringIO
from evaluate_results import get_ranked_pearson_corr
from evaluate_results import compute_rmse
from evaluate_results import get_ranked_spearman_corr
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import seaborn as sns
from collections import defaultdict

import sys
import os

sys.path.append(os.path.join(sys.path[0], '../'))
from model.manager_for_sagittarius import Sagittarius_Manager

columns_metadata = pd.read_csv("../data_utils/data/static/GSE190147_scirnaseq_gene_matrix.columns.csv", delimiter="\t")
rows_metadata = pd.read_csv("../data_utils/data/static/GSE190147_scirnaseq_gene_matrix.rows.csv", delimiter="\t")
gene_expression_matrix = pd.read_csv("../data_utils/data/static/GSE190147_scirnaseq_gene_matrix.txt", delimiter=" ")
gene_expression_matrix_file = "../data_utils/data/static/GSE190147_scirnaseq_gene_matrix.txt"
columns_metadata_file = "../data_utils/data/static/GSE190147_scirnaseq_gene_matrix.columns.csv"
rows_metadata_file = "../data_utils/data/static/GSE190147_scirnaseq_gene_matrix.rows.csv"
gene_expression_mtx = "../data_utils/data/static/GSE190147_scirnaseq_gene_matrix.mtx"


In [5]:
# map cell id to seurat cluster

cell_type_mapping = {}
with open('../data_utils/data/exp1/seurat_clusters_dict.txt', 'r') as file:
    for line in file:
        parts = line.strip().split(': ')
        key = parts[0].strip()
        value = int(parts[1].strip())
        cell_type_mapping[key] = value


In [3]:
# 5 timepoints
timepoints = ['hrs_00_02', 'hrs_01_03', 'hrs_02_04', 'hrs_03_07', 'hrs_04_08']

time_mapping = {
    'hrs_00_02': 0,
    'hrs_01_03': 1,
    'hrs_02_04': 2,
    'hrs_03_07': 3, 
    'hrs_04_08': 4, 
    'hrs_06_10': 5, 
    'hrs_08_12': 6,
    'hrs_10_14': 7,
    'hrs_12_16': 8, 
    'hrs_14_18': 9,
    'hrs_16_20': 10, 
}

In [4]:
# create mapping for the genes for train expr
# 'gene_types_dict.txt' - as created by create_gene_dict()

gene_types = {}
with open('../data_utils/data/static/GENES_MAPPING_FINAL.txt', 'r') as file:
    for line in file:
        parts = line.strip().split(': ')
        key = parts[0].strip()
        value = int(parts[1].strip())
        gene_types[key] = value


In [32]:
# set up which indices cell ids will match to
# shouldn't matter as much though
"""

cell_id = {}
with open('cell_ids_dict.txt', 'r') as file:
    for line in file:
        parts = line.strip().split(': ')
        key = parts[0].strip()
        value = int(parts[1].strip())
        cell_id[key] = value
"""


"\n\ncell_id = {}\nwith open('cell_ids_dict.txt', 'r') as file:\n    for line in file:\n        parts = line.strip().split(': ')\n        key = parts[0].strip()\n        value = int(parts[1].strip())\n        cell_id[key] = value\n"

In [11]:
# set up training inputs
N = 100 # number of samples
T = 1
M = len(gene_types)
expr = torch.zeros(N, T, M)
ys_tensor = torch.zeros(N, T)
ts = torch.zeros(N, T)

In [13]:
# main loop for populating expr
cell_id = set()

with open("../data_utils/data/exp1/exp1_filtered_matrix.csv", "r") as matrix_file:
    csv_reader = csv.reader(matrix_file)
    next(csv_reader)

    for line in tqdm(csv_reader, desc="Processing"):
        try:
            row_idx, col_idx, value = map(int, line)
                    
        except ValueError as ve:
            logging.error(f"Error parsing line: {line.strip()} - {ve}")
            continue
        
        tp = columns_metadata.iloc[col_idx, 4]
        
        cell_type = columns_metadata.iloc[col_idx, 0] # when i say cell_type i mean id
        gene_name = rows_metadata.iloc[row_idx - 2, 0]

        # sanity check
        # assert (expr[cell_id[cell_type] - 1, 0, gene_types[gene_name]] != 0)
        # assert(expr[[counter], 0, gene_types[gene_name]] != 0)
        cell_id.add(cell_type)
        expr[len(cell_id) - 1, 0, gene_types[gene_name]] = value
        ys_tensor[len(cell_id) - 1, 0] = cell_type_mapping[cell_type]
        ts[len(cell_id) - 1, 0] = time_mapping[tp]
        
print(len(cell_id))
ys = [ys_tensor]

Processing: 36333it [00:02, 15335.89it/s]

100





In [15]:
# create mask
mask = torch.zeros(N, T)
for i in range(N):
    for j in range(T):
        if not (expr[i, j, :] == 0).all():
            mask[i, j] = 1

In [17]:
manager = Sagittarius_Manager(
    input_dim=M,
    num_classes=1, 
    class_sizes=[54],
    cvae_catdims=[2],
    cvae_hiddendims=[128, 64],
    cvae_ld=32,
    attn_heads=4,
    num_ref_points=10,
    temporal_dim=16,
    tr_catdims=[8],
    minT=0,
    maxT=10, 
    device='cpu',
    transformer_dim=None, # Not using encoder/decoder for this experiment
    batch_size=16,
    beta=1.0,
    train_transfer=False, 
    num_cont=1, # Only continuous variable is time
    rec_loss='mse'
)

weights = '11313exp1.pth'
manager.train_model(expr, ts, ys, mask, reload=False, mfile=weights, num_epochs=1000, lr=0.001)

h


100%|███████████████████████████████████████| 1000/1000 [01:35<00:00, 10.47it/s]


In [5]:
# set up test set training

N = 100 # number of samples
T = 1
M = len(gene_types)
test_expr = torch.zeros(N, T, M)
test_ys_tensor = torch.zeros(N, T)
test_ts = torch.zeros(N, T)

cell_type_mapping = {}
with open('../data_utils/data/exp1/seurat_clusters_dict_TEST.txt', 'r') as file:
    for line in file:
        parts = line.strip().split(': ')
        key = parts[0].strip()
        value = int(parts[1].strip())
        cell_type_mapping[key] = value

cell_id = set()
with open("../data_utils/data/exp1/exp1_filtered_matrix_TEST.csv", "r") as matrix_file:
    csv_reader = csv.reader(matrix_file)
    next(csv_reader)

    for line in tqdm(csv_reader, desc="Processing"):
        try:
            row_idx, col_idx, value = map(int, line)
                    
        except ValueError as ve:
            logging.error(f"Error parsing line: {line.strip()} - {ve}")
            continue
        
        tp = columns_metadata.iloc[col_idx, 4]
        
        cell_type = columns_metadata.iloc[col_idx, 0] # when i say cell_type i mean id
        gene_name = rows_metadata.iloc[row_idx - 2, 0]

        cell_id.add(cell_type)
        test_expr[len(cell_id) - 1, 0, gene_types[gene_name]] = value
        test_ys_tensor[len(cell_id) - 1, 0] = cell_type_mapping[cell_type]
        test_ts[len(cell_id) - 1, 0] = time_mapping[tp]
        
print(len(cell_id))
test_ys = [test_ys_tensor]
test_mask = torch.zeros(N, T)
for i in range(N):
    for j in range(T):
        if not (test_expr[i, j, :] == 0).all():
            test_mask[i, j] = 1

Processing: 37078it [00:02, 15774.93it/s]

100





In [6]:
manager = Sagittarius_Manager(
    input_dim=M,
    num_classes=1, 
    class_sizes=[54],
    cvae_catdims=[2],
    cvae_hiddendims=[128, 64],
    cvae_ld=32,
    attn_heads=4,
    num_ref_points=10,
    temporal_dim=16,
    tr_catdims=[8],
    minT=0,
    maxT=10, 
    device='cpu',
    transformer_dim=None, # Not using encoder/decoder for this experiment
    batch_size=16,
    beta=1.0,
    train_transfer=False, 
    num_cont=1, # Only continuous variable is time
    rec_loss='mse'
)
weights = 'weights/11313exp1.pth'
manager.train_model(test_expr, test_ts, test_ys, test_mask, reload=True, mfile=weights)
newpred = manager.reconstruct()
print(newpred.shape)

# hold a copy
pred = newpred

gt = test_expr
gt = torch.masked_select(test_expr,torch.stack([test_mask.bool() for _ in range(M)], dim = -1)).view(-1, M)

# at this point pred = newpred, gt and pred have the same dimensions and are 2D

  return self._call_impl(*args, **kwargs)


h
torch.Size([100, 21216])


In [None]:
correlation = get_ranked_spearman_corr(pred, gt, get_per_sequence=False)
print(correlation)
correlation = get_ranked_spearman_corr(pred, gt, get_per_sequence=True)
print(correlation)

100


In [None]:
# individual statistics for spearman

gt_copy = gt
pred_copy = pred

avg_st = []

for i in range(N):

        # check if this timepoint is associated with values
    if not torch.all(gt_copy[i] == 0):
        gene_ex_vec = gt_copy[i]
        pred_ex_vec = pred_copy[i]
        nonzero_mask = torch.where(gene_ex_vec != 0, 1, 0)
        partial_gene_ex = torch.masked_select(gene_ex_vec, nonzero_mask.bool())
        partial_pred_ex = torch.masked_select(pred_ex_vec, nonzero_mask.bool())

        # check correlation to compute
        if len(partial_gene_ex) > 1:
            avg_st.append(spearmanr(partial_gene_ex.detach().numpy(), partial_pred_ex.detach().numpy()))
print(avg_st)
print(len(avg_st))