Script to analyze the multiple myeloma results. This includes plotting the representative tree and gene expression analysis.

In [25]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import matplotlib.colors as mcolors
import matplotlib.cm as cm
from IPython.core.pylabtools import figsize
from networkx.algorithms.bipartite.cluster import clustering
from torchgen.gen_functionalization_type import return_str

from src_python.cell_tree import CellTree

In [26]:
study_num = "bc03" #"mm34" #"bc03"
primary_cells = []

if study_num == "bc03":
    primary_cells = ["SRR5023442", "SRR5023443", "SRR5023444", "SRR5023445", "SRR2973275", "SRR2973351", "SRR2973352", "SRR2973353", "SRR2973354", "SRR2973355", "SRR2973356", "SRR2973357", "SRR2973358", "SRR2973359", "SRR2973360", "SRR2973361", "SRR2973362", "SRR2973363", "SRR2973364", "SRR2973365", "SRR2973366", "SRR2973367", "SRR2973368", "SRR2973369", "SRR2973370", "SRR2973371", "SRR2973372", "SRR2973373", "SRR2973374", "SRR2973375", "SRR2973376", "SRR2973377", "SRR2973378", "SRR2973379", "SRR2973380", "SRR2973381", "SRR2973382", "SRR2973383"]

if study_num == "bc07":
    primary_cells =  ["SRR2973484", "SRR5023558", "SRR5023559", "SRR5023560", "SRR2973277", "SRR2973437", "SRR2973438", "SRR2973439", "SRR2973440", "SRR2973441", "SRR2973442", "SRR2973443", "SRR2973444", "SRR2973445", "SRR2973446", "SRR2973447", "SRR2973448", "SRR2973449", "SRR2973450", "SRR2973451", "SRR2973452", "SRR2973453", "SRR2973454", "SRR2973455", "SRR2973456", "SRR2973457", "SRR2973458", "SRR2973459", "SRR2973460", "SRR2973461", "SRR2973462", "SRR2973463", "SRR2973464", "SRR2973465", "SRR2973466", "SRR2973467", "SRR2973468", "SRR2973469", "SRR2973470", "SRR2973471", "SRR2973472", "SRR2973473", "SRR2973474", "SRR2973475", "SRR2973476", "SRR2973477", "SRR2973478", "SRR2973479", "SRR2973480", "SRR2973481", "SRR2973482", "SRR2973483"]

if study_num == "mm34":
    primary_cells = ["SRR6710302", "SRR6710303", "SRR6710304", "SRR6710305", "SRR6710306", "SRR6710307", "SRR6710308", "SRR6710309", "SRR6710310", "SRR6710311", "SRR6710312", "SRR6710313", "SRR6710314", "SRR6710315", "SRR6710316", "SRR6710317", "SRR6710318", "SRR6710319", "SRR6710320", "SRR6710321", "SRR6710322", "SRR6710323", "SRR6710324", "SRR6710325", "SRR6710326", "SRR6710327", "SRR6710328", "SRR6710329", "SRR6710330", "SRR6710331", "SRR6710332", "SRR6710333", "SRR6710334", "SRR6710335", "SRR6710336", "SRR6710337", "SRR6710338", "SRR6710339", "SRR6710340", "SRR6710341", "SRR6710342", "SRR6710343", "SRR6710344", "SRR6710345", "SRR6710346", "SRR6710347", "SRR6710348", "SRR6710349", "SRR6710350", "SRR6710351", "SRR6710352", "SRR6710353", "SRR6710354", "SRR6710355", "SRR6710356", "SRR6710357", "SRR6710358", "SRR6710359", "SRR6710360", "SRR6710361", "SRR6710362", "SRR6710363", "SRR6710364", "SRR6710365", "SRR6710366"]

if study_num == "mm16":
    primary_cells = ["SRR6710256", "SRR6710257", "SRR6710258", "SRR6710259", "SRR6710260", "SRR6710261", "SRR6710262", "SRR6710263", "SRR6710264", "SRR6710265", "SRR6710266", "SRR6710267", "SRR6710268", "SRR6710269", "SRR6710270", "SRR6710271", "SRR6710272", "SRR6710273", "SRR6710274", "SRR6710275", "SRR6710276", "SRR6710277", "SRR6710278"]

