In [None]:
import time
import os
import math
import multiprocessing as mp
import numpy as np
from numpy import dot
from numpy.linalg import norm
import networkx as nx
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from torch import tensor
from torch.optim import Adam
from sklearn.model_selection import StratifiedKFold
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from tqdm import tqdm
import pdb
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import scipy.sparse as ssp
import bisect
import seaborn as sns
from datamodules import CpGGraphDataModule
from graphcpg import CpGGraph, CpGGraphAnalysis
from matplotlib.colors import ListedColormap
import matplotlib.colors as mcolors
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.lines import Line2D
import matplotlib.gridspec as gridspec
from matplotlib.spines import Spine
import matplotlib.lines as lines
import matplotlib.collections as collections
halfwidth = 85 #mm3.34inches
width = 170 #mm 6.69inches
maxheight = 225 #mm 8.86inches
#300dpi
from pathlib import Path
TimesItalic = Path("/usr/share/fonts/truetype/msttcorefonts/Times_New_Roman_Italic.ttf")

# HCC Example

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
segment_size = 21
val_keys = ['chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19']
test_keys = ['chr2', 'chr4', 'chr6', 'chr8', 'chr10', 'chr12']
batch_size = 10240
n_workers = 12
debug_mode = False
res_dir = ""# your results saving path
data_path = ""# your datasets path
dataset_dict = {"HCC":"y_HCC.npz"}# an example of HCC datasets name
ckpt_dict = {"HCC":"/HCC_epoch=9-step=2800.ckpt"}# an example of trained model parameters file name of HCC
pos_dict = {"HCC":"pos_HCC.npz"}# an example of HCC datasets name


In [None]:
cell_text = ['G1593767_Ca_01_RRBS.single.CpG.txt',
'G1593769_Ca_02_RRBS.single.CpG.txt',
'G1593771_Ca_03_RRBS.single.CpG.txt',
'G1593773_Ca_04_RRBS.single.CpG.txt',
'G1593775_Ca_05_RRBS.single.CpG.txt',
'G1593777_Ca_06_RRBS.single.CpG.txt',
'G1593779_Ca_07_RRBS.single.CpG.txt',
'G1593781_Ca_08_RRBS.single.CpG.txt',
'G1593783_Ca_09_RRBS.single.CpG.txt',
'G1593785_Ca_10_RRBS.single.CpG.txt',
'G1593787_Ca_11_RRBS.single.CpG.txt',
'G1593789_Ca_12_RRBS.single.CpG.txt',
'G1593791_Ca_13_RRBS.single.CpG.txt',
'G1593793_Ca_14_RRBS.single.CpG.txt',
'G1593795_Ca_15_RRBS.single.CpG.txt',
'G1593797_Ca_16_RRBS.single.CpG.txt',
'G1593799_Ca_17_RRBS.single.CpG.txt',
'G1593801_Ca_18_RRBS.single.CpG.txt',
'G1593803_Ca_19_RRBS.single.CpG.txt',
'G1593805_Ca_20_RRBS.single.CpG.txt',
'G1593807_Ca_21_RRBS.single.CpG.txt',
'G1593809_Ca_22_RRBS.single.CpG.txt',
'G1593811_Ca_23_RRBS.single.CpG.txt',
'G1593813_Ca_24_RRBS.single.CpG.txt',
'G1593815_Ca_25_RRBS.single.CpG.txt']
cell_rank = [1,10,11,12,13,14,15,16,17,18,19,2,20,21,22,23,24,25,3,4,5,6,7,8,9]
cell_text = [i[:-20] for i in cell_text]
cell_text_ranked = [cell_text[r-1] for r in cell_rank]

In [None]:
def PyGGraph_to_nx(data):
    edges = list(zip(data.edge_index[0, :].tolist(), data.edge_index[1, :].tolist()))
    g = nx.from_edgelist(edges)
    g.add_nodes_from(range(len(data.x)))  # in case some nodes are isolated
    # transform r back to rating label
    edge_types = {(u, v): data.edge_type[i].item() for i, (u, v) in enumerate(edges)}
    nx.set_edge_attributes(g, name='methylation', values=edge_types)
    node_types = dict(zip(range(data.num_nodes), torch.argmax(data.x, 1).tolist()))
    nx.set_node_attributes(g, name='lociandcells', values=node_types)
    g.graph['states'] = data.y.item()
    return g

