In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.ticker as ticker
import networkx as nx
import datetime
import pickle
import pathlib

In [None]:
import importlib
import evotsc
import evotsc_run
import evotsc_plot
importlib.reload(evotsc)
importlib.reload(evotsc_run)
importlib.reload(evotsc_plot)

In [None]:
exp_path = pathlib.Path('/Users/theotime/Desktop/evotsc/pci/main/')
gen = 1000_000
gene_types = ['AB', 'A', 'B'] # Name of each gene type
gene_type_color = ['tab:blue', 'tab:red', 'tab:green'] #AB, A, B
orient_name = ['leading', 'lagging'] # Name of each gene orientation
rel_orients = ['conv', 'div', 'downstr', 'upstr'] # In alphabetical order

In [None]:
label_fontsize=20
tick_fontsize=15
legend_fontsize=15
dpi=300

In [None]:
rep_dirs = sorted([d for d in exp_path.iterdir() if (d.is_dir() and d.name.startswith("rep"))])
params = evotsc_run.read_params(rep_dirs[0])

In [None]:
params

In [None]:
def get_best_indiv(rep_path, gen):
    
    with open(rep_path.joinpath(f'pop_gen_{gen:06}.evotsc'), 'rb') as save_file:
        pop_rep = pickle.load(save_file)
        
    pop_rep.evaluate()
    
    best_fit = 0
    best_indiv = pop_rep.individuals[0]
    
    try:
        for indiv in pop_rep.individuals:
            if indiv.fitness > best_fit:
                best_fit = indiv.fitness
                best_indiv = indiv
    except AttributeError: # In the neutral control, individuals are not evaluated so there is no fitness field
        pass
    
    return best_indiv

# Plot genomes

In [None]:
def plot_best_genome_and_tsc(exp_path, gen):
    rep_dirs = sorted([d for d in exp_path.iterdir() if (d.is_dir() and d.name.startswith("rep"))])
    
    for rep, rep_dir in enumerate(rep_dirs):
        
        try:
            best_indiv = get_best_indiv(rep_dir, gen)
            
            evotsc_plot.plot_genome_and_tsc(best_indiv, params['sigma_A'], show_bar=True,
                                            id_interval=best_indiv.nb_genes / 12, print_ids=True,
                                            plot_name=exp_path.joinpath(f'genome_and_tsc_rep{rep:02}_env_A.pdf'))
            
            evotsc_plot.plot_genome_and_tsc(best_indiv, params['sigma_B'], show_bar=True,
                                            id_interval=best_indiv.nb_genes / 12, print_ids=True,
                                            plot_name=exp_path.joinpath(f'genome_and_tsc_rep{rep:02}_env_B.pdf'))
        except FileNotFoundError: # Skip missing data
            pass

In [None]:
#plot_best_genome_and_tsc(exp_path, gen)

# Get the genome stats