In [27]:
n_bootstrap = 1000
np.random.seed(0)
n_rounds = 3
model = "sclineager"   # "sciterna"

path = rf"../data/results/{study_num}/{model}"
os.makedirs(os.path.join(path, f"{model}_ct_trees"), exist_ok=True)

ref = np.array(pd.read_csv(os.path.join(f"../data/input_data/{study_num}", "ref.csv")))
cell_names = list(ref[:,0])
cell_indices = [cell_names.index(item) for item in primary_cells if item in cell_names]
ref = ref[:, 1:]

alt = np.array(pd.read_csv(os.path.join(f"../data/input_data/{study_num}", "alt.csv")))[:, 1:]

def plot_trees(path, round, plot_mutations=False, min_num=0):
    for i in range(min_num, n_rounds):
        path_parent = os.path.join(path, f"{model}_parent_vec", f"{model}_parent_vec_{i}{round}.txt")
        path_mut_loc = os.path.join(path, f"{model}_mutation_location", f"{model}_mutation_location_{i}{round}.txt")
        path_selected = os.path.join(path, f"{model}_selected_loci", f"{model}_selected_loci_{i}{round}.txt")

        parent_vec = np.loadtxt(path_parent, dtype=int)
        if os.path.exists(path_mut_loc) == False:
            mut_locs = []
            selected_mutations = []
        else:
            mut_locs = np.loadtxt(path_mut_loc, dtype=int)
            selected_mutations = np.loadtxt(path_selected, dtype=int)

        n_cells = int(((len(parent_vec)+1)/2))
        ct = CellTree(n_cells=n_cells, n_mut=len(selected_mutations))
        ct.use_parent_vec(parent_vec)

        if plot_mutations:
            ct.mut_loc = mut_locs
            graph = ct.to_graphviz()
            graph.render(os.path.join(path, f"{model}_ct_trees", f"ct_tree_{i}_{round}"), format='png', cleanup=True)

        ct.mut_loc = []
        graph_branches = ct.to_graphviz()
        for n in range(n_cells):
            if study_num == "mm34":
                if n in cell_indices:
                    graph_branches.node(str(n), shape='circle', style='filled', color="red")
                else:
                    graph_branches.node(str(n), shape='circle', style='filled', color="blue") # metastasis
            elif study_num == "mm16":
                if n < 23:
                    graph_branches.node(str(n), shape='circle', style='filled', color="red")
                else:
                    graph_branches.node(str(n), shape='circle', style='filled', color="blue") # after treatment
            elif study_num == "bc03":
                if n in cell_indices:
                    graph_branches.node(str(n), shape='circle', style='filled', color="red")
                else:
                    graph_branches.node(str(n), shape='circle', style='filled', color="blue") # lymph node metastasis
            elif study_num == "bc07":
                if n in cell_indices:
                    graph_branches.node(str(n), shape='circle', style='filled', color="red")
                else:
                    graph_branches.node(str(n), shape='circle', style='filled', color="blue") # lymph node metastasis


        graph_branches.render(os.path.join(path, f"{model}_ct_trees", f"ct_tree_{i}_{round}_branches"), format='png', cleanup=True)

if model == "sciterna":
    plot_trees(path, f"0r", plot_mutations=True)
    path_bootstrap = rf"../data/results/{study_num}/{model}_bootstrap"
    os.makedirs(os.path.join(path_bootstrap, f"{model}_ct_trees"), exist_ok=True)
    for i in range(n_bootstrap):
        plot_trees(path_bootstrap, f"{i}r", plot_mutations=False)
else:
    plot_trees(path, "", plot_mutations=True, min_num = 2)

In [28]:
import seaborn as sns

vaf = alt/(ref+alt)
vafq = pd.isna(vaf)
missing = np.sum(vafq, axis=1)
vaf[vafq] = 0.0
vaf = np.array(vaf, dtype=float)

plt.figure(figsize=(5,10))
sns.heatmap(vaf)
plt.title("Variant Allele Frequency (VAF) Heatmap")
plt.xlabel("Samples")
plt.ylabel("Mutations")
plt.show()

ZeroDivisionError: division by zero

In [None]:
missing

