In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.ticker as ticker
import networkx as nx
import datetime
import pickle
import pathlib

In [None]:
import autoreload
import evotsc
import evotsc_plot
autoreload.reload(evotsc)
autoreload.reload(evotsc_plot)

In [None]:
exp_path = pathlib.Path('/Users/theotime/Desktop/evotsc/new-sigma-c/inter-500/')
gen=500_000
gene_types = ['AB', 'A', 'B'] # Name of each gene type
gene_type_color = ['tab:blue', 'tab:red', 'tab:green'] #AB, A, B
orient_name = ['leading', 'lagging'] # Name of each gene orientation
rel_orient = ['conv', 'div', 'upstr', 'downstr']

In [None]:
label_fontsize=20
tick_fontsize=15
legend_fontsize=15
dpi=300

In [None]:
def get_params(exp_path):
    rep_dirs = sorted([d for d in exp_path.iterdir() if (d.is_dir() and d.name.startswith("rep"))])
    
    with open(rep_dirs[0].joinpath('params.txt'), 'r') as params_file:
        param_lines = params_file.readlines()
        
    params = {}
    for line in param_lines:
        param_name = line.split(':')[0]
        if param_name == 'commit':
            param_val = line.split(':')[1].strip()
        elif param_name == 'neutral':
            param_val = (line.split(':')[1] == True)
        elif param_name == 'selection_method':
            param_val = line.split(':')[1].strip()
        else:
            param_val = float(line.split(':')[1])
        
        params[param_name] = param_val
        
    return params

In [None]:
params = get_params(exp_path)

In [None]:
params

In [None]:
def get_best_indiv(rep_path, gen):
    
    with open(rep_path.joinpath(f'pop_gen_{gen:06}.evotsc'), 'rb') as save_file:
        pop_rep = pickle.load(save_file)
        
    pop_rep.evaluate()
    
    best_fit = 0
    best_indiv = pop_rep.individuals[0]
    
    try:
        for indiv in pop_rep.individuals:
            if indiv.fitness > best_fit:
                best_fit = indiv.fitness
                best_indiv = indiv
    except AttributeError: # In the neutral control, individuals are not evaluated so there is no fitness field
        pass
    
    return best_indiv

# Plot genomes

In [None]:
def plot_best_genome_and_tsc(exp_path, gen):
    rep_dirs = sorted([d for d in exp_path.iterdir() if (d.is_dir() and d.name.startswith("rep"))])
    
    for rep, rep_dir in enumerate(rep_dirs):
        
        best_indiv = get_best_indiv(rep_dir, gen)

        evotsc_plot.plot_genome_and_tsc(best_indiv, params['sigma_A'], show_bar=True,
                            name=exp_path.joinpath(f'genome_and_tsc_rep{rep:02}_env_A.pdf'), print_ids=True)
        evotsc_plot.plot_genome_and_tsc(best_indiv, params['sigma_B'], show_bar=True,
                            name=exp_path.joinpath(f'genome_and_tsc_rep{rep:02}_env_B.pdf'), print_ids=True)

In [None]:
plot_best_genome_and_tsc(exp_path, gen)

# Get the genome stats

