In [None]:
import sys

import numpy as np
import pandas as pd
import matplotlib.cm as cm
import matplotlib.pyplot as plt

from fisher import pvalue

sys.path.insert(0, "..")
from unpast.utils.method import zscore
from unpast.utils.figs import draw_heatmap2

%matplotlib inline

plt.style.use('seaborn')

In [None]:
exprs_file = "2_Pelz_timepoint_infos.norm_lib_size_nolog2.tsv"
exprs = pd.read_csv(exprs_file,sep="\t",index_col=0)

In [None]:
unpast_biclusters = pd.read_csv("DIP_norm_nolog2.consensus_seed=42.bin=kmeans,pval=0.01,clust=WGCNA,direction=DOWN-UP,ds=3,dch=0.995,max_power=10,precluster=True.biclusters.tsv", sep="\t")
unpast_biclusters.fillna("", inplace=True)

In [None]:
spec_biclusters_sel = pd.read_csv("spec_coclustering_dips_selected.tsv", sep="\t")
spec_biclusters_sel.fillna("", inplace=True)
spec_biclusters_all = pd.read_csv("spec_coclustering_dips_all.tsv", sep="\t")
spec_biclusters_all.fillna("", inplace=True)

In [None]:
def get_p_value_symbol(p: float)-> str:
    '''
        Indicates the statistical significance by strings. Is used for plots.
        :param p: p-value of the test

        :return: string indicating the significance level
    '''
    if p < 0.000001:
        return "***"
    elif p < 0.00001:
        return "** "
    elif p < 0.05:
        return " * "
    else:
        return "ns."


def bh_correction(sign_matrix: list)-> float:
    '''
        Benjamin-Hochberg correction
    '''
    # p value in ascending order
    flattened_list = [item for sublist in sign_matrix for item in sublist]
    sorted_list = sorted(flattened_list)

    m = len(sorted_list)
    Q = 0.1
    for i, p in enumerate(sorted_list):
        # calc. i/m*Q
        # i+1, because indexing starts with 0
        bh_value = ((i+1) / m) * Q

        if p < bh_value:
            # everything similar or smaller than this is significant
            sign_p = p

    return sign_p

def fisher_test(unpast_members, spec_members, N):
    '''
        Fisher's exact test
    '''
    unpast_members = set(unpast_members)
    spec_members = set(spec_members)

    # Fisher's exact test
    shared = len(unpast_members.intersection(spec_members))
    bic_only = len(unpast_members.difference(spec_members))
    group_only = len(spec_members.difference(unpast_members))
    union = shared + bic_only + group_only
    pval = pvalue(shared, bic_only, group_only, N - union)
    return pval

In [None]:
def check_cluster_similarity(unpast_df, spec_co_all_df, feature):
    all_elements = list()
    for i, row in spec_co_all_df.iterrows():
        spec_members = row[feature].split(" ")
        all_elements.extend(spec_members)
    N = len(all_elements)
    plt.figure(figsize=(5, 4))
    unpast_cl_labels = [f"bic. {i}" for i in range(9)]
    cl_labels = [f"cl. {i}" for i in range(9)]
    matrix_size = 9
    matrix = [[0] * matrix_size for _ in range(matrix_size)]

    for i, row1 in unpast_df.iterrows():
        unpast_members = row1[feature].split(" ")
        for j, row2 in spec_co_all_df.iterrows():
            spec_members = row2[feature].split(" ")
            pval = fisher_test(unpast_members, spec_members, N)
            
            use_log = False
            if use_log:
                matrix[i][j] = np.log(pval.right_tail)
            else:
                matrix[i][j] = pval.right_tail

    # make BH correction of p values and add to plot
    thresh_p_val = bh_correction(matrix)
    for i in range(matrix_size):
        for j in range(matrix_size):
            if matrix[i][j] < thresh_p_val:
                text = "*"
            else:
                text = "ns."
                
            if matrix[i][j] < 0.4:
                color = "white"
            else:
                color = "black"
            plt.annotate(text, xy=(j, i), color=color, ha='center', va='center', fontsize=8, fontweight='bold')

    plt.imshow(matrix, cmap="viridis", interpolation="nearest")
    plt.colorbar(fraction=0.046, pad=0.04, label=f"p-value")
    plt.xticks(np.arange(9), cl_labels, rotation=90)
    plt.yticks(np.arange(9), unpast_cl_labels)
    plt.xlabel("Spectral co-clustering")
    plt.ylabel("UnPaSt")
    plt.grid(False)
    plt.tight_layout()

    return matrix, thresh_p_val

sign_matrix, thresh_p_value = check_cluster_similarity(unpast_biclusters, spec_biclusters_all, "genes")

