## Popup

In [None]:
import numpy as np
import time
from typing import List,Dict,Set
from abc import ABC, abstractmethod
from collections import deque
pp=print
import copy

#todo metric multi-dim

class Param(ABC):
    def __init__(self):
        self.name="toto"
        self.value=0
    @abstractmethod
    def perturb(self):
        pass


class History:

    def __init__(self):

        self.localTime_previous_cumulated_values = 0
        self.score_times =  []
        self.metrics_times:Dict[str,List[float]] = {}
        self.metrics_flags: Dict[str, Set[int]] = {}
        self.metrics_values:Dict[str,List[float]]={}
        self.localTime_isActive=False

    """Important: il faut démarrer le local time dès que l'agent est actif. Puis l'arrêter dès qu'il est inactif """
    def start_local_time(self):
        self.localTime_restart=time.time()
        assert not self.localTime_isActive , "on a oublié d'arrêter le localTime"
        self.localTime_isActive =True

    def stop_local_time(self):
        self.localTime_previous_cumulated_values=self.get_local_time()
        assert self.localTime_isActive , "le localTime n'est pas actif"
        self.localTime_isActive =False

    def get_local_time(self):
        if not self.localTime_isActive: #c'est logique
            return self.localTime_previous_cumulated_values
        else:
            return time.time()-self.localTime_restart+self.localTime_previous_cumulated_values


    def record_metric(self, name,value,addTag=False):
        assert is_number(value) ,f"Error when recording {name} value must be a number, but is: {value} of type {type(value)}"
        metric_times=self.metrics_times.get(name,[])
        metric_values=self.metrics_values.get(name,[])
        metric_times.append(self.get_local_time())
        metric_values.append(value)
        self.metrics_times[name]=metric_times
        self.metrics_values[name]=metric_values

        if addTag:
            metric_timeFlags=self.metrics_flags.get(name, set())
            metric_timeFlags.add(len(metric_times)-1)
            self.metrics_flags[name]=metric_timeFlags


def is_number(x):
    return isinstance(x, (int, float, np.float32, np.float64, np.int32, np.int64))




class Abstract_Agent(ABC):

    @abstractmethod
    def optimize_return_score(self)->float:
        pass


    @abstractmethod
    def set_famparams(self, dico):
        pass

    @abstractmethod
    def get_famparams(self):
        pass

    @abstractmethod
    def perturb_famparams(self):
        pass

    """ faut-il plutot demandé un dico à l'utilisateur ? """
    @abstractmethod
    def set_weights(self,weights:List):
        pass

    @abstractmethod
    def get_copy_of_weights(self)->List:
        pass

    #facultatif

    def weights_info_to_register(self)->Dict[str,float]:
        # by default an empty dictionnary
        return dict()

    def to_register_at_period_end(self)->Dict[str,float]:
        #by default an empty dictionnary
        return dict()

    def on_overfitting(self):
        #appeler quand le score de validation est très infieur au score train
        #typiquement: on augmente les famparam qui baissent le sur-apprentissage
        pass

