In [1]:
from scipy.stats import spearmanr
import pandas as pd
import numpy as np
import logging
import torch
import csv
from tqdm import tqdm
import anndata as ad
import scipy.sparse as sp
from scipy.io import mmwrite
from io import StringIO
from evaluate_results import get_ranked_pearson_corr
from evaluate_results import compute_rmse


import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import seaborn as sns
from collections import defaultdict

import sys
import os

sys.path.append(os.path.join(sys.path[0], '../'))
from model.manager_for_sagittarius import Sagittarius_Manager

columns_metadata = pd.read_csv("/Users/jc/Documents/GitHub/Sagittarius/EvoDevo/dataset/GSE190147_scirnaseq_gene_matrix.columns.csv", delimiter="\t")
rows_metadata = pd.read_csv("/Users/jc/Documents/GitHub/Sagittarius/EvoDevo/dataset/GSE190147_scirnaseq_gene_matrix.rows.csv", delimiter="\t")
gene_expression_matrix = pd.read_csv("/Users/jc/Documents/GitHub/Sagittarius/EvoDevo/dataset/GSE190147_scirnaseq_gene_matrix.txt", delimiter=" ")
gene_expression_matrix_file = "/Users/jc/Documents/GitHub/Sagittarius/EvoDevo/dataset/GSE190147_scirnaseq_gene_matrix.txt"
columns_metadata_file = "/Users/jc/Documents/GitHub/Sagittarius/EvoDevo/dataset/GSE190147_scirnaseq_gene_matrix.columns.csv"
rows_metadata_file = "/Users/jc/Documents/GitHub/Sagittarius/EvoDevo/dataset/GSE190147_scirnaseq_gene_matrix.rows.csv"
gene_expression_mtx = "/Users/jc/Documents/GitHub/Sagittarius/GSE190147_scirnaseq_gene_matrix.mtx"

df = pd.read_csv('/Users/jc/Documents/GitHub/Sagittarius/935117.csv')

In [2]:
# create set of wanted cell identifiers
cell_types = set()
for cell_value in df['cell']:
    cell_types.add(cell_value)

# gather wanted row indices by checking if the column_metadata cell identifier is in the wanted cell identifiers
wanted_values = set()
for index, cell_id in enumerate(columns_metadata['cell']):
    if cell_id in cell_types:
        wanted_values.add(index)

# map to seurat cluster. important for generating ys

shortened_meta = '/Users/jc/Documents/GitHub/Sagittarius/935117.csv'

cell_type_mapping = {}

with open(shortened_meta, newline='') as csvfile:
    csv_reader = csv.DictReader(csvfile)

    for row in csv_reader:
        cell = row['cell']
        seurat_cluster = row['seurat_clusters']

        cell_type_mapping[cell] = int(seurat_cluster)

with open('seurat_clusters_dict.txt', 'w') as file:
    for key, value in cell_type_mapping.items():
        file.write(f'{key}: {value}\n')

In [3]:
# 5 timepoints
timepoints = ['hrs_00_02', 'hrs_01_03', 'hrs_02_04', 'hrs_03_07', 'hrs_04_08']

time_mapping = {
    'hrs_00_02': 0,
    'hrs_01_03': 1,
    'hrs_02_04': 2,
    'hrs_03_07': 3, 
    'hrs_04_08': 4, 
    'hrs_06_10': 5, 
    'hrs_08_12': 6,
    'hrs_10_14': 7,
    'hrs_12_16': 8, 
    'hrs_14_18': 9,
    'hrs_16_20': 10, 
}

In [4]:
# the genes for train expr (from rows_metadata)
gene_types = {}
counter = 0
with open("/Users/jc/Documents/GitHub/Sagittarius/models/gene_names_100.csv", newline='') as gene_file:
    csv_reader = csv.reader(gene_file)
    next(csv_reader, None)
    for index, row in enumerate(csv_reader):
        if index == 0:
            print(row[0])
        gene_name = row[0]
        
        if (gene_name not in gene_types):
            gene_types[gene_name] = counter # this should be 0-based indexing
            counter += 1

gene_types = {key: int(value) for key, value in gene_types.items()}
with open('gene_types_dict.txt', 'w') as file:
    for key, value in gene_types.items():
        file.write(f'{key}: {value}\n')


Ranbp9


In [5]:
# set stuff up
N = len(cell_types) # number of samples
T = 1
M = len(gene_types)
expr = torch.zeros(N, T, M)
ys_tensor = torch.zeros(N, T)
ts = torch.zeros(N, T)

In [6]:
# main loop for populating expr
gene_types = {}
counter = 0
with open("/Users/jc/Documents/GitHub/Sagittarius/models/gene_names_100.csv", newline='') as gene_file:
    csv_reader = csv.reader(gene_file)
    next(csv_reader, None)
    for index, row in enumerate(csv_reader):
        if index == 0:
            print(row[0])
        gene_name = row[0]
        
        if (gene_name not in gene_types):
            gene_types[gene_name] = counter # this should be 0-based indexing
            counter += 1

gene_types = {key: int(value) for key, value in gene_types.items()}


N = len(cell_types) # number of samples
T = len(timepoints)
M = len(gene_types)
expr = torch.zeros(N, T, M)
ys_tensor = torch.zeros(N, T)
ts = torch.zeros(N, T)

