In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import datetime
import pickle
import pathlib

In [None]:
import autoreload
import evotsc
import evotsc_plot
autoreload.reload(evotsc)
autoreload.reload(evotsc_plot)

In [None]:
exp_path = pathlib.Path('/Users/theotime/Desktop/evotsc/sigma_0.1_augustus/')
neutral_exp_path = pathlib.Path('/Users/theotime/Desktop/evotsc/neutral_100k/')
gen=500_000
gene_types = ['AB', 'A', 'B'] # Name of each gene type
gene_type_color = ['tab:blue', 'tab:red', 'tab:green'] #AB, A, B
orient_name = ['leading', 'lagging'] # Name of each gene orientation
rel_orient = ['conv', 'div', 'upstr', 'downstr']

In [None]:
label_fontsize=20
tick_fontsize=15
legend_fontsize=15
dpi=300

In [None]:
def get_params(exp_path):
    rep_dirs = sorted([d for d in exp_path.iterdir() if (d.is_dir() and d.name.startswith("rep"))])
    
    with open(rep_dirs[0].joinpath('params.txt'), 'r') as params_file:
        param_lines = params_file.readlines()
        
    params = {}
    for line in param_lines:
        param_name = line.split(':')[0]
        if param_name == 'commit':
            param_val = line.split(':')[1].strip()
        else:
            param_val = float(line.split(':')[1])
        
        params[param_name] = param_val
        
    return params

In [None]:
params = get_params(exp_path)

In [None]:
params

In [None]:
def get_best_indiv(rep_path, gen):
    
    with open(rep_path.joinpath(f'pop_gen_{gen:06}.evotsc'), 'rb') as save_file:
        pop_rep = pickle.load(save_file)
        
    pop_rep.evaluate()
    
    best_fit = 0
    best_indiv = pop_rep.individuals[0]
    
    try:
        for indiv in pop_rep.individuals:
            if indiv.fitness > best_fit:
                best_fit = indiv.fitness
                best_indiv = indiv
    except AttributeError: # In the neutral control, individuals are not evaluated so there is no fitness field
        pass
    
    return best_indiv

# Plot genomes

In [None]:
def plot_best_genome(exp_path, gen):
    rep_dirs = sorted([d for d in exp_path.iterdir() if (d.is_dir() and d.name.startswith("rep"))])
    
    for rep, rep_dir in enumerate(rep_dirs):
        
        best_indiv = get_best_indiv(rep_dir, gen)

        evotsc_plot.plot_genome(best_indiv, name=exp_path.joinpath(f'genome_rep{rep:02}.pdf'), print_ids=True)

In [None]:
plot_best_genome(exp_path, gen)

In [None]:
def plot_best_genome_and_tsc(exp_path, gen):
    rep_dirs = sorted([d for d in exp_path.iterdir() if (d.is_dir() and d.name.startswith("rep"))])
    
    for rep, rep_dir in enumerate(rep_dirs):
        
        best_indiv = get_best_indiv(rep_dir, gen)

        evotsc_plot.plot_genome_and_tsc(best_indiv, params['sigma_A'], show_bar=True,
                            name=exp_path.joinpath(f'genome_and_tsc_rep{rep:02}_env_A.pdf'), print_ids=True)
        evotsc_plot.plot_genome_and_tsc(best_indiv, params['sigma_B'], show_bar=True,
                            name=exp_path.joinpath(f'genome_and_tsc_rep{rep:02}_env_B.pdf'), print_ids=True)

In [None]:
plot_best_genome_and_tsc(exp_path, gen)

# Plot gene pairs

In [None]:
def get_gene_stats(indiv):
    result_dict = {'gene_id': [],
                   'gene_type': [],
                   'gene_orient': [],
                   'other_id': [],
                   'other_type': [],
                   'other_orient': [],
                   'rel_orient': [],
                   'distance': []}
    
    for i_gene, gene in enumerate(indiv.genes):
        #                                                      other - gene
        other = indiv.genes[i_gene - 1]
        result_dict['gene_id'].append(gene.id)
        result_dict['gene_type'].append(gene_types[gene.gene_type])
        result_dict['gene_orient'].append(orient_name[gene.orientation])
        result_dict['other_id'].append(other.id)
        result_dict['other_type'].append(gene_types[other.gene_type])
        result_dict['other_orient'].append(orient_name[other.orientation])
        if other.orientation == 0 and gene.orientation == 1:   # ---> <---
            rel_orient = 'conv'
        elif other.orientation == 0 and gene.orientation == 0: # ---> --->
            rel_orient = 'downstr'
        elif other.orientation == 1 and gene.orientation == 0: # <--- --->
            rel_orient = 'div'
        else:                                                  # <--- <---
            rel_orient = 'upstr'
        result_dict['rel_orient'].append(rel_orient)
        result_dict['distance'].append(other.intergene)

        #                                                       gene - other
        other = indiv.genes[(i_gene + 1) % indiv.nb_genes]
        result_dict['gene_id'].append(gene.id)
        result_dict['gene_type'].append(gene_types[gene.gene_type])
        result_dict['gene_orient'].append(orient_name[gene.orientation])
        result_dict['other_id'].append(other.id)
        result_dict['other_type'].append(gene_types[other.gene_type])
        result_dict['other_orient'].append(orient_name[other.orientation])
        if gene.orientation == 0 and other.orientation == 1:   # ---> <---
            rel_orient = 'conv'
        elif gene.orientation == 0 and other.orientation == 0: # ---> --->
            rel_orient = 'upstr'
        elif gene.orientation == 1 and other.orientation == 0: # <--- --->
            rel_orient = 'div'
        else:                                                  # <--- <---
            rel_orient = 'downstr'
        result_dict['rel_orient'].append(rel_orient)
        result_dict['distance'].append(gene.intergene)

    return pd.DataFrame.from_dict(result_dict)