In [None]:
def get_pair_stats(indiv):
        
    result_dict = {'gene_id': [],
                   'gene_type': [],
                   'gene_orient': [],
                   'other_id': [],
                   'other_type': [],
                   'other_orient': [],
                   'rel_orient': [],
                   'intergene_dist': [],
                   'interaction_dist': []}
        
    for i_gene, gene in enumerate(indiv.genes):
        #                                                       other  -  gene
        other = indiv.genes[i_gene - 1]        
        # As in evotsc.compute_inter_matrix(), the relevant distance is
        # the distance between the promoter of the focal gene and the 
        # middle of the other gene.
        if other.orientation == 0 and gene.orientation == 1:   # --|--> <----|
            rel_orient = 'conv'
            interaction_dist = other.length // 2 + other.intergene + (gene.length - 1)
    
        elif other.orientation == 0 and gene.orientation == 0: # --|--> |---->
            rel_orient = 'downstr'
            interaction_dist = other.length // 2 + other.intergene

        elif other.orientation == 1 and gene.orientation == 0: # <--|-- |---->
            rel_orient = 'div'
            interaction_dist = (other.length // 2 + 1) + other.intergene

        else:                                                  # <--|-- <----|
            rel_orient = 'upstr'
            interaction_dist = (other.length // 2 + 1) + other.intergene + (gene.length - 1)
            
        result_dict['gene_id'].append(gene.id)
        result_dict['gene_type'].append(gene_types[gene.gene_type])
        result_dict['gene_orient'].append(orient_name[gene.orientation])
        result_dict['other_id'].append(other.id)
        result_dict['other_type'].append(gene_types[other.gene_type])
        result_dict['other_orient'].append(orient_name[other.orientation])
        result_dict['rel_orient'].append(rel_orient)
        result_dict['intergene_dist'].append(other.intergene)
        result_dict['interaction_dist'].append(interaction_dist)
            
            
        #                                                       gene - other
        other = indiv.genes[(i_gene + 1) % indiv.nb_genes]
        if gene.orientation == 0 and other.orientation == 1:   # |----> <--|--
            rel_orient = 'conv'
            interaction_dist = gene.length + gene.intergene + (other.length // 2 - 1)
            
        elif gene.orientation == 0 and other.orientation == 0: # |----> --|-->
            rel_orient = 'upstr'
            interaction_dist = gene.length + gene.intergene + other.length // 2
    
        elif gene.orientation == 1 and other.orientation == 0: # <----| --|-->
            rel_orient = 'div'
            interaction_dist = (gene.intergene + 1) + other.length // 2
            
        else:                                                  # <----| <--|--
            rel_orient = 'downstr'
            interaction_dist = (gene.intergene + 1) + (other.length // 2 - 1)
                    
        result_dict['gene_id'].append(gene.id)
        result_dict['gene_type'].append(gene_types[gene.gene_type])
        result_dict['gene_orient'].append(orient_name[gene.orientation])
        result_dict['other_id'].append(other.id)
        result_dict['other_type'].append(gene_types[other.gene_type])
        result_dict['other_orient'].append(orient_name[other.orientation])
        result_dict['rel_orient'].append(rel_orient)
        result_dict['intergene_dist'].append(gene.intergene)
        result_dict['interaction_dist'].append(interaction_dist)

    return pd.DataFrame.from_dict(result_dict)

In [None]:
def get_whole_stats(indiv):
    
    # Count all pairs of interacting genes
    
    gene_positions, genome_size = indiv.compute_gene_positions(include_coding=True)
    
    result_dict = {'gene_id': [],
                   'gene_type': [],
                   'gene_orient': [],
                   'other_id': [],
                   'other_type': [],
                   'other_orient': [],
                   'rel_orient': [],
                   'intergene_dist': [],
                   'interaction_dist': []}
        
    for i_gene, gene in enumerate(indiv.genes):
        for i_other, other in enumerate(indiv.genes):
            
            if i_other == i_gene:
                continue
            
            # As in evotsc.compute_inter_matrix(), the relevant distance is
            # the distance between the promoter of the focal gene and the 
            # middle of the other gene.
            #                                                       other  -  gene
                            
            
            if other.orientation == 0:  # Leading
                pos_other = gene_positions[i_other] + other.length // 2
            else:  # Lagging
                pos_other = gene_positions[i_other] - other.length // 2


            pos_1_minus_2 = gene_positions[i_gene] - pos_other
            pos_2_minus_1 = - pos_1_minus_2

            # We want to know whether gene i comes before or after gene j
            # Before: -------1--2-------- or -2---------------1-
            # After:  -------2--1-------- or -1---------------2-

            if pos_1_minus_2 < 0: # -------1--2-------- ou -1---------------2-
                if pos_2_minus_1 < genome_size + pos_1_minus_2: # -------1--2--------
                    distance = pos_2_minus_1
                    i_before_j = True
                else: # -1---------------2-
                    distance = genome_size + pos_1_minus_2
                    i_before_j = False

            else: # -------2--1-------- ou -2---------------1-
                if pos_1_minus_2 < genome_size + pos_2_minus_1: # -------2--1--------
                    distance = pos_1_minus_2
                    i_before_j = False
                else:
                    distance = genome_size + pos_2_minus_1
                    i_before_j = True
                    

            # Exit early if genes are too far
            if distance > indiv.interaction_dist:
                # inter_matrix[i, j] and inter_matrix[j, i] are already 0.0
                continue

            result_dict['gene_id'].append(gene.id)
            result_dict['gene_type'].append(gene_types[gene.gene_type])
            result_dict['gene_orient'].append(orient_name[gene.orientation])
            result_dict['other_id'].append(other.id)
            result_dict['other_type'].append(gene_types[other.gene_type])
            result_dict['other_orient'].append(orient_name[other.orientation])
            
            # we want to know if gene is ... of other
            if i_before_j:
                rel_orients = [['upstr', 'conv'],
                               ['div', 'downstr']]
            else:
                rel_orients = [['downstr', 'div'],
                               ['conv', 'upstr']]
            rel_orient = rel_orients[gene.orientation][other.orientation]
            
            result_dict['rel_orient'].append(rel_orient)
            

            if i_before_j:
                if gene.orientation == 0 and other.orientation == 0:
                    intergene_dist = distance - gene.length - other.length // 2
                elif gene.orientation == 0 and other.orientation == 1:
                    intergene_dist = distance - gene.length - (other.length // 2 - 1) 
                elif gene.orientation == 1 and other.orientation == 0:
                    intergene_dist = (distance - 1) - other.length // 2
                else: #gene.orientation == 1 and other.orientation == 1:
                    intergene_dist = (distance - 1) - (other.length // 2 - 1)
            
            else:
                if other.orientation == 0 and gene.orientation == 0:
                    intergene_dist = distance - other.length // 2
                elif other.orientation == 0 and gene.orientation == 1:
                    intergene_dist = distance - other.length // 2 - (gene.length - 1)
                elif other.orientation == 1 and gene.orientation == 0:
                    intergene_dist = distance - (other.length // 2 + 1)
                else: #other.orientation == 1 and gene.orientation == 1:
                    intergene_dist = distance - (other.length // 2 + 1) - (gene.length - 1)
            

            result_dict['intergene_dist'].append(intergene_dist)
            result_dict['interaction_dist'].append(distance)
            
    return pd.DataFrame.from_dict(result_dict)

In [None]:
def get_full_pair_stats(exp_path, gen, only_pairs=True):
    rep_dirs = sorted([d for d in exp_path.iterdir() if (d.is_dir() and d.name.startswith("rep"))])
    
    full_stats = pd.DataFrame()
        
    for rep, rep_dir in enumerate(rep_dirs):
        
        best_indiv = get_best_indiv(rep_dir, gen)

        if only_pairs:
            indiv_stats = get_pair_stats(best_indiv)
        else:
            indiv_stats = get_whole_stats(best_indiv)
        indiv_stats.insert(0, 'Replicate', rep)
        full_stats = pd.concat([full_stats, indiv_stats])
            
    return full_stats

In [None]:
pair_stats = get_full_pair_stats(exp_path, gen)

In [None]:
whole_stats = get_full_pair_stats(exp_path, gen, only_pairs=False)

# Plot the average number and distance of gene pairs per gene type

In [None]:
def plot_gene_stats(stats, plot_counts, name):
    
    # If plot_counts is true, the bars represent the count for each pair type and
    # we annotate them with the average distances; if count_bars is false, we 
    # plot the opposite: the bars represent average distances, and are annotated
    # with the count for each pair type.
    

    fig, axs = plt.subplots(3, 3, sharey='all', figsize=(12, 12), dpi=300)
    
    
    ## Massage the data into usable form
    nb_reps = stats['Replicate'].nunique()
    group_cols = ['gene_type', 'other_type', 'rel_orient']
    group_dims = [gene_types, gene_types, rel_orients]
    group_cols_rep = [*group_cols, 'Replicate']
    group_dims_rep = [*group_dims, range(nb_reps)]
    idx = pd.MultiIndex.from_product(group_dims, names=group_cols)
    idx_rep = pd.MultiIndex.from_product(group_dims_rep, names=group_cols_rep)
    
    # Average number of pairs of each type (characterized by the values from the `group_cols` columns)
    pair_numbers = stats.groupby(group_cols).count().reindex(idx, fill_value=0).sort_index()['gene_id'] / nb_reps
    pair_numbers_rep = stats.groupby(group_cols_rep).count().reindex(idx_rep, fill_value=0).sort_index()['gene_id'] 

    
    # Average distance between the genes in the pairs (global average over the pairs from all individuals)
    pair_distances = stats.groupby(group_cols).mean().reindex(idx, fill_value=0).sort_index()['intergene_dist']
    pair_distances_rep = stats.groupby(group_cols_rep).mean().reindex(idx_rep, fill_value=0).sort_index()['intergene_dist']
        
    # Average distances are computed correctly, even when some replicates do not have a
    # specific kind of pair present (they are not counted as a 0 in the average).
    
    
    colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red']
    
    nb_orient = len(rel_orients)
    x_bar = np.arange(nb_orient)
    
    ## Plot
    for i_gene_type, gene_type in enumerate(gene_types):
        for i_other_type, other_type in enumerate(gene_types):
                
            ax = axs[i_gene_type][i_other_type]    
            
            
            # Plot the replicate-averaged data
            if plot_counts:
                bar_data = pair_numbers.loc[(gene_type, other_type)]
                rep_data = pair_numbers_rep.loc[(gene_type, other_type)]

            else: # plot distances
                bar_data = pair_distances.loc[(gene_type, other_type)]
                rep_data = pair_distances_rep.loc[(gene_type, other_type)]

            rects = ax.bar(x_bar, bar_data, color=colors)
            
            # Plot the corresponding boxplot, or individual replicates            
            if nb_reps > 10: # If we have many replicates, draw a boxplot over the means
                rep_data = rep_data.unstack(level='rel_orient')
                ax.boxplot(rep_data, positions=x_bar, manage_ticks=False, medianprops={'color':'black'})

            else: # Else, just draw every replicate.
                rep_data = rep_data.unstack(level='Replicate')
                bar_width = 0.8 # default value
                x_rep = (np.tile(x_bar, (nb_reps, 1)).T +
                         np.tile(np.linspace(-bar_width/2, bar_width/2, nb_reps+2)[1:-1], (nb_orient, 1)))

                if plot_counts:
                    ax.plot(x_rep, rep_data,
                            marker='o', linestyle='', markeredgecolor='black', markerfacecolor='none')
                # If we're plotting distances, we don't consider pairs that are not present when computing
                # the average distance, so we use a different marker to show that
                else:
                    rep_data_nonzero = rep_data.where(rep_data > 0)
                    rep_data_zero = rep_data.where(rep_data == 0)
                    ax.plot(x_rep, rep_data_nonzero,
                            marker='o', linestyle='', markeredgecolor='black', markerfacecolor='none')
                    ax.plot(x_rep, rep_data_zero,
                            marker='x', linestyle='', markeredgecolor='black', markerfacecolor='none')


            # Plot setup
            if plot_counts:
                ax.set_ylim(0, 32)

            ax.set_xticks(x_bar)
            ax.set_xticklabels(labels=[orient + '.' for orient in rel_orients], fontsize="large")
            ax.yaxis.set_tick_params(labelsize="large")
            
            ax.grid(axis='y', linestyle=':')

            if i_gene_type == 0: # First line
                ax.set_title(f"{other_type}", fontsize='xx-large')
                
            if i_other_type == 0: # First column
                ax.set_ylabel(f"{gene_type}",rotation='vertical', fontsize='xx-large')

    #fig.suptitle(exp_path.name, size='xx-large')
                        
    plt.savefig(exp_path.joinpath(f'{name}.pdf'), dpi=300, bbox_inches='tight')
    plt.show()

In [None]:
plot_gene_stats(pair_stats, plot_counts=True, name='gene_pair_counts')

In [None]:
plot_gene_stats(whole_stats, plot_counts=True, name='gene_pair_counts_whole')

In [None]:
plot_gene_stats(pair_stats, plot_counts=False, name='gene_pair_distances')

In [None]:
plot_gene_stats(whole_stats, plot_counts=False, name='gene_pair_distances_whole')

In [None]:
#plot_gene_stats(filtered_stats, plot_counts=True, plot_name=('gene_pair_counts_filtered')

In [None]:
#plot_gene_stats(filtered_stats, plot_counts=False, plot_name='gene_pair_distances_filtered')

## Plot the positive / negative supercoiling due to each gene type in each environment

In [None]:
def get_pos_neg_supercoiling(exp_path, gen, pairs_only=False):
    
    rep_dirs = sorted([d for d in exp_path.iterdir() if (d.is_dir() and d.name.startswith("rep"))])
    nb_reps = len(rep_dirs)
    
    result_dict = {'Replicate': [],
                   'env': [],
                   'gene_type': [],
                   'other_type': [],
                   'delta_sc': [],
                   'sign': [] }
    
    nb_genes = int(params['nb_genes'])
    genes_per_type = nb_genes / len(gene_types)
    
    envs = ['A', 'B']
    signs = ['neg', 'pos']
    
    for rep, rep_dir in enumerate(rep_dirs):
        
        best_indiv = get_best_indiv(rep_dir, gen)
        
        inter_matrix = best_indiv.compute_inter_matrix()
        
        for env in envs:
        
            sigma_env = params[f'sigma_{env}']
            final_expr = best_indiv.run_system(sigma_env)[-1, :]
        
            for j_gene, other in enumerate(best_indiv.genes):
                for i_gene, gene in enumerate(best_indiv.genes):
                    
                    if (not pairs_only or
                        ((i_gene == (j_gene + 1) % nb_genes or
                        i_gene == (j_gene + nb_genes - 1) % nb_genes))):
                        
                            delta_sc = inter_matrix[i_gene][j_gene] * final_expr[j_gene]
                            
                            result_dict['Replicate'].append(rep)
                            result_dict['env'].append(env)
                            result_dict['gene_type'].append(gene_types[gene.gene_type])
                            result_dict['other_type'].append(gene_types[other.gene_type])
                            result_dict['delta_sc'].append(delta_sc)
                            result_dict['sign'].append(signs[bool(delta_sc > 0)])
                            
    # DataFrame trickery: reindex from a complete MultiIndex to add missing lines
    col_names = ['gene_type', 'other_type', 'env', 'sign', 'Replicate']
    col_dims = [gene_types, gene_types, envs, signs, range(nb_reps)]
    idx = pd.MultiIndex.from_product(col_dims, names=col_names)
    
    # Get the value per replicate (averaged over the genes of that type)
    df = pd.DataFrame.from_dict(result_dict)
    summed_df = df.groupby(col_names).sum() / genes_per_type
    
    # Reindex to add the missing cases
    reindexed_df = summed_df.reindex(idx, fill_value=0).sort_index()

    return reindexed_df

In [None]:
def plot_pos_neg_supercoiling(data, plot_name):
    
    fig, axs = plt.subplots(3, 3, sharey=True, figsize=(12, 12), dpi=300)
    
    x = [0, 1]
    
    rep_avg = data.groupby(['gene_type', 'other_type', 'env', 'sign']).mean()

    for i_other_type, other_type in enumerate(gene_types):
        for i_gene_type, gene_type in enumerate(gene_types):
            
            ax = axs[i_other_type][i_gene_type]
            
            # Rects
            pos_bars = [rep_avg.loc[(gene_type, other_type, 'A', 'pos')]['delta_sc'],
                        rep_avg.loc[(gene_type, other_type, 'B', 'pos')]['delta_sc']]
            neg_bars = [rep_avg.loc[(gene_type, other_type, 'A', 'neg')]['delta_sc'],
                        rep_avg.loc[(gene_type, other_type, 'B', 'neg')]['delta_sc']]
            
            pos_rects = ax.bar(x, pos_bars, color='tab:red')
            neg_rects = ax.bar(x, neg_bars, color='tab:green')
            
            # Boxplots
            ax.boxplot(data.unstack(level='env').loc[(gene_type, other_type, 'pos')], positions=x,
                       manage_ticks=False, medianprops={'color':'black'})
            ax.boxplot(data.unstack(level='env').loc[(gene_type, other_type, 'neg')], positions=x,
                       manage_ticks=False, medianprops={'color':'black'})

            ax.set_xticks(x)
            ax.set_xticklabels(labels=["Env. A", "Env. B"], fontsize="large")
            ax.yaxis.set_tick_params(labelsize="large")
            #ax.set_ylim(-0.032, 0.032)
            
            ax.grid(axis='y', linestyle=':')
            ax.invert_yaxis()

            if i_other_type == 0: # First line
                ax.set_title(f"{gene_type}", fontsize='xx-large')
            if i_gene_type == 0:
                ax.set_ylabel(f"{other_type}", rotation='vertical', fontsize='xx-large')

    #fig.suptitle(f'{exp_path.name}', size='xx-large')
    
    plt.savefig(exp_path.joinpath(plot_name), dpi=300, bbox_inches='tight')
    plt.show()

In [None]:
plot_pos_neg_supercoiling(get_pos_neg_supercoiling(exp_path, gen, pairs_only=False),
                          'pos_neg_supercoiling_whole.pdf')

In [None]:
plot_pos_neg_supercoiling(get_pos_neg_supercoiling(exp_path, gen, pairs_only=True),
                          'pos_neg_supercoiling_pairs.pdf')

# Plot the distribution of intergenic distances

In [None]:
def get_intergene_distances(exp_path, gen):
    rep_dirs = sorted([d for d in exp_path.iterdir() if (d.is_dir() and d.name.startswith("rep"))])
    full_res = pd.DataFrame()
    for rep, rep_dir in enumerate(rep_dirs):

        best_indiv = get_best_indiv(rep_dir, gen)
        
        intergenes = np.zeros(best_indiv.nb_genes)
        for i_gene, gene in enumerate(best_indiv.genes):
            intergenes[i_gene] = gene.intergene
        indiv_res = pd.DataFrame(data={'Intergene':intergenes}, dtype=int)
        indiv_res.insert(0, 'Replicate', rep)

        full_res = pd.concat([full_res, indiv_res])
    
    return full_res

In [None]:
def plot_intergenes(exp_path, gen, cutoff, plot_name):
    
    intergenes = get_intergene_distances(exp_path, gen)
        
    bins = np.linspace(0.0, 5.0, 50)
        
    fig, ax = plt.subplots(figsize=(9, 4), dpi=300)


    plt.ylim(0, 1)
    #plt.xlim(-0.2, 4.2)
    plt.xlabel('Distance (bp)')
    plt.ylabel('Density')
    # Plot intergene distances
    plt.hist(np.log10(intergenes['Intergene']), bins=bins, density=True)
    # Plot cutoff line
    plt.vlines(np.log10(cutoff), 0, 1, linestyle='--', linewidth=1,
               color='tab:red', label='Cutoff distance')
    plt.grid(linestyle=':')

    # Write the x ticks in log scale
    ax.xaxis.set_major_locator(ticker.FixedLocator(ax.get_xticks()))
    ax.set_xticklabels([f'$10^{int(x)}$' for x in ax.get_xticks()])

    plt.title(exp_path.name)
    
    plt.legend(loc='upper left')
            
    plt.savefig(plot_name, dpi=300)
    plt.show()

In [None]:
plot_intergenes(exp_path, gen=gen, cutoff=params['interaction_dist'], 
                plot_name=exp_path.joinpath('intergene_distr.pdf'))

# Explore the gene interaction graph

In [None]:
def get_interaction_graph(rep_path, gen): # Get the interaction graph of the best indiv of the gen
    
    indiv = get_best_indiv(rep_path=rep_path, gen=gen)
    
    ## Build the graph
    inter_graph = nx.DiGraph()

    # Nodes
    for i_gene, gene in enumerate(indiv.genes):
        inter_graph.add_node(i_gene, gene=gene)

    # Edges
    for i_gene in range(indiv.nb_genes):
        for j_gene in range(indiv.nb_genes):
            if i_gene == j_gene:
                continue
            if indiv.inter_matrix[i_gene][j_gene] == 0.0:
                continue
            # influence of gene j on gene i, so the edge is j -> i
            inter_graph.add_edge(j_gene, i_gene, inter=indiv.inter_matrix[i_gene][j_gene])
            
    return indiv, inter_graph


## Number of neighbors per type of gene

In [None]:
def get_neighbors_stats(exp_path, gen):
    
    nb_activ_neighbors = np.zeros((3, 3))
    nb_inhib_neighbors = np.zeros((3, 3))
    
    rep_dirs = sorted([d for d in exp_path.iterdir() if (d.is_dir() and d.name.startswith("rep"))])
    nb_reps = len(rep_dirs)
        
    result_dict = {'Replicate': [],
                   'Type': [],
                   'AB_activ': [],
                   'A_activ': [],
                   'B_activ': [],
                   'AB_inhib': [],
                   'A_inhib': [],
                   'B_inhib': []}
    
    for rep, rep_dir in enumerate(rep_dirs):

        indiv, graph = get_interaction_graph(rep_dir, gen)
        
        genes_per_type = indiv.nb_genes / len(gene_types)
        
        # Count the number of activated / inhibited neighbors for this replicate
        for node in graph:
            for neighbor in graph[node]:
                
                node_type = indiv.genes[node].gene_type
                neighbor_type = indiv.genes[neighbor].gene_type
                
                if graph[node][neighbor]['inter'] < 0: # node activates neighbor
                    nb_activ_neighbors[node_type][neighbor_type] += 1
                    
                else: # node inhibits neighbor
                    nb_inhib_neighbors[node_type][neighbor_type] += 1
        
        nb_activ_neighbors /= genes_per_type
        nb_inhib_neighbors /= genes_per_type
        
        # Add results to the dataframe
        for i_gene_type, gene_type in enumerate(gene_types):
            
            result_dict['Replicate'].append(rep)
            result_dict['Type'].append(i_gene_type)
            
            for i_other_type, other_type in enumerate(gene_types):
                result_dict[gene_type + '_activ'].append(nb_activ_neighbors[i_gene_type][i_other_type])
                result_dict[gene_type + '_inhib'].append(nb_inhib_neighbors[i_gene_type][i_other_type])
        
    res_df = pd.DataFrame.from_dict(result_dict)
    
    for gene_type in gene_types:
        res_df[f'{gene_type}_total'] = res_df[f'{gene_type}_activ'] + res_df[f'{gene_type}_inhib']
        
    return res_df

In [None]:
neighbor_stats = get_neighbors_stats(exp_path, gen)

In [None]:
def plot_number_of_neighbors(exp_path, gen, rel_type):

    # Compute the stats
    neighbor_stats = get_neighbors_stats(exp_path, gen)
    
    mean_stats = neighbor_stats.groupby('Type').mean()    
    std_stats = neighbor_stats.groupby('Type').std()  
    
    # Plot
    plt.figure(figsize=(9, 4), dpi=300)
    
    plt.bar(gene_types, mean_stats[f'AB_{rel_type}'], yerr=std_stats[f'AB_{rel_type}'],
            capsize=5, color='tab:blue', label='AB')
    plt.bar(gene_types, mean_stats[f'A_{rel_type}'], yerr=std_stats[f'A_{rel_type}'],
            bottom=mean_stats[f'AB_{rel_type}'],
            capsize=5, color='tab:red', label='A')
    plt.bar(gene_types, mean_stats[f'B_{rel_type}'], yerr=std_stats[f'B_{rel_type}'],
            bottom=mean_stats[f'AB_{rel_type}'] + mean_stats[f'A_{rel_type}'],
            capsize=5, color='tab:green', label='B')

    plt.title(f'{exp_path.name} ({rel_type})')
    plt.ylabel('Average number of neighbors')
    plt.xlabel('Gene type')
    #plt.ylim(0, 9)
    plt.grid(linestyle=':', axis='y')
    plt.legend()
    
    plt.savefig(exp_path.joinpath(f'neighbors_per_type_{rel_type}.pdf'), dpi=300)
    
    plt.show()

In [None]:
plot_number_of_neighbors(exp_path, gen, rel_type='total')

In [None]:
plot_number_of_neighbors(exp_path, gen, rel_type='activ')

In [None]:
plot_number_of_neighbors(exp_path, gen, rel_type='inhib')

## (Strongly) connected components

In [None]:
def plot_scc_stats(exp_path, gen):
    
    rep_dirs = sorted([d for d in exp_path.iterdir() if (d.is_dir() and d.name.startswith("rep"))])
    nb_reps = len(rep_dirs)
    
    nb_genes = int(params['nb_genes'])
    
    genes_in_sccs_by_size = np.zeros(nb_genes + 1) # The largest possible scc has all genes
    
    for rep, rep_dir in enumerate(rep_dirs):

        indiv, graph = get_interaction_graph(rep_dir, gen)

        for scc in nx.algorithms.strongly_connected_components(graph):
            genes_in_sccs_by_size[len(scc)] += len(scc)
            
    genes_in_sccs_by_size /= nb_reps # Normalize by the number of replicas
    
    plt.figure(figsize=(8, 4), dpi=300)
    plt.grid(linestyle=':')
    plt.xlabel('Number of genes in the CC')
    #plt.ylabel('Number of CCs')
    plt.ylabel('Number of genes in a CC of this size')
    #plt.ylim(0.0, 3.3)
    plt.title(exp_path.name)
    
    plt.bar(range(len(genes_in_sccs_by_size)), genes_in_sccs_by_size)
    
    plt.savefig(exp_path.joinpath('connected_components_genes.pdf'), dpi=300)
    
    plt.show()

In [None]:
#plot_scc_stats(exp_path, gen)

## Plot the interaction graph

In [None]:
def plot_interaction_graph(exp_path, gen, method='spring'):
    
    rep_dirs = sorted([d for d in exp_path.iterdir() if (d.is_dir() and d.name.startswith("rep"))])
    
    for rep, rep_dir in enumerate(rep_dirs):
        
        indiv, inter_graph = get_interaction_graph(rep_dir, gen)
        
        ## Draw the figure
        plt.figure(figsize=(16,16), dpi=dpi)
        plt.box(False)

        # Choose the layout
        if method == 'spring':
            layout = nx.spring_layout(inter_graph, k=0.6)
        elif method == 'circular':
            layout = nx.circular_layout(inter_graph)
        elif method == 'graphviz':
            layout = nx.nx_agraph.graphviz_layout(inter_graph)
        else:
            raise ValueError(f"Unknown graph layout '{method}'")
        
        # Draw the nodes
        nx.draw_networkx_nodes(inter_graph, layout, node_size=600,
                               node_color=[gene_type_color[gene.gene_type] for gene in indiv.genes])
        nx.draw_networkx_labels(inter_graph, layout)
        
        # Draw the edges
        # A negative value means we lower sigma at the other gene, hence increasing expression 
        activ_edges = [e for e in inter_graph.edges if inter_graph[e[0]][e[1]]['inter'] < 0] 
        inhib_edges = [e for e in inter_graph.edges if inter_graph[e[0]][e[1]]['inter'] > 0]

        coef = 100
        activ_widths = [inter_graph[e[0]][e[1]]['inter'] * coef for e in activ_edges]
        inhib_widths = [abs(inter_graph[e[0]][e[1]]['inter']) * coef for e in inhib_edges]

        nx.draw_networkx_edges(inter_graph, layout, edgelist=inhib_edges,
                               width=inhib_widths, edge_color='tab:red', connectionstyle='arc3,rad=0.1')
        nx.draw_networkx_edges(inter_graph, layout, edgelist=activ_edges,
                               width=activ_widths, edge_color='tab:green', connectionstyle='arc3,rad=0.1')

        
        plt.savefig(exp_path.joinpath(f'genome_graph_rep{rep:02}_{method}.png'), dpi=dpi)
        plt.close()

In [None]:
#plot_interaction_graph(exp_path, gen, method='graphviz')