In [None]:
np.argwhere(posterior_node_after_branching == 5)

In [None]:
selected = np.loadtxt(os.path.join(path_bootstrap, "selected.txt"), delimiter=',', dtype=int)
ref = pd.read_csv(os.path.join(f"../data/input_data/{study_num}", "ref.csv"))

all_individual_overdispersions_h = np.full((n_bootstrap, ref.shape[1]), np.nan)
all_individual_dropouts = np.full((n_bootstrap, ref.shape[1]), np.nan)

for i in range(n_bootstrap):
    for j in range(n_rounds):
        selected_mutations = np.loadtxt(os.path.join(path_bootstrap, "sciterna_selected_loci", f"sciterna_selected_loci_{j}r{i}.txt"), dtype=int)
        individual_dropouts = np.loadtxt(os.path.join(path_bootstrap, "sciterna_individual_dropout_probs", f"sciterna_individual_dropout_probs_{j}r{i}.txt"))
        individual_overdispersions_h = np.loadtxt(os.path.join(path_bootstrap, "sciterna_individual_overdispersions_h", f"sciterna_individual_overdispersions_h_{j}r{i}.txt"))

        unique_mutations = np.unique(selected_mutations)
        for mut in unique_mutations:
            indices = np.where(selected_mutations == mut)[0]
            mean_dropout = np.mean(individual_dropouts[indices])
            mean_overdispersion = np.mean(individual_overdispersions_h[indices])

            all_individual_dropouts[i, mut] = mean_dropout
            all_individual_overdispersions_h[i, mut] = mean_overdispersion

In [None]:
sufficient_data_columns_od = ~(np.any(all_individual_overdispersions_h == 6, axis=0))
sufficient_data_columns_dropout = ~(np.any(all_individual_dropouts == 0.2, axis=0))

mean_overdispersions = np.nanmean(all_individual_overdispersions_h[:, sufficient_data_columns_od], axis=0)
mean_dropouts = np.nanmean(all_individual_dropouts[:, sufficient_data_columns_dropout], axis=0)

In [None]:
plt.hist(mean_overdispersions, bins=100)
plt.title("Overdispersion Heterozygous distribution")
plt.show()

plt.hist(mean_dropouts, bins=100)
plt.title("Dropout heterozygous distribution")
plt.show()

In [None]:
path_sclineager = f"../data/results/{study_num}/sclineager"
path_dendro = f"../data/results/{study_num}/dendro"

clones = np.arange(2,6)

def plot_model_trees(path, model):

    os.makedirs(os.path.join(path, f"{model}_ct_trees"), exist_ok=True)

    for c in clones:
        path_parent = os.path.join(path, f"{model}_parent_vec", f"{model}_parent_vec_{c}.txt")

        parent_vec = np.loadtxt(path_parent, dtype=int)

        n_cells = int(((len(parent_vec)+1)/2))
        ct = CellTree(n_cells=n_cells)
        ct.use_parent_vec(parent_vec)

        ct.mut_loc = []
        graph = ct.to_graphviz()
        for n in range(n_cells):
            if n < 65:
                graph.node(str(n), shape='circle', style='filled', color="red")
            elif n >=65:
                graph.node(str(n), shape='circle', style='filled', color="blue")
        graph.render(os.path.join(path, f"{model}_ct_trees", f"ct_tree_{c}_branches"), format='png', cleanup=True)

plot_model_trees(path_sclineager, "sclineager")
plot_model_trees(path_dendro, "dendro")

In [None]:
def create_mutation_matrix(parent_vector, mutation_indices, ct):
    n_nodes = len(parent_vector)
    n_leaves = int((n_nodes+1)/2)
    n_mutations = len(mutation_indices)

    mutation_matrix = np.zeros((n_nodes, n_mutations), dtype=int)

    for mutation_idx, cell_idx in enumerate(mutation_indices):
        children = [c for c in ct.dfs(cell_idx)]
        for cell in children:  # Traverse all cells below the mutation cell
            mutation_matrix[cell, mutation_idx] = 1  # Mark cells with the mutation

    return mutation_matrix[:n_leaves].T

def are_all_arrays_identical(arrays_list):
    first_array = arrays_list[0]
    return all(np.array_equal(first_array, array) for array in arrays_list[1:])

