In [None]:
import os
import json
import random
import pickle
from random import sample
from collections import Counter

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

import anndata
import scanpy as sc

from datasets import load_from_disk, Dataset, concatenate_datasets
from tqdm import tqdm

from src.utils import post_process_generated_cell_sentences, convert_cell_sentence_back_to_expression_vector

In [None]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)

# Load Data

In [None]:
processed_adata = anndata.read_h5ad("/home/sr2464/Desktop/cell2sentence-ft/preprocessed_adata.h5ad")
processed_adata

In [None]:
train_ds = load_from_disk("cell_sentence_arrow_ds/train_arrow_ds")
val_ds = load_from_disk("cell_sentence_arrow_ds/val_arrow_ds")
test_ds = load_from_disk("cell_sentence_arrow_ds/test_arrow_ds")

total_ds = concatenate_datasets([train_ds, val_ds, test_ds])
total_ds

In [None]:
# Reorder processed adata rows to matcha arrow dataset
train_partition_indices = np.load("cell_sentences/train_partition_indices.npy")
val_partition_indices = np.load("cell_sentences/val_partition_indices.npy")
test_partition_indices = np.load("cell_sentences/test_partition_indices.npy")

all_indices = np.concatenate([train_partition_indices, val_partition_indices, test_partition_indices], axis=0)
all_indices.shape

In [None]:
processed_adata = processed_adata[all_indices, :].copy()  # Reorders rows
processed_adata

In [None]:
processed_adata.X

In [None]:
processed_adata.X = processed_adata.X.toarray()
type(processed_adata.X)

# Restrict data to first 100 highest expressed genes

Restrict cell sentences to first 100 genes

In [None]:
total_ds = total_ds.map(lambda example: {"first_100_gene_words": example["input_ids"].split(" ")[:100]})
total_ds

In [None]:
len(total_ds[0]["first_100_gene_words"])

Restrict expression vectors to top 100 genes

In [None]:
for cell_idx in range(0, 800, 160):
    print(np.count_nonzero(processed_adata.X[cell_idx]))

In [None]:
for cell_idx in tqdm(range(processed_adata.X.shape[0])):
    cell_expr_vector = processed_adata.X[cell_idx]
    hundredth_top_expr_value = np.partition(cell_expr_vector, -100)[-100]
    cell_expr_vector[cell_expr_vector <= hundredth_top_expr_value] = 0
    processed_adata.X[cell_idx] = cell_expr_vector

In [None]:
for cell_idx in range(0, 800, 160):
    # Slightly less than 100 because many genes might have same expression count and gets filtered out
    print(np.count_nonzero(processed_adata.X[cell_idx]))

# Convert cell sentences back to expression vectors

In [None]:
# Load processed dataset linear model parameters
dataset_df = pd.read_csv("transformation_metrics_and_parameters.csv")
dataset_df.head()

In [None]:
slope = dataset_df.iloc[0, 2].item()
intercept = dataset_df.iloc[0, 3].item()
print(f"slope: {slope:.4f}, intercept: {intercept:.4f}")

In [None]:
# Load in gene vocabulary
global_vocab = set()
with open("cell_sentences/vocab_human.txt", "r") as fp:
    for line in fp:
        line = line.rstrip()  # remove end whitespace, e.g. newline
        line_elements = line.split(" ")
        gene_name = line_elements[0]
        global_vocab.add(gene_name)

global_vocab_list = list(global_vocab)
global_vocab_list = [gene_name.upper() for gene_name in global_vocab_list]
print(len(global_vocab_list))
global_vocab_list[30:40:2]

In [None]:
all_cell_sentences_converted_back_to_expression = []
for cell_idx in tqdm(range(processed_adata.shape[0])):
    cell_sentence_list = total_ds[cell_idx]["first_100_gene_words"]
    cell_sentence_str = " ".join(cell_sentence_list)

    post_processed_sentence, num_genes_replaced = post_process_generated_cell_sentences(
        cell_sentence=cell_sentence_str,
        global_dictionary=global_vocab_list,
        replace_nonsense_string="NOT_A_GENE",
    )

    reconstructed_expr_vec = convert_cell_sentence_back_to_expression_vector(
        cell_sentence=post_processed_sentence, 
        global_dictionary=global_vocab_list, 
        slope=slope, 
        intercept=intercept
    )
    all_cell_sentences_converted_back_to_expression.append(reconstructed_expr_vec)

In [None]:
all_cell_sentences_converted_back_to_expression = np.stack(all_cell_sentences_converted_back_to_expression, dtype=np.float32)
all_cell_sentences_converted_back_to_expression.shape

In [None]:
reconstructed_adata = sc.AnnData(X=all_cell_sentences_converted_back_to_expression)
reconstructed_adata

In [None]:
reconstructed_adata.obs["cell_type_label"] = total_ds["cell_type"]
reconstructed_adata.obs.head()

In [None]:
reconstructed_adata.var.index = global_vocab_list
reconstructed_adata.var["gene_name"] = global_vocab_list
reconstructed_adata.var.head()

In [None]:
reconstructed_adata

In [None]:
processed_adata

In [None]:
for cell_idx in range(0, 800, 160):
    # Slightly less than 100 because many genes might have same expression count and gets filtered out
    print(np.count_nonzero(processed_adata.X[cell_idx]))

In [None]:
for cell_idx in range(0, 800, 160):
    # Slightly less than 100 because many genes might have same expression count and gets filtered out
    print(np.count_nonzero(reconstructed_adata.X[cell_idx]))

In [None]:
processed_adata.write_h5ad("processed_adata_top100genes.h5ad")
reconstructed_adata.write_h5ad("reconstructed_adata_from_cell_sentences_top100genes.h5ad")