In [None]:
import sys
import pathlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import pickle
import itertools
import networkx as nx
from scipy import stats

In [None]:
from numba import jit

In [None]:
import met_brewer

In [None]:
import importlib
import evotsc
import evotsc_lib
import evotsc_plot
importlib.reload(evotsc)
importlib.reload(evotsc_lib)
importlib.reload(evotsc_plot)

In [None]:
label_fontsize=20
tick_fontsize=15
legend_fontsize=15
dpi=300

In [None]:
exp_path = pathlib.Path('/Users/theotime/Desktop/evotsc/phd/epistasis/wt-with-sc/')
main_path = pathlib.Path('/Users/theotime/Desktop/evotsc/phd/param-explor/main/')
gen = 1_000_000 # mais passer à 1e6
gene_types = ['AB', 'A', 'B'] # Name of each gene type
gene_type_color = ['tab:blue', 'tab:red', 'tab:green'] #AB, A, B
orient_name = ['leading', 'lagging'] # Name of each gene orientation
rel_orients = ['conv', 'div', 'downstr', 'upstr'] # In alphabetical order
envs = ['A', 'B'] # Environment names

In [None]:
rng = np.random.default_rng(seed=123456)

In [None]:
exp_rep_dirs = sorted([d for d in exp_path.iterdir() if (d.is_dir() and d.name.startswith("rep"))])
nb_exp_reps = len(exp_rep_dirs)
exp_params = evotsc_lib.read_params(exp_rep_dirs[0])
exp_params['m'] = 2.5 # Temporary fix because the parameter wasn't saved

main_rep_dirs = sorted([d for d in main_path.iterdir() if (d.is_dir() and d.name.startswith("rep"))])
main_params = evotsc_lib.read_params(main_rep_dirs[0])
main_params['m'] = 2.5 # Temporary fix because the parameter wasn't saved

In [None]:
nb_exp_reps

In [None]:
len(main_rep_dirs)

In [None]:
exp_params

In [None]:
genes_per_type = exp_params["nb_genes"] / len(gene_types)

# Make random individuals (reused throughout)

In [None]:
def make_random_indivs(nb_indiv, params):
    
    mutation = evotsc.Mutation(inversion_poisson_lam=params['inversion_poisson_lam'])
    
    rand_indivs = []
    
    for rep in range(nb_indiv):
        indiv = evotsc_lib.make_random_indiv(intergene=int(params['intergene']),
                                             gene_length=int(params['gene_length']),
                                             nb_genes=int(params['nb_genes']),
                                             default_basal_expression=params['default_basal_expression'],
                                             interaction_dist=params['interaction_dist'],
                                             interaction_coef=params['interaction_coef'],
                                             sigma_basal=params['sigma_basal'],
                                             sigma_opt=params['sigma_opt'],
                                             epsilon=params['epsilon'],
                                             m=params['m'],
                                             selection_coef=params['selection_coef'],
                                             mutation=mutation,
                                             rng=rng,
                                             nb_mutations=100)
        
        indiv.inter_matrix = indiv.compute_inter_matrix()

        rand_indivs.append(indiv)
        
    return rand_indivs

In [None]:
rand_indivs = make_random_indivs(nb_indiv=nb_exp_reps, params=exp_params)

# Evolutionary stats

In [None]:
def get_stats(exp_name, params, gen):
    
    exp_name = pathlib.Path(exp_name)
        
    rep_dirs = sorted([d for d in exp_name.iterdir() if (d.is_dir() and d.name.startswith("rep"))])
    
    res = pd.DataFrame()
    
    cols = ['Gen', 'Fitness', 'ABon_A', 'ABon_B', 'Aon_A', 'Aon_B', 'Bon_A', 'Bon_B']
    
    if not (params['intergene_poisson_lam'] is None or params['intergene_poisson_lam'] == 0.0):
        cols += ['Genome size']

    if not (params['basal_sc_mutation_prob'] is None or params['basal_sc_mutation_prob'] == 0.0):
        cols += ['Basal SC']
    
    for i_rep, rep_dir in enumerate(rep_dirs):
        
        res_dir = pd.read_csv(rep_dir.joinpath('stats.csv'), usecols=cols)

        res_dir.insert(0, 'Replicate', i_rep)
        
        res = pd.concat([res, res_dir])
    
    res['Log Fitness'] = np.log(res['Fitness'])
    
    res = res[res['Gen'] <= gen]
    
    return res

In [None]:
exp_stats = get_stats(exp_path, exp_params, gen)

In [None]:
main_stats = get_stats(main_path, main_params, gen)