In [None]:
# with open("../data/input_data/gencode.v21.annotation.gtf", 'r') as file:
#     for line in file:
#         if line.startswith("#"):
#                 continue
#         fields = line.strip().split('\t')
#         if fields[2] == 'gene':
#             attributes = {key_value.split(' ')[0]: key_value.split(' ')[1].strip('"') for key_value in
#                           fields[8].split('; ') if key_value}
#             gene_names.append([fields[0], fields[3], fields[4], attributes["gene_name"]])
# 
# df = pd.DataFrame(gene_names, columns=['chromosome', 'start', 'end', 'gene'])
# df['start'] = df['start'].astype(int)
# df['end'] = df['end'].astype(int)
# df.to_csv("../data/input_data/mm34/gene_positions.csv", index=False)

def convert_location_to_gene(locs):

    loc_to_gene = []

    df = pd.read_csv("../data/input_data/mm34/gene_positions.csv", index_col=False)
    for location in locs:
        chrom, pos = location.split(":")[0], location.split(":")[1]
        pos = int(pos)
        matching_rows = df[(df['chromosome'] == chrom) & (df['start'] <= pos) & (df['end'] >= pos)]
        matching_genes = matching_rows['gene'].tolist()
        loc_to_gene.append(matching_genes)

    return loc_to_gene

In [None]:
n_bootstrap = 1000
np.random.seed(0)
use_summary_statistics = False # Uses the summary files saved in the data/ directory. If raw output files were generated, set to False

path = rf"../data/results/{study_num}/sciterna_bootstrap"

reference = pd.read_csv(rf'../data/input_data/{study_num}/ref.csv', index_col=0)

