In [164]:
import sys
import jax.numpy as jnp
from jax import random
import jax
import time
import haiku as hk
from collections import namedtuple
from jax import jit
from typing import NamedTuple,Any,Callable,Tuple,List
sys.path.append('../')

In [3]:
from src.agent import GVF



In [137]:
class GVFObject(NamedTuple):
    params:Any
    apply_fn:Callable
    class_def:GVF
    
class TestGVF(GVF):
    def __init__(self,value) -> None:
        super().__init__()
        self.value=value
    
    def __call__(self, obs, act, prev_state)-> Tuple[Any,Any]:
        return self.value,self.value+1,self.value+2,None
    
    @staticmethod
    def initial_state():
        return None 

In [138]:
def initialize_gvf(gvf:GVF,key,*args,**kwargs):
    def forward(obs,act,prev_state):
        network=gvf(*args,**kwargs)
        return network(obs,act,prev_state)
    gvf_trf=hk.without_apply_rng(hk.transform(forward))
    params=gvf_trf.init(key,jnp.array(1),jnp.array(1),TestGVF.initial_state())
    return GVFObject(class_def=gvf,params=params,apply_fn=gvf_trf.apply)

In [139]:
key=random.PRNGKey(0)

In [140]:
ele=initialize_gvf(TestGVF,key,1)

In [231]:
gvf_list=[initialize_gvf(TestGVF,key,i) for i in range(1000)]

In [232]:
class Horde:
    def __init__(self,gvfs:List[GVFObject]):
        self.gvfs=gvfs
        self.params=[gvf.params for gvf in gvfs]
        self.apply_fn=[gvf.apply_fn for gvf in gvfs]
        self.last_states=[gvf.class_def.initial_state() for gvf in gvfs]
        @jit
        def forward(params,last_states,obs,last_act):
            cumulants=jnp.zeros(len(self.gvfs))
            policies=jnp.zeros(len(self.gvfs))
            gammas=jnp.zeros(len(self.gvfs))
            for i in range(len(self.gvfs)):
                policy,cumulant,gamma,last_states[i]=self.apply_fn[i](params[i],obs,last_act,last_states[i])
                policies=policies.at[i].set(policy)
                cumulants=cumulants.at[i].set(cumulant)
                gammas=gammas.at[i].set(gamma)
            return policies,cumulants,gammas,last_states
        self.forward=forward
                
    
    def step(self,obs,last_act):
        """
            Has the same interface as a predictor
        """
        policies,cumulants,gammas,self.last_states=self.forward(self.params,self.last_states,obs,last_act)
        return policies,cumulants,gammas
        

In [233]:
horde=Horde(gvf_list)

In [264]:
start=time.time()
value=horde.step(jnp.array(1.0),jnp.array(1.0))
time.time()-start

0.0023908615112304688

In [265]:
value

(DeviceArray([  0.,   1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,
               10.,  11.,  12.,  13.,  14.,  15.,  16.,  17.,  18.,  19.,
               20.,  21.,  22.,  23.,  24.,  25.,  26.,  27.,  28.,  29.,
               30.,  31.,  32.,  33.,  34.,  35.,  36.,  37.,  38.,  39.,
               40.,  41.,  42.,  43.,  44.,  45.,  46.,  47.,  48.,  49.,
               50.,  51.,  52.,  53.,  54.,  55.,  56.,  57.,  58.,  59.,
               60.,  61.,  62.,  63.,  64.,  65.,  66.,  67.,  68.,  69.,
               70.,  71.,  72.,  73.,  74.,  75.,  76.,  77.,  78.,  79.,
               80.,  81.,  82.,  83.,  84.,  85.,  86.,  87.,  88.,  89.,
               90.,  91.,  92.,  93.,  94.,  95.,  96.,  97.,  98.,  99.,
              100., 101., 102., 103., 104., 105., 106., 107., 108., 109.,
              110., 111., 112., 113., 114., 115., 116., 117., 118., 119.,
              120., 121., 122., 123., 124., 125., 126., 127., 128., 129.,
              130., 131., 132., 133., 

In [230]:
horde.last_states

[None]