## Plot fitness over evolutionary time

In [None]:
def plot_fitness(exp_stats, main_stats, exp_path):
    
    all_stats = [main_stats[main_stats["Gen"] > 0][['Gen', 'Log Fitness', 'Fitness']].copy(),
                 exp_stats[exp_stats["Gen"] > 0][['Gen', 'Log Fitness', 'Fitness']].copy()]
    
    name = ['No SC mutations', 'SC mutations']
        
    all_colors = met_brewer.met_brew(name='Hokusai3', n=6, brew_type='continuous')
    colors = [all_colors[5], all_colors[3]]
    
    plt.figure(figsize=(9, 4), dpi=dpi)
    
    plt.xscale('log')
    plt.yscale('log')
    plt.grid(linestyle=':')
    plt.grid(visible=True, which="minor", axis='x', linestyle=':')
    
    plt.xlabel('Generation', fontsize=label_fontsize)
    plt.ylabel('Fitness', fontsize=label_fontsize)
    
    plt.tick_params(axis='both', which='major', labelsize=tick_fontsize)
    
    for i_exp in range(2):
        
        mean_data = all_stats[i_exp].groupby('Gen').mean().reset_index()
        first_dec = all_stats[i_exp].groupby('Gen').quantile(0.1)
        last_dec = all_stats[i_exp].groupby('Gen').quantile(0.9)
        
        # Average fitness
        plt.plot(mean_data['Gen'],
                 np.exp(mean_data['Log Fitness']),
                 color=colors[i_exp],
                 linewidth=2,
                 zorder=10,
                 label=name[i_exp]) 

        # Quantiles
        plt.plot(mean_data['Gen'],
                 first_dec['Fitness'],
                 color=colors[i_exp],
                 alpha=0.3)

        plt.plot(mean_data['Gen'],
                 last_dec['Fitness'],
                 color=colors[i_exp],
                 alpha=0.3)

    #plt.title(exp_path.name)
    plt.legend(loc='lower right', fontsize=legend_fontsize)
        
    plt.savefig(f'{exp_path}/fitness_all_with_main.pdf', dpi=dpi, bbox_inches='tight')

In [None]:
plot_fitness(exp_stats, main_stats, exp_path)

In [None]:
stats.ttest_ind(exp_stats[exp_stats['Gen'] == gen]['Fitness'],
                main_stats[main_stats['Gen'] == gen]['Fitness'])

### Plot mean intergenic distance

In [None]:
def plot_basal_sc(full_stats):
    
    max_gen = full_stats.groupby('Replicate').max()['Gen'].min()
    
    stats = full_stats[(full_stats["Gen"] > 0) & (full_stats['Gen'] <= max_gen)][['Gen', 'Basal SC']].copy()
        
    mean_data = stats.groupby('Gen').mean().reset_index()
    
    # Note: for the fitness, the mean can be above the quantile
    first_dec = stats.groupby('Gen').quantile(0.1)
    last_dec = stats.groupby('Gen').quantile(0.9)
    
    all_colors = met_brewer.met_brew(name='Hokusai3', n=6, brew_type='continuous')
    colors = [all_colors[5], all_colors[3]]
    
    plt.figure(figsize=(9,4), dpi=dpi)
    
    plt.xscale('log')
    #plt.yscale('log')
    plt.grid(linestyle=':')
    plt.grid(visible=True, which="minor", axis='x', linestyle=':')
    
    plt.xlabel('Generation', fontsize=label_fontsize)
    plt.ylabel('Basal supercoiling', fontsize=label_fontsize)
    
    plt.tick_params(axis='both', which='major', labelsize=tick_fontsize)
        
    # Add main
    plt.hlines(main_params['sigma_basal'], 1e0, gen, linewidth=2, color=colors[0],
               zorder=10, label='No SC mutations')
    
    # Average
    plt.plot(mean_data['Gen'],
             mean_data['Basal SC'],
             color=colors[1],
             linewidth=2,
             zorder=10,
             label='SC mutations') 
    
    # Quantiles
    plt.plot(mean_data['Gen'],
             first_dec['Basal SC'],
             color=colors[1],
             alpha=0.3)
    
    plt.plot(mean_data['Gen'],
             last_dec['Basal SC'],
             color=colors[1],
             alpha=0.3)

    plt.legend(fontsize=legend_fontsize)
    
    # Limits: (-0.06835996833066642, -0.059463889737949686)
    plt.ylim(-0.069, -0.059)
    print(f'Limits: {plt.ylim()}')
    
    plt.savefig(f'{exp_path}/basal_sc_all.pdf', dpi=dpi, bbox_inches='tight')

