In [None]:
import os
import pandas as pd
import numpy as np
from sklearn.manifold import MDS
import networkx as nx
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d import proj3d

In [None]:
OUT_PATH = "outs_final_nee"
os.makedirs(OUT_PATH, exist_ok=True)

DATA_PATH = "ny/fRNA/fRNA/"


ONE_COL_MM = 890
SPINE_WIDTH = 0.5

nee_fs = 7
font = {
    "family": "sans-serif",
    "sans-serif": ["Helvetica"],
    "weight": "normal",
    "size": nee_fs,
}
plt.matplotlib.rc("font", **font)
# plt.matplotlib.rcParams['text.usetex'] = True

figsize = (ONE_COL_MM / 250, 2.6/4.5 * ONE_COL_MM / 250)

In [None]:
# An example L=30
# https://togodb.biosciencedbc.jp/togodb/view/frnadb_summary#en
# DQ755244
# Blast search sequence then look up accession in fRNa
# Sequence = UGAGGAGCAUUUGUGAUGAGUCUUUCACGU
# "............((((........)))).." = test 12

accession = "DQ755244"
test = 12
sample = 8

# "..(.((..........)))..........." = sample 8
# TCGTGCTTTTCTTATTGCCAGTTCTGTTGT
accession_source = "DQ697895"

# dimensions = [2,5,12]
L = 30

print(os.path.join(DATA_PATH, "180414_r_assignment_parallel_L{1:d}_2e6_transitions/data_out/stats_L{1:d}_{0:d}.txt").format(test, L))
df_stats= pd.read_csv(os.path.join(DATA_PATH, "180414_r_assignment_parallel_L{1:d}_2e6_transitions/data_out/stats_L{1:d}_{0:d}.txt".format(test, L)), sep="\t")

source = df_stats["Source"].iloc[sample]
target = df_stats["Target"].iloc[0]

L = len(target)
print("Source: {0:s}, Target: {1:s}".format(source, target))
df_stats.head(12)

In [None]:
print(source)
print(target)

In [None]:
print("Test: {0:d}, Sample: {1:d}".format(test, sample))
fitness_fname = os.path.join(DATA_PATH, "180414_r_assignment_parallel_L{2:d}_2e6_transitions/phenotype_transitions/fitness_RNA_K-4_L-{2:d}_D-{2:d}_test-{0:d}_sample-{1:d}.txt".format(test, sample, L))
pts_fname     = os.path.join(DATA_PATH, "180414_r_assignment_parallel_L{2:d}_2e6_transitions/phenotype_transitions/phenotype_transitions_RNA_K-4_L-{2:d}_D-{2:d}_test-{0:d}_sample-{1:d}.txt".format(test, sample, L))
print(fitness_fname)
print(pts_fname)
df_fitness= pd.read_csv(fitness_fname,
                        sep="\t",
                        header=None)
df_pts    = pd.read_csv(pts_fname,
                        sep="\t",
                        header=None)

# df_fitness = df_fitness[(df_fitness[0].isin(df_pts[0]))|(df_fitness[0].isin(df_pts[1]))]
df_fitness.index = range(len(df_fitness))
print(df_fitness.shape, df_pts.shape)

In [None]:
print("-"*100)
print("Checking")
print("-"*100)
print(df_pts.shape, df_fitness.shape)
df_fitness.head()
print(df_fitness[df_fitness[0]==source])
print(df_fitness[df_fitness[0]==target])
print("Source...")
print(1.1, df_pts[df_pts[0]==source])
print(1.2, df_pts[df_pts[1]==source])
print("Target...")
print(2.1, df_pts[df_pts[0]==target])
print(2.2, df_pts[df_pts[1]==target])
print("-"*100)

In [None]:
# Check labelling system is consistent: integer = df_fitness tag
labels = dict(zip(df_fitness.index,df_fitness[0]))
labels_lookup = dict((v,k) for k,v in labels.items())
source_p = labels_lookup[source]
target_p = labels_lookup[target]