class Agent_wraper:

    def __init__(self, name, agent:Abstract_Agent, max_nb_best):
        self.name=name
        self.agent=agent
        self.max_nb_best=max_nb_best

        self.best_weights = deque(maxlen=max_nb_best)
        self.best_score = None
        self.best_famparams: Dict[str, Param]

        self.mutation_names=[]
        self.score=-float("inf")

    def get_name_suffixed(self):
        res=self.name
        for mut in self.mutation_names:
            res+="_"+mut
        return res

    def save_at_best(self,new_best_score:float):
        self.best_weights.append(self.agent.get_copy_of_weights())
        self.best_score=new_best_score
        self.best_famparams=copy.deepcopy(self.agent.get_famparams())

    def mean_of_best_weights(self):
        nb_wei = len(self.best_weights[0])
        nb_best = len(self.best_weights)
        new_wei_list = []
        for j in range(nb_wei):
            shape = self.best_weights[0][j].shape
            res = np.zeros(shape)
            for i in range(nb_best):
                res += self.best_weights[i][j]
            res /= nb_best
            new_wei_list.append(res)
        return new_wei_list


    #exploitation
    def load_from_another(self,other_agent_w:'Agent_wraper'):

        """les weights (ou learning variable) d'un réseau sont toujours une liste ou un tuple de tenseur"""
        new_wei_list=other_agent_w.mean_of_best_weights()

        self.agent.set_weights(new_wei_list)
        self.agent.set_famparams(copy.deepcopy(other_agent_w.best_famparams))

        """
         Ici, c'est pas l'idéal: on attribut comme best_score e best_score du l'other_agent.
         Or le mutant devrait faire ses preuves avant de passer dans la liste des premier.
         Mais on veut éviter qu'il ne soit remuter immédiatement après (il faudrait mettre en place un mécanisme de bonus pour les jeunes mutants)
         """
        self.score=other_agent_w.best_score
        """ne pas oublier de vider la liste des best_weights. Les anciens peuvent être assez différents : le changement de famparams peu induire des changement brutal de poids (ex: coef de pénalisation)"""
        self.best_weights=deque(maxlen=self.max_nb_best)
        self.best_weights.append(new_wei_list)