In [None]:
for set_i in dataset_dict.keys():
    num=1

    y = np.load(data_path + dataset_dict[set_i])
    pos_raw = np.load(data_path + pos_dict[set_i])

    datamodule = CpGGraphDataModule(y, segment_size=segment_size,
                                    val_keys=val_keys, test_keys=test_keys, 
                                    batch_size=batch_size, n_workers=n_workers,
                                    debug_mode=debug_mode, cell_nums=False, save_cell_id=True)
    datamodule.setup()
    test_graphs = datamodule.test
    graphs = test_graphs
    data_name = set_i + "_visual"
    sort_by='prediction'#'random'#'true'#'prediction'
    # load from exists
    preload = set_i+'_RYCHRPOSCELL'
    preload_path = res_dir + "/vanilla/" + preload
    preload_path_npz = preload_path + '.npz'
    model_ckpt = res_dir + ckpt_dict[set_i]
    if os.path.exists(preload_path_npz):
        print("Loading ...")
        npzfile = np.load(preload_path_npz, allow_pickle=True)
        R = npzfile['arr_0']
        Y = npzfile['arr_1']
        CHR = npzfile['arr_2']
        POS = npzfile['arr_3']
        CELL = npzfile['arr_4']
    else:
        #generate
        print('Generating ...')
        model = CpGGraph.load_from_checkpoint(model_ckpt)
        model.eval()
        model.to(device)
        R = []
        Y = []
        CHR = []
        POS = []
        CELL = []
        #TODO graph_loader = DataLoader(graphs, 50, shuffle=False)
        graph_loader = DataLoader(graphs,
                                    batch_size=10240, shuffle=False,
                                    pin_memory=True)
        for data in tqdm(graph_loader):
            data = data.to(device)
            raw_r = model.forward(data)
            r = torch.sigmoid(raw_r).detach()
            y = data.y
            chr = data.chr
            loc_pos = data.m
            cell_id = data.n
            R.extend(r.view(-1).tolist())
            Y.extend(y.view(-1).tolist())
            CHR.extend(chr.view(-1).tolist())
            POS.extend(loc_pos.view(-1).tolist())
            CELL.extend(cell_id)
            np.savez(preload_path, R, Y, CHR, POS, CELL)
    if sort_by == 'true':  # sort graphs by their true ratings
        order = np.argsort(Y).tolist()
    elif sort_by == 'prediction':
        order = np.argsort(R).tolist()
    elif sort_by == 'random':  # randomly select graphs to visualize
        order = np.random.permutation(range(len(R))).tolist()
    #load from path
    preload_2 = set_i+'_CONCAT_STATES'
    preload_path_2 = res_dir + "/details/" + preload_2
    preload_path_npz_2 = preload_path_2 + '.npz'
    if os.path.exists(preload_path_npz_2):
        print("Loading details ...")
        npzfile_2 = np.load(preload_path_npz_2, allow_pickle=True)
        CONCAT_STATES = npzfile_2['arr_0']
    else:
        #generate
        print('Generating details ...')
        model2 = CpGGraphAnalysis.load_from_checkpoint(model_ckpt)
        model2.eval()
        model2.to(device)
        CONCAT_STATES = []
        #TODO graph_loader = DataLoader(graphs, 50, shuffle=False)
        select_graphs = [(graphs[i]) for i in order[-num:][::-1]]+[(graphs[i]) for i in order[:num]]
        print('select_graphs: ', select_graphs)
        selected_graph_loader = DataLoader(select_graphs,
                                    batch_size=10240, shuffle=False,
                                    pin_memory=True)
        for data in tqdm(selected_graph_loader):
            data = data.to(device)
            concat_states = model2.forward(data)
            concat_states = concat_states.detach().tolist()
            CONCAT_STATES.extend(concat_states)
            np.savez(preload_path_2, CONCAT_STATES)
    num_highest = [order[-num]]
    num_lowest = [order[num]]
    num_nodes = graphs[0].num_nodes
    cell_num = num_nodes - segment_size
    highest = [PyGGraph_to_nx(graphs[i]) for i in num_highest]
    lowest = [PyGGraph_to_nx(graphs[i]) for i in num_lowest]
    highest_scores = [100*R[i] for i in num_highest]
    lowest_scores = [100*R[i] for i in num_lowest]
    scores = np.around((highest_scores + lowest_scores), 2)
    highest_ys = [Y[i] for i in num_highest]
    lowest_ys = [Y[i] for i in num_lowest]
    label_dict = {True:"Methylated",False:"Unmethylated"}
    highest_ys = [label_dict[i] for i in highest_ys]
    lowest_ys = [label_dict[i] for i in lowest_ys]
    ys = highest_ys + lowest_ys
    highest_chrs = [CHR[i] for i in num_highest]
    lowest_chrs = [CHR[i] for i in num_lowest]
    chrs = highest_chrs + lowest_chrs   
    highest_poses = [POS[i] for i in num_highest]
    lowest_poses = [POS[i] for i in num_lowest]
    poses = highest_poses + lowest_poses  
    highest_cells = [CELL[i] for i in num_highest]
    lowest_cells = [CELL[i] for i in num_lowest]
    cells = highest_cells + lowest_cells  
    node_to_color = dict(zip(list(range(segment_size+1)), 
                        ['white']+[plt.cm.Blues(0.3+0.7*(seg/segment_size)) for seg in range(segment_size)]))
    edge_to_color = {0: '#fdbf9c', 1: '#5ac7a2', 2: 'white'}
    edge_to_hide = {0: 1, 1: 1, 2: 0}
    plt.axis('off')
    f = plt.figure(figsize=(6.69, 2.45), dpi=300)
    gs = gridspec.GridSpec(4, 12)
    gs.update(left=0.02, right=0.98, top=0.95, bottom=0.05, wspace=0.01)
    for i, g in enumerate(highest + lowest):
        u_nodes = [x for x, y in g.nodes(data=True) if y['lociandcells'] != 0]
        u0, v0 = 10, cells[i] + 21
        v_nodes = [x for x, y in g.nodes(data=True) if y['lociandcells'] == 0]
        pos = nx.drawing.layout.bipartite_layout(g, v_nodes)
        set_axs = f.add_subplot(gs[:, 2+6*i:4+6*i])
        set_axs.axis('off')
        if ys[i] == 'Methylated':
            g[u0][v0]['methylation'] = 1
        edge_types = nx.get_edge_attributes(g, 'methylation')
        edge_colors = [edge_to_color[edge_types[x]] for x in g.edges()]
        edge_alphas = [edge_to_hide[edge_types[x]] for x in g.edges()]
        node_labels = {x: y for x, y in nx.get_node_attributes(g, 'lociandcells').items()}
        node_colors = [node_to_color[node_labels[x]] for x in g.nodes()]
        chr_node_labels = [pos_raw.files[chrs[i]] + ":" + str(pos_raw[pos_raw.files[chrs[i]]][pos_range]) for pos_range in np.arange(poses[i]-10,poses[i]+11,1).tolist()]
        draw_nodes = nx.draw_networkx_nodes(g, pos, 
                node_size=50, 
                node_color=node_colors, 
                linewidths=1,
                edgecolors='black')
        nx.draw_networkx_edges(g, pos,
                width=0.75,
                edge_color=edge_colors,
                alpha=edge_alphas,
                )
        nx.draw_networkx_edges(g, {v0: pos[v0], u0: pos[u0]}, edgelist=[(v0,u0)], width=1,
                edge_color='#ff6161', ax=set_axs, style='--')
        set_axs.text((pos[v0][0]+pos[u0][0])/2,(pos[v0][1]+pos[u0][1])/2,s="m$_t$",font=TimesItalic, fontsize=5)
        for ip, locus_label in enumerate(chr_node_labels):
            px,py = pos[ip]
            px_s, py_s = 0.15, -0.01
            if locus_label == (pos_raw.files[chrs[i]] + ":" + str(pos_raw[pos_raw.files[chrs[i]]][poses[i]])):
                tar_px, tar_py = px+px_s, py+py_s
                set_axs.text(tar_px, tar_py,s=locus_label,color='#ff6161',fontsize=5)
                set_axs.text(px-0.05,py-0.02,s="u$_t$",font=TimesItalic, fontsize=5)
            else:
                set_axs.text(px+px_s,py+py_s,s=locus_label,fontsize=5)
        set_axs.text(px-0.17,py+0.05,s="Loci",fontsize=5, fontweight="bold")
        for ic, cell_label in enumerate(cell_text_ranked):
            cx,cy = pos[ic+21]
            cx_s, cy_s = -1.32, -0.015
            if cell_label == cell_text_ranked[cells[i]]:
                tar_cx, tar_cy = cx+cx_s,cy+cy_s
                set_axs.text(tar_cx,tar_cy,s=cell_label,color='#ff6161',fontsize=5)
                set_axs.text(cx-0.05,cy-0.02,s="v$_t$",font=TimesItalic , fontsize=5)
            else:
                set_axs.text(cx+cx_s,cy+cy_s,s=cell_label,fontsize=5)
        set_axs.text(cx-0.2,cy+0.05,s="Cells",fontsize=5, fontweight="bold")
        set_axs.set_title('Target state: {:}'.format(ys[i]), loc='center', y=0.95, fontsize=8)
    for i, g in enumerate(highest + lowest):
        corrweight = np.array(CONCAT_STATES[i]).squeeze()[:21,:]
        tar_arr = corrweight[10,:]
        corr_list = []
        #cosine similarity
        for arr in corrweight:
            cos = dot(arr, tar_arr)/(norm(arr)*norm(tar_arr))
            cos_sim = 0.5 + 0.5 * cos
            corr_list.append(cos_sim)
        x = list(np.arange(1,22,1).astype(np.int8))
        sns.set_theme(context='paper')
        gs_sub = gridspec.GridSpecFromSubplotSpec(100,1,subplot_spec=gs[:, 5+6*i:6+6*i])
        ax = f.add_subplot(gs_sub[2:96])
        ax.grid(True)
        ax.set_xticks([-1, 0, 1])
        ax.tick_params(axis='x', labelsize=5)
        ax.tick_params(axis='y', labelsize=5, pad=0.05)
        ax.yaxis.set_ticklabels([])
        cmap_dict = {'red':   [(0.0,  1.0, 1.0),
                            (1.0,  1.0, 1.0),
                            ],
                    'green': [(0.0,  1.0, 1.0),
                            (1.0,  0.3804, 0.3804),
                            ],
                    'blue':  [(0.0,  1.0, 1.0),
                            (1.0,  0.3804, 0.3804),
                            ]}
        cmap = LinearSegmentedColormap('white_to_red', cmap_dict)
        similarity = [corr*corr*corr for corr in corr_list]
        bar_colors = cmap(similarity)
        ax.set_title("Similarity" , fontsize=5)
        barplot = ax.barh(x, corr_list, color=bar_colors)
        barplot[cells[i]].set_color('#ff8696')
        barplot[cells[i]].set_edgecolor('white')
        corrweight2 = np.array(CONCAT_STATES[i]).squeeze()[21:,:]
        tar_arr2 = corrweight2[cells[i],:]
        corr_list2 = []
        #cosine similarity
        for arr2 in corrweight2:
            cos = dot(arr2, tar_arr2)/(norm(arr2)*norm(tar_arr2))
            cos_sim = 0.5 + 0.5 * cos
            corr_list2.append(cos_sim)
        x2 = ["cell"+chr_num[-2:] for chr_num in cell_text_ranked]
        sns.set_theme(context='paper')
        gs_sub2 = gridspec.GridSpecFromSubplotSpec(100,1,subplot_spec=gs[:, 0+6*i:1+6*i])
        ax2 = f.add_subplot(gs_sub2[3:97])
        ax2.grid(True)
        ax2.invert_xaxis()
        ax2.set_xticks([-1, 0, 1])
        ax2.tick_params(axis='x', labelsize=5)
        ax2.yaxis.set_ticklabels([])
        cmap_dict = {'red':   [(0.0,  1.0, 1.0),
                            (1.0,  1.0, 1.0),
                            ],
                    'green': [(0.0,  1.0, 1.0),
                            (1.0,  0.3804, 0.3804),
                            ],
                    'blue':  [(0.0,  1.0, 1.0),
                            (1.0,  0.3804, 0.3804),
                            ]}
        cmap = LinearSegmentedColormap('white_to_red', cmap_dict)
        similarity = [corr*corr*corr for corr in corr_list2]
        bar_colors = cmap(similarity)
        ax2.set_title("Similarity" , fontsize=5)
        barplot = ax2.barh(x2, corr_list2, color=bar_colors)
        barplot[cells[i]].set_color('#ff8696')
        barplot[cells[i]].set_edgecolor('white')
    legend_list = [Line2D([0], [0], linestyle='--', color='#ff6161', label='Target state', lw=1),
                    Line2D([0], [0], linestyle='-', color='#5ac7a2', label='Methylated', lw=1),
                    Line2D([0], [0], linestyle='-', color='#fdbf9c', label='Unmethylated', lw=1)]
    legend_name_list = ['Target state','Methylated','Unmethylated']   
    leg = f.legend(legend_list,legend_name_list, ncol=4, loc='lower center', bbox_to_anchor=(0, -0.01, 1, 1), prop={'size':5})
    leg.get_frame().set_alpha(1.0)
    f.savefig(os.path.join(res_dir, "visualization_{}_{}.pdf".format(data_name, sort_by)))