In [None]:
def get_gene_stats(indiv):
    result_dict = {'gene_id': [],
                   'gene_type': [],
                   'gene_orient': [],
                   'other_id': [],
                   'other_type': [],
                   'other_orient': [],
                   'rel_orient': [],
                   'distance': []}
        
    for i_gene, gene in enumerate(indiv.genes):
        #                                                       other  -  gene
        other = indiv.genes[i_gene - 1]        
        # As in evotsc.compute_inter_matrix(), the relevant distance is
        # the distance between the promoter of the focal gene and the 
        # middle of the other gene.
        if other.orientation == 0 and gene.orientation == 1:   # --|--> <----|
            rel_orient = 'conv'
            gene_dist = other.length // 2 + gene.length
    
        elif other.orientation == 0 and gene.orientation == 0: # --|--> |---->
            rel_orient = 'downstr'
            gene_dist = other.length // 2

        elif other.orientation == 1 and gene.orientation == 0: # <--|-- |---->
            rel_orient = 'div'
            gene_dist = other.length // 2

        else:                                                  # <--|-- <----|
            rel_orient = 'upstr'
            gene_dist = other.length // 2 + gene.length

        total_dist = other.intergene + gene_dist
        result_dict['gene_id'].append(gene.id)
        result_dict['gene_type'].append(gene_types[gene.gene_type])
        result_dict['gene_orient'].append(orient_name[gene.orientation])
        result_dict['other_id'].append(other.id)
        result_dict['other_type'].append(gene_types[other.gene_type])
        result_dict['other_orient'].append(orient_name[other.orientation])
        result_dict['rel_orient'].append(rel_orient)
        result_dict['distance'].append(total_dist)
            
        #                                                       gene - other
        other = indiv.genes[(i_gene + 1) % indiv.nb_genes]
        if gene.orientation == 0 and other.orientation == 1:   # |----> <--|--
            rel_orient = 'conv'
            gene_dist = gene.length + other.length // 2
            
        elif gene.orientation == 0 and other.orientation == 0: # |----> --|-->
            rel_orient = 'upstr'
            gene_dist = gene.length + other.length // 2
    
        elif gene.orientation == 1 and other.orientation == 0: # <----| --|-->
            rel_orient = 'div'
            gene_dist = other.length // 2
            
        else:                                                  # <----| <--|--
            rel_orient = 'downstr'
            gene_dist = other.length // 2
        total_dist = gene.intergene + gene_dist
        result_dict['gene_id'].append(gene.id)
        result_dict['gene_type'].append(gene_types[gene.gene_type])
        result_dict['gene_orient'].append(orient_name[gene.orientation])
        result_dict['other_id'].append(other.id)
        result_dict['other_type'].append(gene_types[other.gene_type])
        result_dict['other_orient'].append(orient_name[other.orientation])
        result_dict['rel_orient'].append(rel_orient)
        result_dict['distance'].append(total_dist)

    return pd.DataFrame.from_dict(result_dict)

In [None]:
def get_full_stats(exp_path, gen):
    rep_dirs = sorted([d for d in exp_path.iterdir() if (d.is_dir() and d.name.startswith("rep"))])
    
    full_stats = pd.DataFrame()
        
    for rep, rep_dir in enumerate(rep_dirs):
        
        best_indiv = get_best_indiv(rep_dir, gen)

        indiv_stats = get_gene_stats(best_indiv)
        indiv_stats.insert(0, 'Replicate', rep)
        full_stats = pd.concat([full_stats, indiv_stats])
            
    return full_stats

In [None]:
full_stats = get_full_stats(exp_path, gen)

In [None]:
filtered_stats = full_stats[full_stats['distance'] <= int(params["interaction_dist"])]

# Plot the average number and distance of gene pairs per gene type

In [None]:
def plot_gene_stats(stats, plot_name, plot_counts):
    
    # If plot_counts is true, the bars represent the count for each pair type and
    # we annotate them with the average distances; if count_bars is false, we 
    # plot the opposite: the bars represent average distances, and are annotated
    # with the count for each pair type.
    
    nb_reps = stats['Replicate'].nunique()
    
    fig, axs = plt.subplots(3, 3, sharey='all', figsize=(12, 12), dpi=300)
    
    group_cols = ['gene_type', 'other_type', 'rel_orient']
    
    idx = pd.MultiIndex.from_product([gene_types, gene_types, rel_orient], names=group_cols)
    
    # Average number of pairs of each type (characterized by the values from the `group_cols` columns
    pair_numbers = stats.groupby(group_cols).count().reindex(idx, fill_value=0).sort_index()['Replicate'] / nb_reps
    
    # Average distance between the genes in the pairs (global average over the pairs from all individuals)
    pair_distances = stats.groupby(group_cols).mean().reindex(idx, fill_value=0).sort_index()['distance']
        
    colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red']
    
    x = np.arange(len(rel_orient))

    for i_gene_type, gene_type in enumerate(gene_types):
        for i_other_type, other_type in enumerate(gene_types):
                
            if plot_counts:
                bar_data = pair_numbers.loc[(gene_type, other_type)]
                text_data = pair_distances.loc[(gene_type, other_type)]

            else: # plot distances
                bar_data = pair_distances.loc[(gene_type, other_type)]
                text_data = pair_numbers.loc[(gene_type, other_type)]

            ax = axs[i_gene_type][i_other_type]

            rects = ax.bar(x, bar_data, color=colors)

            if plot_counts:
                ax.set_ylim(0, 20)

            ax.set_xticks(x)
            ax.set_xticklabels(labels=rel_orient, fontsize="large")
            ax.yaxis.set_tick_params(labelsize="large")

            #Label rects with # of events (deactivated for now)
            for i_rect, rect in enumerate([]): #enumerate(rects):
                ax.annotate(f"{text_data[i_rect]:.1f}",
                            xy=(rect.get_x() + rect.get_width()/2, 0),
                            xytext=(0, 3),
                            ha='center',
                            textcoords="offset points",
                            color='black')

            ax.grid(b=True, axis='y', linestyle=':')

            if i_gene_type == 0: # First line
                if plot_counts:
                    title = f"... of {other_type} genes"
                else:
                    title = f"to ... {other_type} genes"
                ax.set_title(title, fontsize='xx-large')
                
            if i_other_type == 0: # First column
                if plot_counts:
                    ylabel = f"Number of {gene_type} genes"
                else:
                    ylabel = f"Distance of {gene_type} genes"
                ax.set_ylabel(ylabel,rotation='vertical', fontsize='xx-large')

    fig.suptitle(exp_path.name, size='xx-large')
                    
    plt.savefig(plot_name, dpi=300, bbox_inches='tight')
    plt.show()