In [None]:
plot_basal_sc(exp_stats)

# Plot the number of active genes of each type over evolutionary time

In [None]:
def plot_gene_activity_all(full_stats, exp_path, var_type='quantile'):
    
    mean_data = full_stats.groupby('Gen').mean().reset_index()
    if var_type == 'sigma':
        std_data = full_stats.groupby('Gen').std()
    elif var_type == 'quantile':
        first_dec = full_stats.groupby('Gen').quantile(0.1)
        last_dec = full_stats.groupby('Gen').quantile(0.9)
    elif var_type == 'minmax':
        min_data = full_stats.groupby('Gen').min()
        max_data = full_stats.groupby('Gen').max()
    
    for env in ["A", "B"]:

        fig, ax1 = plt.subplots(figsize=(9, 4), dpi=dpi)
        delta_y = exp_params["nb_genes"] / 3 * 0.05 
        ax1.set_ylim(-delta_y, exp_params["nb_genes"] / 3 + delta_y)
        ax1.set_ylabel('Activated genes', fontsize=label_fontsize)
        ax1.set_xlabel('Generation', fontsize=label_fontsize)
        ax1.set_xscale('log')
        ax1.grid(linestyle=':')
        ax1.grid(visible=True, which="minor", axis='x', linestyle=':')

        for i_gene_type, gene_type in enumerate(gene_types):

            ax1.plot(mean_data['Gen'], mean_data[f"{gene_type}on_{env}"],
                     color=gene_type_color[i_gene_type],
                     linewidth=2,
                     label=gene_type)
            
            # Show 2-sigma (95%) confidence intervals
            if var_type == 'sigma':
                ax1.plot(mean_data['Gen'],
                         mean_data[f"{gene_type}on_{env}"] - 2 * std_data[f"{gene_type}on_{env}"],
                         color=gene_type_color[i_gene_type],
                         alpha=0.3)
                ax1.plot(mean_data['Gen'],
                         mean_data[f"{gene_type}on_{env}"] + 2 * std_data[f"{gene_type}on_{env}"],
                         color=gene_type_color[i_gene_type],
                         alpha=0.3)
            # Show first and last deciles
            elif var_type == 'quantile':
                ax1.plot(mean_data['Gen'],
                         first_dec[f"{gene_type}on_{env}"],
                         color=gene_type_color[i_gene_type],
                         alpha=0.3)
                ax1.plot(mean_data['Gen'],
                         last_dec[f"{gene_type}on_{env}"],
                         color=gene_type_color[i_gene_type],
                         alpha=0.3)
            # Show min and max values
            elif var_type == 'minmax':
                ax1.plot(mean_data['Gen'],
                         min_data[f"{gene_type}on_{env}"],
                         color=gene_type_color[i_gene_type],
                         alpha=0.3)
                ax1.plot(mean_data['Gen'],
                         max_data[f"{gene_type}on_{env}"],
                         color=gene_type_color[i_gene_type],
                         alpha=0.3)


        
        ax1.tick_params(axis='both', which='major', labelsize=tick_fontsize)


        #plt.title(f"Environment {env}")
        fig.legend(bbox_to_anchor=(0, 1),
                   bbox_transform=ax1.transAxes,
                   loc="upper left",
                   fontsize=legend_fontsize)   

        plt.savefig(f'{exp_path}/gene_activity_env_{env}.pdf', dpi=dpi, bbox_inches='tight')

        plt.show()

In [None]:
plot_gene_activity_all(exp_stats, exp_path)

# Influence of enviromental supercoiling on final gene expression levels

In [None]:
nb_sigmas = 250
sigma_min = -0.061
sigma_max = 0.061

In [None]:
def compute_activity_sigma_per_type(indiv, sigmas):
    
    # Initialize the individual
    indiv.evaluate(0.0, 0.0)

    activ = np.zeros((3, len(sigmas))) # Compute activity for each gene type

    for i_sigma, sigma_env in enumerate(sigmas):
        # Evaluate the individual in the environment with sigma
        temporal_expr = indiv.run_system(sigma_env)

        # Compute total gene activation levels        
        for i_gene, gene in enumerate(indiv.genes):
            activ[gene.gene_type][i_sigma] += temporal_expr[-1, i_gene]
            
    activ /= (indiv.nb_genes / 3)
    
    return activ

