# Entrainer les agents addi

## Settings

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%reset -f

In [None]:
%cd /content/drive/MyDrive/RECHERCHE3
%ls

In [None]:
#ou bien 
#!cp -r drive/MyDrive/permanent/RECHERCHE2/ml_flux_euler/* .

In [None]:
import Euler2.core_solver as core
import Euler2.agent_addi as aga
import Euler2.neural_networks as nn

from Euler2.param import Param,Projecter
from Euler2.initial_conditions import *

#from Euler.backend import K
import matplotlib.pyplot as plt
import popup_lib.popup as pop
import numpy as np
import copy
import time

import tensorflow as tf

Attention: `BC_periodic` rajoute une discontinuité au bord pour chaque fonction non périodique.




In [None]:
param = Param(nx=1000,nx_ratio=10,
              BC_solver=Param.BC_reflexive,
              BC_model=Param.BC_reflexive)
nb_t=800

## Génération des données

In [None]:
def show_w_init(W):
    fig,axs=plt.subplots(3,3,sharex="all",sharey="all",figsize=(6,6))
    axs=axs.flatten()

    for i in range(9):
        axs[i].plot(W[i,:,0],label=r"$\rho$")
        axs[i].plot(W[i,:,1],label=r"$\rho V$")
        axs[i].plot(W[i,:,2],label="$E$")
    plt.legend()
    plt.show()

In [None]:
res=init_non_periodic(param,9)
print(res.shape)
show_w_init(res)

In [None]:
#attention, si les random sod vont trop proche de zéro => ça pète
res= init_random_sod(param, batch_size=10, minimum_E=0.05, maximal_jump_E=2, minimum_rho=0.05, maximal_jump_rho=2)
print(res.shape)
show_w_init(res)

* Il faut mettre des jeux d'entrainement assez gros pour limiter le hasard du sampling
* Quand, au bout d'une longue stagnation, tous les agents s'améliore d'un coup, c'est suspect ! 
* On ne peut pas mettre non plus des batchs trop gros à cause de la limite des GPU;  cela bloque très vite ex de jeu qui passe pas: 

        nb_t=1000, batch_size = 140, nx=1000.  

In [None]:
def data_for_train():
    return init_non_periodic(param,120)

Il est tout à fait possible de mettre dans val des K-tests déterministes

In [None]:
def data_for_val():
    return init_non_periodic(param,120)

### Test du solveur

In [None]:
def generate_fine_solutions(for_train):
    if for_train:
        w_init=data_for_train()
    else:
        w_init=data_for_val()

    
    return core.compute_solutions(param, nb_t, w_init,False)
W_val=generate_fine_solutions(True)
W_val.shape

## Entrainement

### Définition des familles d'agent

In [None]:
fam_size=6
#un durée pour que les agents puissent tous faire leur 'nb_optimization' optimizations
#attention, cela dépend beaucoup de la watch_duration
period_duration="6 steps"#"15 seconds" 
all_agents=[]

In [None]:
def family_full(window_size,color):
        name="full_"+str(window_size)
        agents=[]
        model_struct = (32,64,32)
        for _ in range(fam_size):
            model=nn.Difference_model_several_convo(param,4)
                 
            agent= aga.Agent_addi(param, 
                            model,
                            watch_duration=20, #20
                            lossCoef_stab=100, #10.
                            lossCoef_ridge=1e-3,
                            lossCoef_disHLL=1,
                            nb_optimization=15,
                            )
            agents.append(agent)
            all_agents.append(agent)
            
        return pop.Family_trainer(
            agents=agents,
            nb_bestweights_averaged=3,
            nb_strong=3,
            nb_weak=2,
            period_duration=period_duration,
            name=name,
            color=color)

family_trainers=[
                 family_full(5,"red"),
                 #family_full(7,"blue"),
                 ]



In [None]:
def load_data():
    ti=time.time()
    W_train = generate_fine_solutions(True)
    W_val = generate_fine_solutions(False)
    print(f"|génération des données, durée: {time.time()-ti:.2f} ",end="")
    
    ti=time.time()
    for agent in all_agents:
        agent.load_WY_train(W_train)
        agent.load_WY_valid(W_val)
    print(f"|load données,durée: {time.time()-ti:.2f}",end="")

### LA BOUCLE

In [None]:
try:
    for i in range(50):
        load_data()
        for family_trainer in family_trainers:
            family_trainer.period()
except KeyboardInterrupt:
    for family_trainer in family_trainers:
        #pour pouvoir reprendre si on veut
        family_trainer.interupt_period()

### Historique des `famparams`

In [None]:
def plot_history_famparams():
    # keys=["lossCoef_stab","lossCoef_ridge","lossCoef_disHLL","lossCoef_Laplacian","watch_duration","score"]
    keys=["lossCoef_stab","lossCoef_ridge","lossCoef_disHLL","watch_duration","score"]

    for k in keys:
        fig,ax=plt.subplots()
        for family_trainer in family_trainers:
            family_trainer.plot_metric(k,ax)

    plt.show()
plot_history_famparams()

### Historique des `loss`