In [None]:
def LoadGraph2(df, ll):
    vals = df.values
    data = []
    for i in range(len(vals)):
        from_node = ll[vals[i,0]]
        to_node   = ll[vals[i,1]]
        count     = vals[i,2]
        data.append([from_node, to_node, count])
    dft = pd.DataFrame(data, columns=["from","to","count"])
    return dft

In [None]:
# Get graph as dataframe and number of phenotypes
dft = LoadGraph2(df_pts, labels_lookup)
Np = list(set([dft.iloc[i,0] for i in range(len(dft))] + [dft.iloc[i,1] for i in range(len(dft))]))

In [None]:
def get_hamming_distances(df_fitness):
    integer_version = (
        df_fitness[0]
        .str.replace("(", "0", regex=False)
        .str.replace(")", "1", regex=False)
        .str.replace(".", "2", regex=False)
        .apply(lambda s: " ".join(s))
        .apply(lambda s: np.fromstring(s, sep=" ", dtype=int))
    )
    integer_version = np.stack(integer_version.to_numpy(), axis=0)
    a1 = integer_version[:, None, :]
    a2 = integer_version[None, :, :]
    hamming_distances = np.count_nonzero(a1 != a2, axis=-1)
    return hamming_distances/integer_version.shape[1]

hamming = get_hamming_distances(df_fitness)
print(hamming)

In [None]:
# Get MDS coordinates
np.random.seed(0)

if not os.path.exists(f"{OUT_PATH}/fig4a_embedding.npy"):
    embedding = MDS(n_components=2, dissimilarity='precomputed', random_state=0)
    X_transformed = embedding.fit_transform(hamming)
    X_transformed_backup = X_transformed.copy()
    np.save(f"{OUT_PATH}/fig4a_embedding.npy", X_transformed_backup)
else:
    X_transformed_backup = np.load(f"{OUT_PATH}/fig4a_embedding.npy")
    X_transformed = X_transformed_backup.copy()

In [None]:
X_transformed[:,1] = X_transformed_backup[:,0]
X_transformed[:,0] = X_transformed_backup[:,1] 

In [None]:
pos = {}
for i, el in enumerate(X_transformed):
    pos[i] = tuple(el)

In [None]:
def MakeGraph(dft):
    G = nx.DiGraph()
    for index, row in dft.iterrows():
        G.add_edge(row["from"], row["to"], weight=row["count"])
    return G

def DrawGraph(G, pos, fitness, X, ax, cmap_name='jet', pad=1.2, node_size=50, edge_alpha=0.5):
    x_max = 0
    y_max = 0
    for el in X:
        if np.abs(el[0])>x_max:
            x_max=np.abs(el[0])
        if np.abs(el[1])>y_max:
            y_max=np.abs(el[1])
        
    x_max=pad*x_max
    y_max=pad*y_max
    cmap = plt.get_cmap(cmap_name)
    colors = [tuple([el for el in list(cmap(fitness[el]))[:3]]) for el in G.nodes()]
    nodes = nx.draw_networkx_nodes(G, pos=pos, ax=ax, nodelist=G.nodes(), node_color = colors, node_size=node_size)
    edges = nx.draw_networkx_edges(G, pos=pos, ax=ax, width=0.1, color="grey", style="-", arrows=False, alpha=edge_alpha)
    ax.set_xlim(-x_max, x_max)
    ax.set_ylim(-y_max, y_max)
    ax.axis('off')

In [None]:
# Make graph to plot
fitness = dict(zip(list(df_fitness.index), list(df_fitness[1])))
G = MakeGraph(dft)

for i in range(len(df_fitness)):
    G.add_node(i)