In [None]:
plot_gene_stats(full_stats, plot_counts=True, plot_name=exp_path.joinpath('gene_pair_counts.pdf'))

In [None]:
#plot_gene_stats(filtered_stats, plot_counts=True, plot_name=exp_path.joinpath('gene_pair_counts_filtered.pdf'))

In [None]:
plot_gene_stats(full_stats, plot_counts=False, plot_name=exp_path.joinpath('gene_pair_distances.pdf'))

In [None]:
#plot_gene_stats(filtered_stats, plot_counts=False, plot_name=exp_path.joinpath('gene_pair_distances_filtered.pdf'))

## Plot the average contribution of each type of gene on each type of gene

In [None]:
def gene_effect(rel_orient, distance): # Return the SC effect of the gene on the other gene
    if distance >= params['interaction_dist']:
        return 0
    if rel_orient == 'div' or rel_orient == 'downstr': # we activate the other gene if we are downstream of it
        return - (1 - distance / params['interaction_dist']) * params['interaction_coef']
    else:
        return (1 - distance / params['interaction_dist']) * params['interaction_coef']

In [None]:
full_stats.groupby(['gene_type', 'other_type']).count()

In [None]:
def plot_interaction_level_pairs(stats):
    
    nb_gene_types = len(gene_types)
    genes_per_type = params["nb_genes"] / nb_gene_types

    nb_reps = stats['Replicate'].nunique()
    
    group_cols = ['gene_type', 'other_type']
    
    idx = pd.MultiIndex.from_product([gene_types, gene_types], names=group_cols)
    
    # Compute the effect of each gene on the other gene's SC level
    stats['effect'] = np.vectorize(gene_effect)(stats['rel_orient'], stats['distance'])
    
    # Summarize the stats
    activ_stats = stats[stats['effect'] < 0].groupby(group_cols).sum().reindex(idx, fill_value=0).sort_index()['effect'] / (nb_reps * genes_per_type)
    inhib_stats = stats[stats['effect'] > 0].groupby(group_cols).sum().reindex(idx, fill_value=0).sort_index()['effect'] / (nb_reps * genes_per_type)
    
    # Plot
    fig, axs = plt.subplots(3, 3, sharey='all', figsize=(12, 12), dpi=300)

    x = [0, 1]
    colors = ['tab:green', 'tab:red']

    for i_gene_type, gene_type in enumerate(gene_types):
        for i_other_type, other_type in enumerate(gene_types):
            
            
            ax = axs[i_gene_type][i_other_type]
            
            bar_data = [-activ_stats.loc[(gene_type, other_type)], inhib_stats.loc[(gene_type, other_type)]]

            rects = ax.bar(x, bar_data, color=colors)

            ax.set_xticks(x)
            ax.set_xticklabels(labels=["activ", "inhib"], fontsize="large")
            ax.yaxis.set_tick_params(labelsize="large")
            ax.set_ylim(0, 0.29)
            
            ax.grid(b=True, axis='y', linestyle=':')

            if i_gene_type == 0: # First line
                ax.set_title(other_type, fontsize='xx-large')
            if i_other_type == 0:
                ax.set_ylabel(gene_type, rotation='horizontal', ha='right', fontsize='xx-large')

                                                           
    fig.suptitle(exp_path.name, size='xx-large')
                    
    plt.savefig(exp_path.joinpath('interaction_level_pairs.pdf'), dpi=300, bbox_inches='tight')
    plt.show()                                             