if use_summary_statistics == False:
    os.makedirs(os.path.join(path, "sciterna_mutations_before_branching"), exist_ok=True)
    os.makedirs(os.path.join(path, "sciterna_nodes_after_branching"), exist_ok=True)
    os.makedirs(os.path.join(path, "sciterna_nodes_before_branching"), exist_ok=True)
    os.makedirs(os.path.join(path, "sciterna_mutations_branching"), exist_ok=True)
    os.makedirs(os.path.join(path, "sciterna_mutations_after_branching"), exist_ok=True)
    
    all_loci = reference.columns
    all_cells = reference.index
    mut_indicator = {}
    
    def mutations_node(parent_vector, mutation_indices, ct, node):
        n_cells = len(parent_vector)
        n_mutations = len(mutation_indices)
    
        mutation_matrix = np.zeros((n_cells, n_mutations), dtype=int)
    
        for mutation_idx, cell_idx in enumerate(mutation_indices):
            children = [c for c in ct.dfs_experimental(cell_idx)]
            for cell in children:  # Traverse all cells below the mutation cell
                mutation_matrix[cell, mutation_idx] = 1  # Mark cells with the mutation
    
        return mutation_matrix[node]
    
    for i in range(0, n_bootstrap):
        for r in range(n_rounds):
            path_parent = os.path.join(path, "sciterna_parent_vec", f"sciterna_parent_vec_{r}r{i}.txt")

            if os.path.exists(os.path.join(path, rf"sciterna_mutations_branching\mutations_branching_{r}r{i}.txt")):
                continue

            path_mut_loc = os.path.join(path, "sciterna_mutation_location", f"sciterna_mutation_location_{r}r{i}.txt")
            path_selected = os.path.join(path, "sciterna_selected_loci", f"sciterna_selected_loci_{r}r{i}.txt")
            parent_vec = np.loadtxt(path_parent, dtype=int)
            mut_locs = np.loadtxt(path_mut_loc, dtype=int)
            mutations = np.loadtxt(path_selected)
            for loci in all_loci:
                mut_indicator[loci] = []

            n_cells = int(((len(parent_vec)+1)/2))
            ct = CellTree(n_cells=n_cells, n_mut=len(mutations))
            ct.use_parent_vec(parent_vec)
            ct.mut_loc = mut_locs

            # mutation_matrix = create_mutation_matrix(ct.parent_vec, ct.mut_loc, ct)
            #
            # for mi, mut in enumerate(mutations):
            #     mut_indicator[all_loci[int(mut)]].append(mutation_matrix[mi])
            #
            # for k in mut_indicator.keys():
            #     if len(mut_indicator[k]) > 1:
            #         if not are_all_arrays_identical(mut_indicator[k]):
            #             print("Problem")
            #         else:
            #             mut_indicator[k] = [mut_indicator[k][0]]
            #
            # dataframe = {}
            #
            # for key, value in mut_indicator.items():
            #     if len(value) > 0:
            #         dataframe[key] = pd.DataFrame([value[0]], columns=all_cells)
            #     else:
            #         dataframe[key] = pd.DataFrame([np.nan] * len(all_cells)).T
            #         dataframe[key].columns = all_cells
            #
            # df = pd.concat(dataframe.values(), axis=0)
            # df.index = all_loci
            # df.to_csv(os.path.join(path, "sciterna_mut_indicator", f"sciterna_mut_indicator_{r}r{i}.csv"), index=True)

            # optimize capture rate + purity to determine the majority metastasis branch
            branching_node = 0
            branching_nodes_max = 0
            for node in range(len(parent_vec)):
                if ct.isleaf(node):
                    continue

                primary = 0
                metastasis = 0
                for sn in ct.dfs_experimental(node):
                    if ct.isleaf(sn):
                        if sn < 65:
                            primary += 1
                        elif sn >=65:
                            metastasis += 1

                capture_rate = metastasis/(n_cells - 65)
                purity = metastasis/(metastasis + primary)
                branching_nodes_new = capture_rate + purity

                if branching_nodes_max < branching_nodes_new:
                    branching_node = node
                    branching_nodes_max = branching_nodes_new
            print(branching_node, branching_nodes_max)
            if branching_nodes_max < 1.9:
                print("Potential problem (Likely wrong root): ", i, r)

            mutations_indices = mutations_node(parent_vec, mut_locs, ct, branching_node)
            mutations_branching = mutations[mutations_indices==1]
            np.savetxt(os.path.join(path, rf"sciterna_mutations_branching\mutations_branching_{r}r{i}.txt"), mutations_branching, fmt='%i')

            mutations_after_metastasis = []
            metastasis_nodes = [n for n in ct.dfs_experimental(branching_node) if ct.isleaf(n)]
            primary_nodes = [n for n in range(n_cells) if n not in metastasis_nodes]

            np.savetxt(os.path.join(path, rf"sciterna_nodes_after_branching\nodes_after_branching_{r}r{i}.txt"), metastasis_nodes, fmt='%i')
            np.savetxt(os.path.join(path, rf"sciterna_nodes_before_branching\nodes_before_branching_{r}r{i}.txt"), primary_nodes, fmt='%i')

            for ni, mut_loc in enumerate(mut_locs):
                if mut_loc in metastasis_nodes[1:]:
                    mutations_after_metastasis.append(mutations[ni])

            mutations_before_metastasis = [m for m in mutations if m not in mutations_after_metastasis]

            np.savetxt(os.path.join(path, rf"sciterna_mutations_before_branching\mutations_before_branching_{r}r{i}.txt"), mutations_before_metastasis, fmt='%i')
            np.savetxt(os.path.join(path, rf"sciterna_mutations_after_branching\mutations_after_branching_{r}r{i}.txt"), mutations_after_metastasis, fmt='%i')

In [None]:
if use_summary_statistics == False:
    primary = []
    metastasis = []
    primary_nodes = []
    metastasis_nodes = []
    branching = []
    r = 2
    for i in range(n_bootstrap):
        path_primary_node = os.path.join(path, rf"sciterna_nodes_before_branching\nodes_before_branching_{r}r{i}.txt")
        path_metastasis_node = os.path.join(path, rf"sciterna_nodes_after_branching\nodes_after_branching_{r}r{i}.txt")
        path_primary = os.path.join(path, rf"sciterna_mutations_before_branching\mutations_before_branching_{r}r{i}.txt")
        path_metastasis = os.path.join(path, rf"sciterna_mutations_after_branching\mutations_after_branching_{r}r{i}.txt")
        path_branching = os.path.join(path, rf"sciterna_mutations_branching\mutations_branching_{r}r{i}.txt")

        primary_data = np.loadtxt(path_primary, dtype=int)
        metastasis_data = np.loadtxt(path_metastasis, dtype=int)
        primary_nodes_data = np.loadtxt(path_primary_node, dtype=int)
        metastasis_nodes_data = np.loadtxt(path_metastasis_node, dtype=int)
        branching_data = np.loadtxt(path_branching, dtype=int)

        primary.extend(primary_data)
        metastasis.extend(metastasis_data)
        primary_nodes.extend(primary_nodes_data)
        metastasis_nodes.extend(metastasis_nodes_data)
        branching.extend(branching_data)
        
    unique_mutations_branching, counts_mutations_branching = np.unique(np.array(branching), return_counts=True)
    selected = np.loadtxt(os.path.join(path, "selected.txt"), delimiter=',', dtype=int)
    selected_mutations = np.unique(selected)