In [None]:
def get_full_stats(exp_path, gen):
    rep_dirs = sorted([d for d in exp_path.iterdir() if (d.is_dir() and d.name.startswith("rep"))])
    
    full_stats = pd.DataFrame()
        
    for rep, rep_dir in enumerate(rep_dirs):
        
        best_indiv = get_best_indiv(rep_dir, gen)

        indiv_stats = get_gene_stats(best_indiv)
        indiv_stats.insert(0, 'Replicate', rep)
        full_stats = pd.concat([full_stats, indiv_stats])
            
    return full_stats

In [None]:
evol_stats = get_full_stats(exp_path, gen)

In [None]:
#neutral_stats = get_full_stats(neutral_exp_path, 100000)

In [None]:
def plot_gene_stats(stats, plot_name, count_bars):
    
    # If count_bars is true, the bars represent the count for each pair type and
    # we annotate them with the average distances; if count_bars is false, we 
    # plot the opposite: the bars represent average distances, and are annotated
    # with the count for each pair type.
    
    fig, axs = plt.subplots(3, 3, sharey='all', figsize=(12, 12), dpi=300)
    
    group_cols = ['gene_type', 'other_type', 'rel_orient']
    
    idx = pd.MultiIndex.from_product([gene_types, gene_types, rel_orient], names=group_cols)

    grouped_stats = stats.groupby(group_cols).count().reindex(idx, fill_value=0).sort_index()
    mean_stats = stats.groupby(group_cols).mean().reindex(idx, fill_value=0).sort_index()
    std_stats = stats.groupby(group_cols).mean().reindex(idx, fill_value=0).sort_index()
    
    nb_reps = stats['Replicate'].nunique()
    
    colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red']
    
    x = np.arange(len(rel_orient))

    for i_gene_type, gene_type in enumerate(gene_types):
        for i_other_type, other_type in enumerate(gene_types):
            for cur_orient in enumerate(rel_orient):
                
                if count_bars:
                    bar_data = grouped_stats.loc[(gene_type, other_type)]['Replicate'] / nb_reps
                    text_data = mean_stats.loc[(gene_type, other_type)]['distance']
                    
                else:
                    bar_data = mean_stats.loc[(gene_type, other_type)]['distance'] / nb_reps
                    text_data = grouped_stats.loc[(gene_type, other_type)]['Replicate']
                

                ax = axs[i_gene_type][i_other_type]
                
                rects = ax.bar(x, bar_data, color=colors)
                
                if count_bars:
                    ax.set_ylim(0, 20)
                #    
                #else:
                #    ax.set_ylim(0, 3250)
                
                ax.set_xticks(x)
                ax.set_xticklabels(labels=rel_orient, fontsize="large")
                ax.yaxis.set_tick_params(labelsize="large")
                
                #Label rects with # of events (deactivated for now)
                for i_rect, rect in enumerate([]): #enumerate(rects):
                    ax.annotate(f"{text_data[i_rect]:.1f}",
                                xy=(rect.get_x() + rect.get_width()/2, 0),
                                xytext=(0, 3),
                                ha='center',
                                textcoords="offset points",
                                color='black')
                
                ax.grid(b=True, axis='y', linestyle=':')

                if i_gene_type == 0: # First line
                    ax.set_title(other_type, fontsize='xx-large')
                if i_other_type == 0:
                    ax.set_ylabel(gene_type, rotation='horizontal', ha='right', fontsize='xx-large')

    plt.savefig(plot_name, dpi=300, bbox_inches='tight')
    plt.show()

In [None]:
plot_gene_stats(evol_stats, count_bars=True, plot_name=exp_path.joinpath('gene_pair_counts.pdf'))

In [None]:
plot_gene_stats(evol_stats, count_bars=False, plot_name=exp_path.joinpath('gene_pair_distances.pdf'))

In [None]:
#plot_gene_stats(neutral_stats, count_bars=True, plot_name=neutral_exp_path.joinpath('gene_pair_counts'))

In [None]:
#plot_gene_stats(neutral_stats, count_bars=False, plot_name=neutral_exp_path.joinpath('gene_pair_distances'))

# Plot intergene distribution