In [None]:
plot_interaction_level_pairs(full_stats)

In [None]:
def plot_interaction_level_total(exp_path, gen):
    
    nb_gene_types = len(gene_types)
    genes_per_type = params['nb_genes'] / nb_gene_types
    
    total_activ = np.zeros((nb_gene_types, nb_gene_types))
    total_inhib = np.zeros((nb_gene_types, nb_gene_types))
        
    rep_dirs = sorted([d for d in exp_path.iterdir() if (d.is_dir() and d.name.startswith("rep"))])
    nb_reps = len(rep_dirs)
    
    for rep, rep_dir in enumerate(rep_dirs):

        best_indiv = get_best_indiv(rep_dir, gen)
        
        inter_matrix = best_indiv.compute_inter_matrix()
        
        # total_activ[AB][A]: how much do AB genes activate A genes
        # -> when inter_matrix[A][AB] is negative (meaning activation)
        
        for i_gene, gene in enumerate(best_indiv.genes): # AB in the example
            for i_other, other in enumerate(best_indiv.genes): # A in the example
                if inter_matrix[i_other][i_gene] < 0:
                    total_activ[gene.gene_type][other.gene_type] += - inter_matrix[i_other][i_gene] # make it +
                else:
                    total_inhib[gene.gene_type][other.gene_type] += inter_matrix[i_other][i_gene]

    total_activ /= (nb_reps * genes_per_type)
    total_inhib /= (nb_reps * genes_per_type)
    
    # Now plot it
    fig, axs = plt.subplots(3, 3, sharey='all', figsize=(12, 12), dpi=300)
    
    x = [0, 1]
    colors = ['tab:green', 'tab:red']

    for i_gene_type, gene_type in enumerate(gene_types):
        for i_other_type, other_type in enumerate(gene_types):
            
            ax = axs[i_gene_type][i_other_type]
            
            bar_data = [total_activ[i_gene_type][i_other_type], total_inhib[i_gene_type][i_other_type]]

            rects = ax.bar(x, bar_data, color=colors)

            ax.set_xticks(x)
            ax.set_xticklabels(labels=["activ", "inhib"], fontsize="large")
            ax.yaxis.set_tick_params(labelsize="large")
            #ax.set_ylim(0, 0.29)
            
            ax.grid(b=True, axis='y', linestyle=':')

            if i_gene_type == 0: # First line
                ax.set_title(other_type, fontsize='xx-large')
            if i_other_type == 0:
                ax.set_ylabel(gene_type, rotation='horizontal', ha='right', fontsize='xx-large')

                                                           
    fig.suptitle(exp_path.name, size='xx-large')
                    
    plt.savefig(exp_path.joinpath('interaction_level_total.pdf'), dpi=300, bbox_inches='tight')
    plt.show()                                             

In [None]:
plot_interaction_level_total(exp_path, gen)

# Plot the positive / negative supercoiling due to each gene type in each environment