else:
    metastasis_nodes = np.loadtxt(os.path.join(path, "metastasis_nodes.txt"))

In [None]:
unique_cells, posterior_node_after_branching = np.unique(np.array(metastasis_nodes), return_counts=True)
heatmap_data = posterior_node_after_branching.reshape(1, -1)
plt.figure(figsize=(30, 2))
plt.imshow(heatmap_data, cmap='viridis', aspect='auto')
plt.colorbar(label='Count')
plt.xticks(ticks=np.arange(len(unique_cells)), labels=posterior_node_after_branching, rotation=90)
plt.yticks([])
plt.xlabel('Nodes')
plt.title('Counts per Cell of Appearing in Majority Metastasis Branch', fontsize=32)
plt.grid(False)

plt.show()

In [None]:
top_loci = np.loadtxt(rf"../data/results/{study_num}/sciterna_bootstrap/selected.txt", dtype=int)

In [None]:
selected_loci = reference.columns[top_loci]
selected_genes = convert_location_to_gene(selected_loci)

In [None]:
genes = []
for s in selected_genes:
    genes.extend(s)
gens, cnt = np.unique(genes, return_counts=True)

print(cnt[np.argsort(cnt)][-30:])
print(gens[np.argsort(cnt)][-30:])

In [None]:
expression_counts_scaled = pd.read_csv(rf"../data/input_data/{study_num}/normalized_expression_data.csv", index_col=0)

expression_counts_scaled['sum_counts'] = expression_counts_scaled.sum(axis=1)
df_sorted = expression_counts_scaled.sort_values(by='sum_counts', ascending=False)
top_10000 = df_sorted.head(1000)
top_10000 = top_10000.drop(columns=['sum_counts'])
expression_counts_scaled = expression_counts_scaled.drop(columns=['sum_counts'])

In [None]:
top_5_indices = np.argsort(np.mean(expression_counts_scaled, axis=1))[-20:]
top_5_rows = expression_counts_scaled.iloc[top_5_indices]
top_5_rows

In [None]:
categories = {
    "Most expressed": ["IGKV2-28", "B2M", "MALAT1"],
    "largest_difference_groups": ['IGJ', 'FCRL2', 'CD74', 'ABCA1', 'TBL1XR1', 'HLA-B'], #'CTSB'
    "selected": ['CTSB'],
    "most_mutated": ['DCC', 'PELI1', 'NBPF14', 'VPS41', 'ATP11B'] # 'CTSB'
}

filtered_df_top_genes = pd.DataFrame()

for category, genes in categories.items():
    category_genes = expression_counts_scaled[expression_counts_scaled.index.isin(genes)]
    category_row_sums = category_genes.sum(axis=1)
    top_genes = category_row_sums.nlargest(7).index
    filtered_df_top_genes = pd.concat([filtered_df_top_genes, expression_counts_scaled.loc[top_genes]])

In [None]:
sorted_indices = np.argsort(posterior_node_after_branching)
mask_less_than_65 = sorted_indices < 65

plt.figure(figsize=(40, 16))

sorted_df = filtered_df_top_genes.iloc[:, sorted_indices]
col_colors = ['red' if val else 'blue' for val in mask_less_than_65]

plt.figure(figsize=(40, 18))
ax = sns.heatmap(sorted_df, cmap='coolwarm')

for n, tick_label in enumerate(ax.get_xticklabels()):
    tick_label.set_color(col_colors[n])
        