In [None]:
def get_intergene_distances(exp_path, gen):
    rep_dirs = sorted([d for d in exp_path.iterdir() if (d.is_dir() and d.name.startswith("rep"))])
    full_res = pd.DataFrame()
    for rep, rep_dir in enumerate(rep_dirs):

        best_indiv = get_best_indiv(rep_dir, gen)
        
        intergenes = np.zeros(best_indiv.nb_genes)
        for i_gene, gene in enumerate(best_indiv.genes):
            intergenes[i_gene] = gene.intergene
        indiv_res = pd.DataFrame(data={'Intergene':intergenes}, dtype=int)
        indiv_res.insert(0, 'Replicate', rep)

        full_res = pd.concat([full_res, indiv_res])
    
    return full_res

In [None]:
def plot_intergenes(exp_path, neutral_exp_path, gen, cutoff, plot_name):
    
    intergenes = get_intergene_distances(exp_path, gen)
    neutral_intergenes = get_intergene_distances(neutral_exp_path, 100000)
    
    bins = np.linspace(0.0, 4.0, 40)
    
    data = [intergenes, neutral_intergenes]
    titles = ['With selection', 'Without selection']
    
    plt.figure(figsize=(9, 4), dpi=300)

    for i in range(2):

        plt.subplot(1, 2, i+1)
        plt.ylim(0, 1)
        plt.xlim(-0.2, 4.2)
        plt.xlabel('Distance (log)')
        plt.ylabel('Density')
        # Plot intergene distances
        plt.hist(np.log10(data[i]['Intergene']), bins=bins, density=True)
        # Plot cutoff line
        plt.vlines(np.log10(cutoff), 0, 1, linestyle='--', linewidth=1,
                   color='tab:red', label='Cutoff distance')
        plt.grid(linestyle=':')
        plt.title(titles[i])
        plt.legend(loc='upper left')
        
    plt.savefig(plot_name + '.pdf', dpi=300)
    plt.show()

In [None]:
#plot_intergenes(exp_path, neutral_exp_path, gen=gen, cutoff=params['interaction_dist'],
#                plot_name=exp_path.joinpath('intergene_distr'))

# Explore the gene interaction graph

In [None]:
import networkx as nx

In [None]:
def plot_interaction_graph(exp_path, gen, method='spring'):
    
    rep_dirs = sorted([d for d in exp_path.iterdir() if (d.is_dir() and d.name.startswith("rep"))])
    
    for rep, rep_dir in enumerate(rep_dirs):

        indiv = get_best_indiv(rep_path=rep_dir, gen=gen)
        
        ## Build the graph
        inter_graph = nx.DiGraph()
        
        # Nodes
        for i_gene, gene in enumerate(indiv.genes):
            inter_graph.add_node(i_gene, gene=gene)
        
        # Edges
        for i_gene in range(indiv.nb_genes):
            for j_gene in range(indiv.nb_genes):
                if i_gene == j_gene:
                    continue
                if indiv.inter_matrix[i_gene][j_gene] == 0.0:
                    continue
                # influence of gene j on gene i, so the edge is j -> i
                inter_graph.add_edge(j_gene, i_gene, inter=indiv.inter_matrix[i_gene][j_gene])

        ## Draw the figure
        plt.figure(figsize=(16,16), dpi=dpi)
        plt.box(False)

        # Choose the layout
        if method == 'spring':
            layout = nx.spring_layout(inter_graph, k=0.6)
        elif method == 'circular':
            layout = nx.circular_layout(inter_graph)
        else:
            raise ValueError(f"Unknown graph layout '{method}'")
        
        # Draw the nodes
        nx.draw_networkx_nodes(inter_graph, layout, node_size=600,
                               node_color=[gene_type_color[gene.gene_type] for gene in indiv.genes])
        nx.draw_networkx_labels(inter_graph, layout)
        
        # Draw the edges
        activ_edges = [e for e in inter_graph.edges if inter_graph[e[0]][e[1]]['inter'] < 0] # Negative sigma 
        inhib_edges = [e for e in inter_graph.edges if inter_graph[e[0]][e[1]]['inter'] > 0]

        coef = 10
        activ_widths = [inter_graph[e[0]][e[1]]['inter'] * coef for e in activ_edges]
        inhib_widths = [abs(inter_graph[e[0]][e[1]]['inter']) * coef for e in inhib_edges]

        nx.draw_networkx_edges(inter_graph, layout, edgelist=inhib_edges,
                               width=inhib_widths, edge_color='tab:red', connectionstyle='arc3,rad=0.1')
        nx.draw_networkx_edges(inter_graph, layout, edgelist=activ_edges,
                               width=activ_widths, edge_color='tab:green', connectionstyle='arc3,rad=0.1')

        
        plt.savefig(exp_path.joinpath(f'genome_graph_rep{rep:02}_{method}.png'), dpi=dpi)

In [None]:
plot_interaction_graph(exp_path, gen)

In [None]:
#plot_interaction_graph(exp_path, gen, method='circular')