In [None]:
def plot_pos_neg_supercoiling(exp_path, gen):
    
    nb_gene_types = len(gene_types)
    genes_per_type = params['nb_genes'] / nb_gene_types
        
    pos_supercoiling = {'A': np.zeros((nb_gene_types, nb_gene_types)),
                        'B': np.zeros((nb_gene_types, nb_gene_types))}
    neg_supercoiling = {'A': np.zeros((nb_gene_types, nb_gene_types)),
                        'B': np.zeros((nb_gene_types, nb_gene_types))}
    
    rep_dirs = sorted([d for d in exp_path.iterdir() if (d.is_dir() and d.name.startswith("rep"))])
    nb_reps = len(rep_dirs)
    
    envs = ['A', 'B']
    
    for rep, rep_dir in enumerate(rep_dirs):

        best_indiv = get_best_indiv(rep_dir, gen)
        
        inter_matrix = best_indiv.compute_inter_matrix()
        
        for env in envs:
        
            sigma_env = params[f'sigma_{env}']
            final_expr = best_indiv.run_system(sigma_env)[:, -1]
        
            for i_gene, gene in enumerate(best_indiv.genes):
                for j_gene, other in enumerate(best_indiv.genes):
                    delta_sc = inter_matrix[i_gene][j_gene] * final_expr[j_gene]
                    if delta_sc < 0:
                        neg_supercoiling[env][gene.gene_type][other.gene_type] += delta_sc
                    else:
                        pos_supercoiling[env][gene.gene_type][other.gene_type] += delta_sc
    
    for env in envs:
        pos_supercoiling[env] /= (nb_reps * genes_per_type)
        neg_supercoiling[env] /= (nb_reps * genes_per_type)

    # Plotting
    fig, axs = plt.subplots(3, 3, sharey=True, figsize=(12, 12), dpi=300)
    
    x = [0, 1]

    for i_gene_type, gene_type in enumerate(gene_types):
        for i_other_type, other_type in enumerate(gene_types):
            
            ax = axs[i_gene_type][i_other_type]
            
            pos_bars = [pos_supercoiling['A'][i_gene_type][i_other_type],
                           pos_supercoiling['B'][i_gene_type][i_other_type]]
            neg_bars = [neg_supercoiling['A'][i_gene_type][i_other_type],
                           neg_supercoiling['B'][i_gene_type][i_other_type]]
            
            pos_rects = ax.bar(x, pos_bars, color='tab:red')
            neg_rects = ax.bar(x, neg_bars, color='tab:green')

            ax.set_xticks(x)
            ax.set_xticklabels(labels=["Env. A", "Env. B"], fontsize="large")
            ax.yaxis.set_tick_params(labelsize="large")
            ax.xaxis.set_tick_params(labelsize="x-large")
            #ax.set_ylim(0, 0.29)
            
            ax.grid(b=True, axis='y', linestyle=':')
            ax.invert_yaxis()

            if i_gene_type == 0: # First line
                ax.set_title(f"Effect of {other_type} genes...", fontsize='xx-large')
            if i_other_type == 0:
                ax.set_ylabel(f"... on {gene_type} genes", rotation='vertical', fontsize='xx-large')

                                                           
    fig.suptitle(f'{exp_path.name}', size='xx-large')
                    
    plt.savefig(exp_path.joinpath(f'pos_neg_supercoiling.pdf'), dpi=300, bbox_inches='tight')
    plt.show()                                             

In [None]:
plot_pos_neg_supercoiling(exp_path, gen)

# Plot the distribution of intergenic distances

In [None]:
def get_intergene_distances(exp_path, gen):
    rep_dirs = sorted([d for d in exp_path.iterdir() if (d.is_dir() and d.name.startswith("rep"))])
    full_res = pd.DataFrame()
    for rep, rep_dir in enumerate(rep_dirs):

        best_indiv = get_best_indiv(rep_dir, gen)
        
        intergenes = np.zeros(best_indiv.nb_genes)
        for i_gene, gene in enumerate(best_indiv.genes):
            intergenes[i_gene] = gene.intergene
        indiv_res = pd.DataFrame(data={'Intergene':intergenes}, dtype=int)
        indiv_res.insert(0, 'Replicate', rep)

        full_res = pd.concat([full_res, indiv_res])
    
    return full_res

In [None]:
def plot_intergenes(exp_path, gen, cutoff, plot_name):
    
    intergenes = get_intergene_distances(exp_path, gen)
        
    bins = np.linspace(0.0, 5.0, 50)
        
    fig, ax = plt.subplots(figsize=(9, 4), dpi=300)


    plt.ylim(0, 1)
    #plt.xlim(-0.2, 4.2)
    plt.xlabel('Distance (bp)')
    plt.ylabel('Density')
    # Plot intergene distances
    plt.hist(np.log10(intergenes['Intergene']), bins=bins, density=True)
    # Plot cutoff line
    plt.vlines(np.log10(cutoff), 0, 1, linestyle='--', linewidth=1,
               color='tab:red', label='Cutoff distance')
    plt.grid(linestyle=':')

    # Write the x ticks in log scale
    ax.xaxis.set_major_locator(ticker.FixedLocator(ax.get_xticks()))
    ax.set_xticklabels([f'$10^{int(x)}$' for x in ax.get_xticks()])

    plt.title(exp_path.name)
    
    plt.legend(loc='upper left')
            
    plt.savefig(plot_name, dpi=300)
    plt.show()

In [None]:
plot_intergenes(exp_path, gen=gen, cutoff=params['interaction_dist'], 
                plot_name=exp_path.joinpath('intergene_distr.pdf'))