class Family_trainer:
    instance_count=0
    def __init__(self,
                 agents:List[Abstract_Agent],
                 ratio_weak=0.4,
                 ratio_strong=0.4,
                 nb_bestweights_averaged=5,
                 color="k",
                 min_interessant_score=-float("inf"),  # en-dessous de ce score, on mute aléatoirement.
                 name=None,
                 # à la fin de chaque période on fait les mutations
                 # une période dure "periode_duration"
                 periode_duration=10,
                 # "step" = un appelle de optimize_return_score
                 period_duration_unity="step"  #ou "seconde"
                 ):

        Family_trainer.instance_count+=1

        self.min_interessant_score=min_interessant_score
        self.name=name if name is not None else "fam_"+str(Family_trainer.instance_count)
        self.periode_duration=periode_duration
        assert period_duration_unity=="step" or period_duration_unity=="seconde", "period_duration_unity must be 'step' or 'minute'"
        self.period_duration_unity=period_duration_unity

        self.ratio_weak=ratio_weak
        self.nb_bestweights_averaged=nb_bestweights_averaged
        self.ratio_strong=ratio_strong
        self.color=color

        self._period_count=-1

        #un dico car on voudrait pouvoir supprimer des agents
        self.agents:Dict[str,Agent_wraper] = {}
        self.history=History()

        for i,agent in enumerate(agents):
            agent_name=str(i)
            agent_w = Agent_wraper(agent_name, agent, self.nb_bestweights_averaged)
            self.agents[agent_name] = agent_w
            for k, v in agent.get_famparams().items():
                self.history.record_metric(k, v)


    def period(self):
        self.history.start_local_time()
        self._period_count+=1

        if self.period_duration_unity== "seconde":
            print(f"\n{self.name}, period {self._period_count}, each agent turns {self.periode_duration} secondes:", end="")
        else:
            print(f"\n{self.name}, period {self._period_count}, each agent turns {self.periode_duration} steps:", end="")

        for agent_w in self.agents.values():
            self._period_one_agent(agent_w)

        if len(self.agents)>1:
            self.mutation()
        self.history.stop_local_time()


    def _period_one_agent(self, agent_w:Agent_wraper):

        ti0=time.time()
        optimization_step=0
        ok=True
        while ok:
            optimization_step+=1
            if self.period_duration_unity== "seconde":
                ok= time.time() - ti0 < self.periode_duration
            else:
                ok=optimization_step<self.periode_duration

            score=agent_w.agent.optimize_return_score()
            if np.isnan(score): score = -float("inf")

            self.history.record_metric("score", score)

            if  agent_w.best_score is None or score>agent_w.best_score:
                agent_w.save_at_best(score)
                print(" "+agent_w.name+"↗"+str(np.round(score,4)),end="")
            else:
                print(".",end="")

        for k, v in agent_w.agent.to_register_at_period_end().items():
            self.history.record_metric(k, v)


    def mutation(self):

        sorted_agents = sorted(self.agents.values(), key=lambda a_w: a_w.best_score)
        nb_weak=int( len(self.agents) * self.ratio_weak)
        nb_strong=int(len(self.agents) * self.ratio_strong)
        if nb_weak==0: nb_weak=1
        if nb_strong==0: nb_strong=1

        nb_interessant_strong=[]
        for i in range(nb_strong):
            if sorted_agents[-i+1].best_score>self.min_interessant_score:
                nb_interessant_strong.append(sorted_agents[-i+1])

        if len(nb_interessant_strong)==0:
            print("\nATTENTION: aucun des agents n'a un score interessant: on perturbe la moitité des agents" )
            for agent_w in self.agents.values():
                if np.random.random()>0.5:
                    agent_w.agent.perturb_famparams()
            return

        weaks = sorted_agents[:nb_weak]

        print(", mutations:", end="")
        for i in range(nb_weak):
            weak = weaks[i]
            strong_index=np.random.randint(len(nb_interessant_strong))
            strong  = nb_interessant_strong[strong_index]
            """ici on fait de l'early stopping en récupérant les meilleurs poids et meilleurs famparam du meilleurs agent.
               cela évite aussi d'avoir exactement deux fois le même agent.
            """
            weak.load_from_another(strong)
            weak.agent.perturb_famparams()

            for k, v in weak.agent.get_famparams().items():
                """on ajoute un tag pour repérer ces nouveau poids """
                self.history.record_metric(k, v)
            for k, v in weak.agent.weights_info_to_register().items():
                self.history.record_metric(k, v)

            weak.mutation_names.append(strong.name)
            print(f"{weak.get_name_suffixed()}|", end="")


    def plot_metric(self, metric: str, ax, transformation=None):

        ax.set_xlabel("local time")
        ax.set_ylabel(metric)

        x = self.history.metrics_times[metric]
        y = self.history.metrics_values[metric]
        if transformation is not None:
            y=transformation(y)

        flags=self.history.metrics_flags.get(metric,set())
        for i in range(len(x)):
            if i in flags:
                ax.plot(x[i],y[i],self.color+"+")
            else:
                ax.plot(x[i], y[i], self.color + ".")

    def plot_two_metrics(self,metric0:str,metric1:str,ax):

        ax.set_xlabel(metric0)
        ax.set_ylabel(metric1)

        x = self.history.metrics_values[metric0]
        y = self.history.metrics_values[metric1]
        assert len(x)==len(y),f"les métricques {metric0} et {metric1} n'ont pas été enregistrée le même nombre de fois"

        flags0 = self.history.metrics_flags.get(metric0,set())
        flags1 = self.history.metrics_flags.get(metric0,set())
        flags=flags0.union(flags1)
        for i in range(len(x)):
            if i in flags:
                ax.plot(x[i], y[i], self.color + "o", alpha=i / len(x))
            else:
                ax.plot(x[i], y[i], self.color + ".", alpha=i / len(x))


    def stats_of_best(self,nb_best=None):
        if nb_best is None:
            nb_best = int(len(self.agents) * self.ratio_strong)
        assert nb_best<=len(self.agents)

        agent_w_sorted=sorted(self.agents.values(),key=lambda ag:ag.best_score)
        best:List[Agent_wraper]=agent_w_sorted[-nb_best-1:]

        res=dict()
        sum_score=0
        for agent in best:
            sum_score+=agent.best_score
            for k,v in agent.best_famparams.items():
                res[k]=res.get(k,0)+v*agent.best_score

        for k,v in res.items():
            res[k]/=sum_score

        return res

    def get_best_agent(self,mean_its_weights=True)->Abstract_Agent:
        agent_w_sorted = sorted(self.agents.values(), key=lambda ag: ag.best_score)
        best=agent_w_sorted[-1]

        if mean_its_weights:
            weights=best.mean_of_best_weights()
        else:
            weights=best.best_weights[-1]
        best.agent.set_weights(weights)

        return best.agent


## DDQN



## Entrainement