for tick_label in ax.get_yticklabels():
    tick_label.set_fontsize(26)   
    tick_label.set_rotation(0)
    
for idx in [3,9,10]:
    plt.hlines(idx, *plt.xlim(), color='black', linewidth=4)
    
plt.text(1.02, 0.91, "Most Expressed", va='center', ha='left', fontsize=24, fontweight='bold', rotation=90, transform=plt.gca().transAxes)
plt.text(1.02, 0.58, "Largest Difference Samples", va='center', ha='left', fontsize=24, fontweight='bold', rotation=90, transform=plt.gca().transAxes)
plt.text(1.02, 0.18, "Most SNVs", va='center', ha='left', fontsize=24, fontweight='bold', rotation=90, transform=plt.gca().transAxes)
plt.title("Gene Expression Ordered by Likelihood of Cell Being in the Metastasis Branch", fontsize=50, pad=25)
plt.ylabel("Genes", fontsize=26, fontweight='bold')
plt.xlabel("Cells", fontsize=26, fontweight='bold', labelpad=10)
# plt.savefig(r"../data/results/figures/gene_expression.pdf", format="pdf")
plt.show()

In [None]:
r = 2
path = rf"../data/results/{study_num}/sciterna"
reference = pd.read_csv(rf'../data/input_data/{study_num}/ref.csv', index_col=0)
path_parent = os.path.join(path, "sciterna_parent_vec", f"sciterna_parent_vec_{r}r0.txt")
path_mut_loc = os.path.join(path, "sciterna_mutation_location", f"sciterna_mutation_location_{r}r0.txt")
path_selected = os.path.join(path, "sciterna_selected_loci", f"sciterna_selected_loci_{r}r0.txt")
parent_vec = np.loadtxt(path_parent, dtype=int)
mut_locs = np.loadtxt(path_mut_loc, dtype=int)
selected_mutations = np.loadtxt(path_selected, dtype=int)

selected_loci = reference.columns[selected_mutations]
selected_genes = convert_location_to_gene(selected_loci)

ctsb_indices = [n for n, c in enumerate(selected_genes) if "CTSB" in c]

In [None]:
n_cells = int(((len(parent_vec)+1)/2))
ct = CellTree(n_cells=n_cells, n_mut=0)
ct.use_parent_vec(parent_vec)
ct.mut_loc = mut_locs[ctsb_indices]

graph = ct.to_graphviz(gene_names=["C" for _ in range(3000)])
graph.attr(dpi='50')
graph.attr(rankdir='LR')

color_row = filtered_df_top_genes.loc["CTSB"].values
color_row = color_row / np.max(color_row)

cmap = plt.get_cmap('Greens')
colors = cmap(color_row)

for n in range(int((len(parent_vec)+1)/2)):
    rgba_color = colors[n]
    hex_color = mcolors.to_hex(rgba_color)

    if n < 65: # primary
        prob = (n_bootstrap - posterior_node_after_branching[n]) / 500
        graph.node(str(n), label="", shape='circle', style='filled', color="red", 
                   fillcolor=hex_color, fixedsize="true", width=str(prob), height=str(prob), penwidth="10")
    elif n >= 65: # metastasis
        prob = posterior_node_after_branching[n] / 500
        graph.node(str(n), label="", shape='circle', style='filled', fillcolor=hex_color, 
                   color="blue", fixedsize="true", width=str(prob), height=str(prob), penwidth="15")
graph.attr(ratio="0.23") 

graph_file = "../data/results/figures/representative_tree"
graph.render(graph_file, format='pdf', cleanup=True)

In [None]:
fig, ax = plt.subplots(figsize=(6, 1))
norm = mcolors.Normalize(vmin=0, vmax=1)
sm = cm.ScalarMappable(cmap='Greens', norm=norm)
sm.set_array([])

cbar = fig.colorbar(sm, ax=ax, orientation='horizontal', fraction=2, pad=0, aspect=5)
cbar.set_label('Gene Expression CTSB', fontsize=24)
cbar.ax.tick_params(labelsize=10)
ax.axis('off')
colorbar_path = os.path.join(path, "gene_expression_colorbar.svg")
plt.show()