counter = 0
cell_id = dict()
with open("/Users/jc/Documents/GitHub/Sagittarius/models/new_cell_types.csv", "r") as matrix_file:
    csv_reader = csv.reader(matrix_file)
    next(csv_reader)

    for line in tqdm(csv_reader, desc="Processing"):
        try:
            row_idx, col_idx, value = map(int, line)
                    
        except ValueError as ve:
            logging.error(f"Error parsing line: {line.strip()} - {ve}")
            continue
        
        tp = columns_metadata.iloc[col_idx, 4]
        
        cell_type = columns_metadata.iloc[col_idx, 0] # when i say cell_type i mean id
        cell_id[cell_type] = len(cell_id) # when i say cell_id dict, it holds cell ids and the order in which stuff was added
        gene_name = rows_metadata.iloc[row_idx - 2, 0]
        #if counter == 0:
            #print(cell_type)
        expr[len(cell_id) - 1, 0, gene_types[gene_name]] = value
        
        #print(cell_type_mapping[cell_type])

        ys_tensor[cell_type_mapping[cell_type], 0] = cell_type_mapping[cell_type]
        ts[cell_type_mapping[cell_type], 0] = time_mapping[tp]

        if cell_type not in cell_types and not tp in timepoints:
            print("ok now something's weird")

ys = [ys_tensor]

Ranbp9


Processing: 38422it [00:01, 34144.27it/s]


In [7]:
# motivation here is to just generate the cell_id thing through the loop once and have stuff in the same place in expr
# so it's like not so rng

with open('cell_ids_dict.txt', 'w') as file:
    for key, value in cell_id.items():
        file.write(f'{key}: {value}\n')

In [8]:
# create mask
mask = torch.zeros(N, T)
for i in range(N):
    for j in range(T):
        if not (expr[i, j, :] == 0).all():
            mask[i, j] = 1

In [9]:
manager = Sagittarius_Manager(
    input_dim=M,
    num_classes=1, 
    class_sizes=[54],
    cvae_catdims=[2],
    cvae_hiddendims=[128, 64],
    cvae_ld=32,
    attn_heads=4,
    num_ref_points=10,
    temporal_dim=16,
    tr_catdims=[8],
    minT=0,
    maxT=10, 
    device='cpu',
    transformer_dim=None, # Not using encoder/decoder
    batch_size=16,
    beta=1.0,
    train_transfer=False, 
    num_cont=1, # Only continuous variable is time
    rec_loss='mse'
)

manager.train_model(expr, ts, ys, mask, reload=False, mfile="trained_model4.pth", num_epochs=1000, lr=0.001)
newpred = manager.reconstruct()
print(newpred.shape)

# ig just hold a copy for literally no reason at all
pred = newpred

gt = expr
gt = torch.masked_select(expr,torch.stack([mask.bool() for _ in range(M)], dim = -1)).view(-1, M)

# at this point pred = newpred, gt and pred have the same dimensions and are 2D

  x = torch.tensor(expr[bstart:bend], dtype=torch.float32).to(self.device) # Numpy don't have Float, unlike tensors
  return self._call_impl(*args, **kwargs)
100%|███████████████████████████████████████| 1000/1000 [00:54<00:00, 18.34it/s]


torch.Size([100, 5731])


In [10]:
# individual statistics for spearman

gt_copy = gt
pred_copy = pred

avg_st = []

for i in range(100):

        # Check if this timepoint is associated with values
    if not torch.all(gt_copy[i] == 0):
        gene_ex_vec = gt_copy[i]
        pred_ex_vec = pred_copy[i]
        nonzero_mask = torch.where(gene_ex_vec != 0, 1, 0)
        partial_gene_ex = torch.masked_select(gene_ex_vec, nonzero_mask.bool())
        partial_pred_ex = torch.masked_select(pred_ex_vec, nonzero_mask.bool())

        # See if there's actual correlation to compute ?
        if len(partial_gene_ex) > 1:
            avg_st.append(spearmanr(partial_gene_ex.detach().numpy(), partial_pred_ex.detach().numpy()))
print(avg_st)
print(len(avg_st))

[SignificanceResult(statistic=0.8252824130971429, pvalue=8.865564893140931e-51), SignificanceResult(statistic=0.8547159841398466, pvalue=3.1940992993971814e-130), SignificanceResult(statistic=0.8539239720937588, pvalue=7.289843650425341e-75), SignificanceResult(statistic=0.8694966305789625, pvalue=2.4911292413653712e-108), SignificanceResult(statistic=0.8664298687403662, pvalue=5.949845349946883e-75), SignificanceResult(statistic=0.8496264681866003, pvalue=1.5073779265954363e-60), SignificanceResult(statistic=0.8504329379665797, pvalue=2.491423629927154e-47), SignificanceResult(statistic=0.8345638650459208, pvalue=8.282107313717345e-118), SignificanceResult(statistic=0.8735803563638205, pvalue=3.5121964857275334e-158), SignificanceResult(statistic=0.8277568414258824, pvalue=1.5196357297250423e-80), SignificanceResult(statistic=0.7702669272224361, pvalue=3.0149880862434236e-56), SignificanceResult(statistic=0.7324418257385974, pvalue=1.4879540655849316e-33), SignificanceResult(statistic