In [None]:
def plot_history_losses():
    # keys=["loss_stab","loss_disHLL","loss_ridge","loss_Laplacian"]
    keys=["loss_stab","loss_disHLL","loss_ridge",
          "loss_stab*coef","loss_disHLL*coef","loss_ridge*coef",
          "score_l1","score_l2","score_linfty"
          ]

    for k in keys:
        fig,ax=plt.subplots()
        for family_trainer in family_trainers:
            family_trainer.plot_metric(k,ax)

    plt.show()
plot_history_losses()

## Choix d'une des familles

In [None]:
family_trainer=family_trainers[0]

### Valeurs des famparams

In [None]:

for name,agent_w in family_trainer.agents.items():
    print(name)
    print(agent_w.agent.famparams)
    print(agent_w.agent.agent_score())


best_agent=family_trainer.get_best_agent()
# best_agent=family_trainer.agents['a'].agent
print("best agent")
print(best_agent.famparams)
# print(best_agent.agent_score())


## Testons le long du temps

In [None]:
def compare_along_the_time(W_val):
    nb_t = len(W_val)
    
    best_agent.load_WY_valid(W_val)
    res_fine_proj,res_HLL,res_model=best_agent.predict()
        
    nb_batch_plot=5
    if nb_batch_plot>W_val.shape[1]:
        nb_batch_plot=W_val.shape[1]
    
    nb_t_plot=3
    #pour avoir le permier temps et le dernier
    t_to_plot=np.floor(np.linspace(0,nb_t-1,nb_t_plot)).astype(int)
    if nb_batch_plot==1:ax=ax[:,tf.newaxis]
    
    fig, ax = plt.subplots(nb_t_plot,nb_batch_plot,figsize=(15,nb_t_plot*2))
    for i,t in enumerate(t_to_plot):
        for j in range(nb_batch_plot):
            ax[i,j].plot(res_fine_proj[t,j,:,0],color="k",label="fine")
            ax[i,j].plot(res_HLL[t,j,:,0],color="b",label="HLL")
            ax[i,j].plot(res_model[t,j,:,0],color="r",label="model")
    ax[0,0].legend()    
    ax[0,0].set_title("Rho")
    fig.tight_layout()


    fig, ax = plt.subplots(nb_t_plot,nb_batch_plot,figsize=(15,nb_t_plot*2))
    for i,t in enumerate(t_to_plot):
        for j in range(nb_batch_plot):
            ax[i,j].plot(res_fine_proj[t,j,:,2],color="k",label="fine")
            ax[i,j].plot(res_HLL[t,j,:,2],color="b",label="HLL")
            ax[i,j].plot(res_model[t,j,:,2],color="r",label="model")
    ax[0,0].legend()    
    ax[0,0].set_title("Energie")
    fig.tight_layout()


    errors_HLL=tf.reduce_mean(tf.abs(res_fine_proj-res_HLL),axis=[1,2,3])
    errors_model=tf.reduce_mean(tf.abs(res_fine_proj-res_model),axis=[1,2,3])
    
    fig,ax=plt.subplots()
    ax.plot(errors_HLL,color="b",label="HLL")
    ax.plot(errors_model,color="r",label="model")

    # if (np.max(errors_model)>0.1):
    #     ax.set_ylim(0,0.1)
    
    ax.legend()
    ax.set_xlabel("time")

    plt.show()

### kind Periodic

In [None]:
param.BC_model=param.BC_solver=Param.BC_periodic 
w_init=init_periodic(param,50)
W_val = core.compute_solutions(param, nb_t, w_init,False)
print(W_val.shape)
compare_along_the_time(W_val)

In [None]:
param.BC_model=param.BC_solver=Param.BC_reflexive
w_init=init_periodic(param,50)
W_val = core.compute_solutions(param, nb_t, w_init,False)
print(W_val.shape)
compare_along_the_time(W_val)

### kind non-periodic

In [None]:
param.BC_model=param.BC_solver=Param.BC_neumann 

In [None]:
w_init=init_non_periodic(param,50)
W_val = core.compute_solutions(param, nb_t, w_init,False)
W_val.shape

In [None]:
compare_along_the_time(W_val)

### kind Sod

In [None]:
def deterministic_SOD(batch_size):
    x = tf.range(param.xmin, param.xmax, param.dx)
            
    #La densité peu être assez faible: 0.1 OK, 0.01-> quelques oscilations
    rho_0 =  tf.where(x < 0.5, 1., 0.5 )+1 #0.125
    #Par contre la pression ne doit pas proche de zéro: avec 0.01 cela craque tout de suite
    P_0 =    tf.where(x < 0.5, 1., 0.5)+1 #0.1
    P_coefs=np.linspace(1,3,batch_size)

    rho=[rho_0 for _ in range(batch_size)]
    #rho=[rho_0*coef for coef in P_coefs]
    P=np.array([P_0*coef for coef in P_coefs])

    rhoV = tf.zeros([batch_size, param.nx])
    E=P#/(param.gamma-1)

    return np.stack([rho,rhoV,E],axis=2)

In [None]:
#w_init=random_SOD(5)
w_init=deterministic_SOD(5)

param.BC_model=param.BC_solver=param.BC_neumann
W_val = core.compute_solutions(param, nb_t, w_init,False)
W_val.shape

In [None]:
compare_along_the_time(W_val)