In [None]:
def plot_heatmap(biclusters, feature, bic_prefix):    
    if feature == "genes":
        features = set([])
        for fset in [biclusters.loc[x, feature] for x in biclusters.index.values]:
            features |= set(fset.split(" "))
        features = list(features)
        if "" in features:
            features.remove("")
    
        tmp = draw_heatmap2(zscore(exprs.loc[features,:]),biclusters,figsize=(5,5),
                        no_row_colors=False,cluster_rows=False,#no_cbar=True,
                        cluster_columns=False,bicluster_colors="auto",
                        xlabel="timepoints", bic_prefix=bic_prefix)
    else:
        features = np.array(["VB3-Saat", "VB3-7", "VB3-8", "VB3-9", "VB3-13", "VB3-14", "VB3-15", "VB3-16", "VB3-17", "VB3-22", "VB3-24", "VB3-25", "VB3-31", "VB3-32", "VB3-33", "VB3-38", "VB3-40", "VB3-41", "VB3-42", "VB3-45", "VB3-46", "VB3-47", "VB3-48"])
        tmp = draw_heatmap2(zscore(exprs.loc[:,features]),biclusters,figsize=(5,5),
                        no_row_colors=False,cluster_rows=False,#no_cbar=True,
                        cluster_columns=True,bicluster_colors="auto",
                        xlabel="timepoints", bic_prefix=bic_prefix)
    
    ax = tmp[0].ax_heatmap
    x_labels = ["Seed", "0.50", "0.99", "1.40", "3.46", "4.00", "4.47", "5.00", "5.48", "7.95", "8.96", "9.42", "12.43", "12.97", "13.50", "16.01", "16.97", "17.45", "18.00", "19.48", "19.99", "20.44", "21.00"]
    ax.set_xticklabels(x_labels)
    ax.set_xlabel("Sample time point (days post infection)")

    colormap = cm.get_cmap("copper")
    TCID50 = np.array([240000, 560000000, 56000000, 56000000, 3200, 42, 7600, 32000000, 5600000, 1300, 3200000, 43000000, 1300, 32000, 180000000, 3200, 560000, 180000000, 7600000, 56000, 5600, 1300, 76000000])
    log_TCID50 = np.log(TCID50)
    norm_TCID50 = log_TCID50 / max(log_TCID50)
    colors = colormap(norm_TCID50)
    for i, label in enumerate(ax.get_xticklabels()):
        label.set_color(colors[i])  # color based on the value in x

In [None]:
# including fishers test info
def calculate_matrix(unpast_df, spec_co_all_df, sign_matrix, p_thresh, feature):
    x_n = list()
    y_n = list()
    unpast_cl_labels = [f"bic. {i}" for i in range(9)]
    cl_labels = [f"cl. {i}" for i in range(9)]
    plt.figure(figsize=(4, 3))
    matrix_size = 10
    matrix = [[0] * matrix_size for _ in range(matrix_size)]
    jacc_similarities = list()
    jacc_sizes = list()
    for i, row2 in unpast_df.iterrows():
        set1 = set(row2[feature].split(" "))
        if set1 == {''}:
            set1 = set()
        for j, row1 in spec_co_all_df.iterrows():
            set2 = set(row1[feature].split(" "))
            if set2 == {''}:
                set2 = set()
            
            if i == j:
                x_n.append(len(set2))
                y_n.append(len(set1))

            # calc jaccard index
            intersection = len(set1.intersection(set2))
            union = len(set1.union(set2))
            jacc_idx = intersection / union if union != 0 else 0
            if sign_matrix[i][j] < p_thresh:
                matrix[i][j+1] = jacc_idx
                thresh = 0.75 if feature == "samples" else 0.2
                if jacc_idx < thresh:
                    color = "white"
                else:
                    color = "black"    
            else:
                matrix[i][j+1] = np.nan
                color = "black"
                    
            t = f"{round(jacc_idx, 2)}"
            if t == "0.00":
                t = "0"
            plt.annotate(t, xy=(j+1, i), color=color, ha='center', va='center', fontsize=7, fontweight='bold')

    for i in range(matrix_size-1):
        # add n data for spec. co (last row)
        matrix[matrix_size-1][i] = np.nan
        plt.annotate(x_n[i], xy=(i+1, matrix_size-1), color="black", ha='center', va='center', fontsize=8, fontweight='bold')
        # add n data for unpast (first column)
        matrix[i][0] = np.nan
        plt.annotate(y_n[i], xy=(0, i), color="black", ha='center', va='center', fontsize=8, fontweight='bold')
    matrix[matrix_size-1][matrix_size-1] = np.nan


    plt.imshow(matrix, cmap="viridis", interpolation="nearest")
    plt.colorbar(fraction=0.046, pad=0.04, label="Jaccard index")
    plt.axvline(0.5, color='black') # vertical
    plt.axhline(8.5, color='black') # horizontal
    plt.xticks(np.arange(10), [""] + cl_labels, rotation=90)
    plt.yticks(np.arange(10), unpast_cl_labels + [""])
    plt.xlabel("Spectral co-clustering")
    plt.ylabel("UnPaSt")
    plt.grid(False)
    plt.tight_layout()