In [None]:
def Draw3DGraph2(G, fitness, X, ax, cmap='jet', labels=None, bbox=None, scale_edges=False,
                simple_paths=False, label_st=None, min_width=0.5, scaler=5,
                pad=1.2, node_size=50, edge_alpha=0.5, text_color="k", view_init=20,
                 extra=0.01, colorbar=False, shortest_simple_path=False, plot_all_points=True, plot_all_lines=False,
                 fs=nee_fs, shrink=0.5, aspect=20, cpad=None):

    
    ax.view_init(view_init, -45)
    ax.set_proj_type('ortho')
    ax.dist = 12
    r = np.sqrt((node_size))
    
    if plot_all_points:
        for node in G.nodes():
            pnt3d = ax.scatter(X[node, 0],
                               X[node, 1],
                               fitness[node],
                               color="grey",
                               alpha=edge_alpha,
                               s=node_size, zorder=-1000, linewidth=0)
        

    if colorbar:
        import matplotlib as mpl
        from matplotlib.colors import Normalize
        from matplotlib import cm

        norm = Normalize(vmin=0, vmax=1)
        n_cmap = cm.ScalarMappable(norm=norm, cmap=cmap)
        cb = plt.colorbar(n_cmap, ax=ax, shrink=shrink, aspect=aspect, pad = cpad)#, cmap=plt.matplotlib.cm.Reds)
        cb.ax.tick_params(length=1.5)
        # [spine.set_linewidth(SPINE_WIDTH) for name, spine in cb.ax.spines.items()]
    
    sps_edges = set()
    try:
        sps = nx.all_simple_paths(G, source_p, target_p, cutoff=40)
        for path in sps:
            for i in range(len(path[:-1])):
                edge = (path[i],path[i+1])
                sps_edges.add(edge)
    except:
        pass
    sps_edges = list(sps_edges)
    
    if simple_paths:
        grey_edges = [edge for edge in G.edges() if edge not in sps_edges]
        blue_edges = [edge for edge in G.edges() if edge in sps_edges]
    else:
        grey_edges = G.edges()
    
    try:
        sss = list(nx.shortest_simple_paths(G, source_p, target_p))
    except:
        pass
    
    if shortest_simple_path:
        for path in sss:
            uphill = True
            sss_edges = set()
            for i in range(len(path[:-1])):
                edge = (path[i],path[i+1])
                print(path[i], path[i+1], fitness[path[i]], fitness[path[i+1]])
                if fitness[path[i]] <= fitness[path[i+1]]:
                    sss_edges.add(edge)
                else:
                    uphill = False
                    break
            if uphill:
                break
        sss_edges = [edge for edge in G.edges() if edge in sss_edges]
        print(sss_edges)
    
    edge_widths = {}
    ws = []
    for edge in G.edges():
        ws.append(G.edges()[edge]["weight"])
    maxw = np.max(ws)
    minw = np.min(ws)
    ws = [(el-minw+1e-6)/(maxw-minw+1e-6) for el in ws]

    for i,edge in enumerate(G.edges()):
        if scaler != 1:
            edge_widths[edge] = scaler*ws[i] + min_width
        else:
            edge_widths[edge] = 1

    sss_col = "red"
    path_col = "red"
    nopath_col = "grey"
    
    if plot_all_lines:
        for edge in grey_edges:
            coords = [[X[edge[0], 0], X[edge[1], 0]],
                      [X[edge[0], 1], X[edge[1], 1]],
                      [fitness[edge[0]], fitness[edge[1]]]]
            ax.plot(coords[0], coords[1], coords[2],
                    '-',
                    color=nopath_col,
                    linewidth=edge_widths[edge],
                    alpha=edge_alpha)
        
    if simple_paths == True and len(blue_edges)>0:
        for edge in blue_edges:
            coords = [[X[edge[0], 0], X[edge[1], 0]],
                      [X[edge[0], 1], X[edge[1], 1]],
                      [fitness[edge[0]], fitness[edge[1]]]]
            ax.plot(coords[0], coords[1], coords[2],
                    '-',
                    color=path_col,
                    linewidth=edge_widths[edge],
                    alpha=0.5)
            
    if shortest_simple_path and len(sss_edges)>0:
        sss_nodes = list(set([edge[0] for edge in sss_edges]+[edge[1] for edge in sss_edges]))
        for node in sss_nodes:
            pnt3d = ax.scatter(X[node, 0],
                               X[node, 1],
                               fitness[node],
                               color=cmap(fitness[node]),
                               s=node_size, zorder=102)
        
        
        for edge_i, edge in enumerate(sss_edges):
            coords = [[X[edge[0], 0], X[edge[1], 0]],
                      [X[edge[0], 1], X[edge[1], 1]],
                      [fitness[edge[0]], fitness[edge[1]]]]
            ax.plot(coords[0], coords[1], coords[2],
                    '-',
                    color=sss_col,
                    linewidth=edge_widths[edge]*3,
                    alpha=1, zorder=1, solid_capstyle='round')
            
            
    ax.set_xlim(-x_max, x_max)
    ax.set_ylim(-y_max, y_max)

    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zticks([0.0,0.5,1.0])
    ax.set_zticklabels([0.0,0.5,1.0])

    ax.text(X_transformed[source_p, 0]+0*extra,
            X_transformed[source_p, 1]+-4*extra,
            fitness[source_p]+0*extra, 
            '%s' % ("S"), size=fs, zorder=1e10, color='k', ha="center",va="top")
    if target_p in G.nodes():
        ax.text(X_transformed[target_p, 0]+0*extra,
                X_transformed[target_p, 1]+0*extra,
                fitness[target_p]+0*extra,
                '%s' % ("T"), size=fs, zorder=1e10, color='k', ha="center", va="bottom")