# Explore the gene interaction graph

In [None]:
def get_interaction_graph(rep_path, gen): # Get the interaction graph of the best indiv of the gen
    
    indiv = get_best_indiv(rep_path=rep_path, gen=gen)
    
    ## Build the graph
    inter_graph = nx.DiGraph()

    # Nodes
    for i_gene, gene in enumerate(indiv.genes):
        inter_graph.add_node(i_gene, gene=gene)

    # Edges
    for i_gene in range(indiv.nb_genes):
        for j_gene in range(indiv.nb_genes):
            if i_gene == j_gene:
                continue
            if indiv.inter_matrix[i_gene][j_gene] == 0.0:
                continue
            # influence of gene j on gene i, so the edge is j -> i
            inter_graph.add_edge(j_gene, i_gene, inter=indiv.inter_matrix[i_gene][j_gene])
            
    return indiv, inter_graph


## Number of neighbors per type of gene

In [None]:
def get_neighbors_stats(exp_path, gen):
    
    nb_activ_neighbors = np.zeros((3, 3))
    nb_inhib_neighbors = np.zeros((3, 3))
    
    rep_dirs = sorted([d for d in exp_path.iterdir() if (d.is_dir() and d.name.startswith("rep"))])
    nb_reps = len(rep_dirs)
        
    result_dict = {'Replicate': [],
                   'Type': [],
                   'AB_activ': [],
                   'A_activ': [],
                   'B_activ': [],
                   'AB_inhib': [],
                   'A_inhib': [],
                   'B_inhib': []}
    
    for rep, rep_dir in enumerate(rep_dirs):

        indiv, graph = get_interaction_graph(rep_dir, gen)
        
        genes_per_type = indiv.nb_genes / len(gene_types)
        
        # Count the number of activated / inhibited neighbors for this replicate
        for node in graph:
            for neighbor in graph[node]:
                
                node_type = indiv.genes[node].gene_type
                neighbor_type = indiv.genes[neighbor].gene_type
                
                if graph[node][neighbor]['inter'] < 0: # node activates neighbor
                    nb_activ_neighbors[node_type][neighbor_type] += 1
                    
                else: # node inhibits neighbor
                    nb_inhib_neighbors[node_type][neighbor_type] += 1
        
        nb_activ_neighbors /= genes_per_type
        nb_inhib_neighbors /= genes_per_type
        
        # Add results to the dataframe
        for i_gene_type, gene_type in enumerate(gene_types):
            
            result_dict['Replicate'].append(rep)
            result_dict['Type'].append(i_gene_type)
            
            for i_other_type, other_type in enumerate(gene_types):
                result_dict[gene_type + '_activ'].append(nb_activ_neighbors[i_gene_type][i_other_type])
                result_dict[gene_type + '_inhib'].append(nb_inhib_neighbors[i_gene_type][i_other_type])
        
    res_df = pd.DataFrame.from_dict(result_dict)
    
    for gene_type in gene_types:
        res_df[f'{gene_type}_total'] = res_df[f'{gene_type}_activ'] + res_df[f'{gene_type}_inhib']
        
    return res_df

In [None]:
neighbor_stats = get_neighbors_stats(exp_path, gen)

In [None]:
def plot_number_of_neighbors(exp_path, gen, rel_type):

    # Compute the stats
    neighbor_stats = get_neighbors_stats(exp_path, gen)
    
    mean_stats = neighbor_stats.groupby('Type').mean()    
    std_stats = neighbor_stats.groupby('Type').std()  
    
    # Plot
    plt.figure(figsize=(9, 4), dpi=300)
    
    plt.bar(gene_types, mean_stats[f'AB_{rel_type}'], yerr=std_stats[f'AB_{rel_type}'],
            capsize=5, color='tab:blue', label='AB')
    plt.bar(gene_types, mean_stats[f'A_{rel_type}'], yerr=std_stats[f'A_{rel_type}'],
            bottom=mean_stats[f'AB_{rel_type}'],
            capsize=5, color='tab:red', label='A')
    plt.bar(gene_types, mean_stats[f'B_{rel_type}'], yerr=std_stats[f'B_{rel_type}'],
            bottom=mean_stats[f'AB_{rel_type}'] + mean_stats[f'A_{rel_type}'],
            capsize=5, color='tab:green', label='B')

    plt.title(f'{exp_path.name} ({rel_type})')
    plt.ylabel('Average number of neighbors')
    plt.xlabel('Gene type')
    plt.ylim(0, 9)
    plt.grid(linestyle=':', axis='y')
    plt.legend()
    
    plt.savefig(exp_path.joinpath(f'neighbors_per_type_{rel_type}.pdf'), dpi=300)
    
    plt.show()