In [None]:
plot_heatmap(unpast_biclusters.loc[[3,1,4,6,0,8,7,5,2],:], "genes", bic_prefix="bic.")
plt.savefig("results/unpast_heatmap.png", bbox_inches='tight', dpi=300)
plot_heatmap(spec_biclusters_sel.loc[[3,1,4,6,0,8,7,5,2],:], "genes", bic_prefix="cl.")
plt.savefig("results/spec_coclus_heatmap_selected.png", bbox_inches='tight', dpi=300)
plot_heatmap(spec_biclusters_all.loc[[3,1,4,6,0,8,7,5,2],:], "genes", bic_prefix="cl.")
plt.savefig("results/spec_coclus_heatmap_all.png", bbox_inches='tight', dpi=300)

In [None]:
calculate_matrix(unpast_biclusters, spec_biclusters_all, sign_matrix, thresh_p_value, "genes")
plt.savefig(f"results/jaccard_matrix.png", dpi=300)

## Cluster by time points

In [None]:
spec_biclusters_sel = pd.read_csv("spec_coclustering_time_selected.tsv", sep="\t")
spec_biclusters_sel.fillna("", inplace=True)
spec_biclusters_all = pd.read_csv("spec_coclustering_time_all.tsv", sep="\t")
spec_biclusters_all.fillna("", inplace=True)

In [None]:
plot_heatmap(unpast_biclusters.loc[[3,1,4,6,0,8,7,5,2],:], "samples", bic_prefix="bic.")
plt.savefig("results/unpast_heatmap_samples.png", bbox_inches='tight', dpi=300)
plot_heatmap(spec_biclusters_sel.loc[[3,1,4,6,0,8,7,5,2],:], "samples", bic_prefix="cl.")
plt.savefig("results/spec_coclus_heatmap_selected_samples.png", bbox_inches='tight', dpi=300)
plot_heatmap(spec_biclusters_all.loc[[3,1,4,6,0,8,7,5,2],:], "samples", bic_prefix="cl.")
plt.savefig("results/spec_coclus_heatmap_all_samples.png", bbox_inches='tight', dpi=300)

In [None]:
sign_matrix_samples, thresh_p_value_samples = check_cluster_similarity(unpast_biclusters, spec_biclusters_all, "samples")

In [None]:
calculate_matrix(unpast_biclusters, spec_biclusters_all, sign_matrix_samples, thresh_p_value_samples, "samples")
plt.savefig(f"results/jaccard_matrix_samples.png", dpi=300)

## Evaluate clusters based on Pelz labels

In [None]:
orig_time_data = pd.read_excel(io="Pelz_by_timepoints.xlsx",
                              sheet_name=None,
                              header=0,
                              na_values=["", "None"],
                              keep_default_na=False)["PR8"]

def assign_label(row):
    start = row["VB3-Saat"]
    end = row["VB3-48"]
    t1 = "de novo " if start == 0 else ""
    t2 = "gain" if start < end else "loss"
    return f"{t1}{t2}"

def calc_max_frac(row):
    return max(row) / sum(row) * 100

def calc_max_comb_frac(row):
    r = row.drop("max. fraction")
    max = r.idxmax()
    if max in ["gain", "de novo gain"]:
        cols = ["gain", "de novo gain"]
    else:
        cols = ["loss", "de novo loss"]
    return sum(r[cols]) / sum(r) * 100

def create_label_table(orig_time_data, clusters):
    table_dict = dict()
    table_dict["full"] = orig_time_data["label"].value_counts()
    for i, r in clusters.iterrows():
        cands = r["genes"].split(" ")
        table_dict[f"cl.{i}"] = orig_time_data.loc[orig_time_data["DI"].isin(cands), "label"].value_counts()

    df = pd.DataFrame(table_dict).T
    df.fillna(0.0, inplace=True)
    df["max. fraction"] = df.apply(calc_max_frac, axis=1)
    df["max. comb. fraction"] = df.apply(calc_max_comb_frac, axis=1)
    df = df.round(1)
    df = df[["gain", "de novo gain", "loss", "de novo loss", "max. fraction", "max. comb. fraction"]]

    return df

orig_time_data["label"] = orig_time_data.apply(assign_label, axis=1)
df = create_label_table(orig_time_data, unpast_biclusters)
df