In [None]:
# Get x and y maxes
x_max = 0
y_max = 0
for el in X_transformed:
    if np.abs(el[0])>x_max:
        x_max=np.abs(el[0])
    if np.abs(el[1])>y_max:
        y_max=np.abs(el[1])

In [None]:
# Set up matplotlib config
# font = {'family' : 'serif',
#         'serif'  : ['Computer Modern Roman'],
#         'weight' : 'normal',
#         'size'   : 10}

nee_fs = 7
font = {
    "family": "sans-serif",
    "sans-serif": ["Helvetica"],
    "weight": "normal",
    "size": nee_fs,
}
plt.matplotlib.rc('font', **font)
# plt.matplotlib.rcParams['text.usetex'] = True
# plt.matplotlib.rcParams['pgf.texsystem'] = 'pdflatex'

fig = plt.figure()
fig.set_size_inches(6, 5)
scaling = 1.5
fig.set_size_inches(ONE_COL_MM/250 * scaling, ONE_COL_MM / 250 * 5/6 * scaling)
ax0 = fig.add_subplot(111, projection='3d')

cmap = plt.get_cmap('Reds')
bbox_props2 = dict(boxstyle="round", fc="w", ec="0.5", alpha=0.1)

test_els = 1000

Draw3DGraph2(G, fitness, X_transformed, ax0, cmap=cmap,
             edge_alpha=0.05,
             node_size=8, extra=0.01,
             colorbar=True, simple_paths=False, shortest_simple_path=True,
             # plot_all_points=False,
             plot_all_points=True,
             pad=1.2,
             shrink=0.5, aspect=17, cpad=0.025
            )

ax0.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
ax0.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
ax0.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))

# ax0.grid(False)
ax0.grid(True)
ax0.w_zaxis.gridlines.set_alpha(0.2)

ax0.set_xlabel("MDS1", labelpad=-10)
ax0.set_ylabel("MDS2", labelpad=-10)
ax0.set_zlabel("Fitness", labelpad=-4)

# for ax in [ax0.w_xaxis, ax0.w_yaxis, ax0.w_zaxis]:
    # ax.line.set_linewidth(SPINE_WIDTH)
    
ax0.tick_params(axis="z",direction="out", pad=-1)
# ax0.text(ax0.get_xlim()[0], ax0.get_ylim()[0], 1.5,
#         "S:  {source:s} acc: {accession_s:s}\nT:  {target:s} acc: {accession:s}\n".format(source=labels[source_p],
#                                                                                           target=labels[target_p],
#                                                                                           accession=accession,
#                                                                                           accession_s=accession_source
#                                                                                          ),
#          size=12, zorder=100, color='k', ha="left",va="top", bbox=bbox_props2, fontsize='small')
fig.savefig(f"{OUT_PATH}/fig4A.pdf", transparent=True, bbox_inches='tight')
fig.savefig(f"{OUT_PATH}/fig4A.svg", transparent=True, bbox_inches='tight')
fig.savefig(f"{OUT_PATH}/fig4A.png", transparent=True, bbox_inches='tight', dpi=300)

plt.show()