In [None]:
plot_number_of_neighbors(exp_path, gen, rel_type='total')

In [None]:
plot_number_of_neighbors(exp_path, gen, rel_type='activ')

In [None]:
plot_number_of_neighbors(exp_path, gen, rel_type='inhib')

## (Strongly) connected components

In [None]:
def plot_scc_stats(exp_path, gen):
    
    rep_dirs = sorted([d for d in exp_path.iterdir() if (d.is_dir() and d.name.startswith("rep"))])
    nb_reps = len(rep_dirs)
    
    nb_genes = int(params['nb_genes'])
    
    genes_in_sccs_by_size = np.zeros(nb_genes + 1) # The largest possible scc has all genes
    
    for rep, rep_dir in enumerate(rep_dirs):

        indiv, graph = get_interaction_graph(rep_dir, gen)

        for scc in nx.algorithms.strongly_connected_components(graph):
            genes_in_sccs_by_size[len(scc)] += len(scc)
            
    genes_in_sccs_by_size /= nb_reps # Normalize by the number of replicas
    
    plt.figure(figsize=(8, 4), dpi=300)
    plt.grid(linestyle=':')
    plt.xlabel('Number of genes in the CC')
    #plt.ylabel('Number of CCs')
    plt.ylabel('Number of genes in a CC of this size')
    #plt.ylim(0.0, 3.3)
    plt.title(exp_path.name)
    
    plt.bar(range(len(genes_in_sccs_by_size)), genes_in_sccs_by_size)
    
    plt.savefig(exp_path.joinpath('connected_components_genes.pdf'), dpi=300)
    
    plt.show()

In [None]:
#plot_scc_stats(exp_path, gen)

## Plot the interaction graph

In [None]:
def plot_interaction_graph(exp_path, gen, method='spring'):
    
    rep_dirs = sorted([d for d in exp_path.iterdir() if (d.is_dir() and d.name.startswith("rep"))])
    
    for rep, rep_dir in enumerate(rep_dirs):
        
        indiv, inter_graph = get_interaction_graph(rep_dir, gen)
        
        ## Draw the figure
        plt.figure(figsize=(16,16), dpi=dpi)
        plt.box(False)

        # Choose the layout
        if method == 'spring':
            layout = nx.spring_layout(inter_graph, k=0.6)
        elif method == 'circular':
            layout = nx.circular_layout(inter_graph)
        elif method == 'graphviz':
            layout = nx.nx_agraph.graphviz_layout(inter_graph)
        else:
            raise ValueError(f"Unknown graph layout '{method}'")
        
        # Draw the nodes
        nx.draw_networkx_nodes(inter_graph, layout, node_size=600,
                               node_color=[gene_type_color[gene.gene_type] for gene in indiv.genes])
        nx.draw_networkx_labels(inter_graph, layout)
        
        # Draw the edges
        # A negative value means we lower sigma at the other gene, hence increasing expression 
        activ_edges = [e for e in inter_graph.edges if inter_graph[e[0]][e[1]]['inter'] < 0] 
        inhib_edges = [e for e in inter_graph.edges if inter_graph[e[0]][e[1]]['inter'] > 0]

        coef = 10
        activ_widths = [inter_graph[e[0]][e[1]]['inter'] * coef for e in activ_edges]
        inhib_widths = [abs(inter_graph[e[0]][e[1]]['inter']) * coef for e in inhib_edges]

        nx.draw_networkx_edges(inter_graph, layout, edgelist=inhib_edges,
                               width=inhib_widths, edge_color='tab:red', connectionstyle='arc3,rad=0.1')
        nx.draw_networkx_edges(inter_graph, layout, edgelist=activ_edges,
                               width=activ_widths, edge_color='tab:green', connectionstyle='arc3,rad=0.1')

        
        plt.savefig(exp_path.joinpath(f'genome_graph_rep{rep:02}_{method}.png'), dpi=dpi)

In [None]:
#plot_interaction_graph(exp_path, gen, method='graphviz')