In [None]:
def compute_avg_best_activ_by_sigma():
    
    sigmas = np.linspace(sigma_min, sigma_max, nb_sigmas)
    activ = np.zeros((3, len(sigmas)))

    nb_indivs = 0

    for i_rep, rep_dir in enumerate(exp_rep_dirs):
        try:
            indiv = evotsc_lib.get_best_indiv(rep_dir, gen=gen)
        except FileNotFoundError:
            continue
            
        activ += compute_activity_sigma_per_type(indiv, sigmas)
        nb_indivs += 1
            
    return activ / nb_indivs

In [None]:
activ = compute_avg_best_activ_by_sigma()

In [None]:
def generate_rand_activ_by_sigma():

    sigmas = np.linspace(sigma_min, sigma_max, nb_sigmas)
    activ = np.zeros(len(sigmas))  # Average over all genes, not gene types
    
    for i_rand, rand_indiv in enumerate(rand_indivs):
        activ += np.mean(compute_activity_sigma_per_type(rand_indiv, sigmas), axis=0)
        
    return activ / len(rand_indivs)

In [None]:
rand_activ = generate_rand_activ_by_sigma()

In [None]:
# See how gene activity levels depend on environmental supercoiling
def plot_activity_sigma_per_type(activ, rand_activ, params, plot_title=None, plot_name=None):
    
    sigma_basal = params['sigma_basal']
    sigma_opt = params['sigma_opt']      
    
    sigmas_env = np.linspace(sigma_min, sigma_max, nb_sigmas)

    colors = ['tab:blue', 'tab:red', 'tab:green'] # AB: blue, A: red, B: green
    
    fig, ax = plt.subplots(figsize=(7, 4), dpi=dpi)
    
    plt.xlabel('Background supercoiling ($\sigma_{basal} + \delta\sigma_{env}$)')
    plt.ylabel('Average gene expression by type')
    plt.ylim(-0.05, 1.10)
    plt.xlim(sigmas_env[0] + sigma_basal, sigmas_env[-1] + sigma_basal)
    plt.grid(linestyle=':')
    
    # Add 1/2 expression level
    half_expr = (1 + np.exp(- params['m'])) / 2
    plt.hlines(half_expr, sigmas_env[0] + sigma_basal, sigmas_env[-1] + sigma_basal,
               linestyle=':', linewidth=1.5, color='tab:pink')#, label='Activation threshold')
    
    
    # Add average expression per gene type
    for i_gene_type, gene_type in enumerate(gene_types):
        plt.plot(sigmas_env + sigma_basal, activ[i_gene_type, :],
                 color=gene_type_color[i_gene_type],
                 linewidth=2,
                 label=gene_type)

    # Add sigma_A and sigma_B
    y_min, y_max = plt.ylim()
    plt.vlines(params['sigma_A'] + sigma_basal, y_min, y_max, linestyle='--', linewidth=1, color='black')
    plt.vlines(params['sigma_B'] + sigma_basal, y_min, y_max, linestyle='--', linewidth=1, color='black')
    
    plt.text(params['sigma_A'] + sigma_basal, y_max + 0.005, '$\sigma_A$',
             va='bottom', ha='center', fontsize='large') # Use \mathbf{} for bold
    plt.text(params['sigma_B'] + sigma_basal, y_max + 0.005, '$\sigma_B$',
             va='bottom', ha='center', fontsize='large')
    plt.ylim(y_min, y_max)
    
    # Add expression for a random genome
    plt.plot(sigmas_env + sigma_basal, rand_activ,
         linewidth=2, color='tab:cyan', zorder=0, linestyle=(0, (3, 1, 1, 1)), label='Random')
    
    # Add expression for an isolated gene
    sigmas_total = sigmas_env + sigma_basal
    activities = 1.0 / (1.0 + np.exp((sigmas_total - sigma_opt)/ params['epsilon']))
    plt.plot(sigmas_env + sigma_basal, np.exp(params['m'] * (activities - 1)),
             linewidth=2, color='tab:cyan', zorder=0, linestyle='--', label='Isolated gene')

    plt.legend(loc='lower left')
    
    # Add other ax with other sc
    ax2 = ax.twiny()
    xmin, xmax = ax.get_xlim()
    ax2.set_xlim(xmin - params['sigma_basal'], xmax - params['sigma_basal'])
    ax2.set_xlabel('Environmental shift in supercoiling ($\delta\sigma_{env}$)')
    
    # Wrap up            
    if plot_name:
        plt.savefig(plot_name, dpi=dpi, bbox_inches='tight')
        
    plt.show()
    plt.close()

In [None]:
plot_activity_sigma_per_type(activ, rand_activ, params=exp_params,
                             plot_name=exp_path.joinpath(f'activity_sigmas_avg.pdf'))