In [194]:
import haiku as hk
import jax
import optax
from jax import random
from jax import numpy as jnp
from jax import jit

In [217]:
#define networks for agent as in torch implementation

class ContextPolicy(hk.Module):
    def __init__(self, name=None):
        super().__init__(name=name)    
        
        
        self.conv1 = hk.Conv2D(32, 2, stride=2) #out_channels, kernel_size (stride defaults to 1, in_channels done automatically) 
        self.bn1 = hk.BatchNorm(False, False, 0.999)
        self.conv2 = hk.Conv2D(64, 2, stride=2)
        self.bn2 = hk.BatchNorm(False, False, 0.999)
        self.fc = hk.Linear(512)
        self.head = hk.Linear(1)

    def __call__(self, x, is_training=True):
    
        x = self.conv1(x)

        x = jax.nn.relu(self.bn1(x, is_training))

        x = self.conv2(x)
  
        x = self.bn2(x, is_training)

        x = jax.nn.relu(x)
        #x = jax.nn.relu(self.bn2(self.conv2(x), is_training))
 
        x = jax.nn.relu(self.fc(x)) #check correct when compared to torch version

        x = jax.nn.sigmoid(self.head(x))

        return x

def _context_forward(x):
    module = ContextPolicy()
    return module(x)



# # input size: (1, 1, 4, 64)

class TacPolicy(hk.Module):
    def __init__(self, action_size, name=None):
        super().__init__(name=name)    
        
        
        self.conv1 = hk.Conv2D(32, 2, stride=2) #out_channels, kernel_size (stride defaults to 1, in_channels done automatically) 
        self.bn1 = hk.BatchNorm(False, False, 0.999)
        self.conv2 = hk.Conv2D(64, 2, stride=2)
        self.bn2 = hk.BatchNorm(False, False, 0.999)
        self.fc = hk.Linear(512)
        self.head = hk.Linear(action_size)

    def __call__(self, x, is_training=True):
        x = self.conv1(x)
        x = jax.nn.relu(self.bn1(x, is_training))
        x = jax.nn.relu(self.bn2(self.conv2(x), is_training))
        x = jax.nn.relu(self.fc(x)) #check correct when compared to torch version
        x = jax.nn.softmax(self.head(x))
        
        return x

def _tac_forward(x, action_size):
    module = TacPolicy(action_size)
    return module(x)





class ArgPolicy(hk.Module):
    def __init__(self, hidden_dim, name=None):
        super().__init__(name=name)    
        self.lstm = hk.LSTM(hidden_dim)
        
        self.conv1 = hk.Conv2D(32, 2, stride=2) #out_channels, kernel_size (in_channels done automatically) 
        self.bn1 = hk.BatchNorm(False, False, 0.999)
        self.conv2 = hk.Conv2D(64, 2, stride=2)
        self.bn2 = hk.BatchNorm(False, False, 0.999)
        self.conv3 = hk.Conv2D(128, 2, stride=2)
        self.bn3 = hk.BatchNorm(False, False, 0.999)

        self.fc = hk.Linear(128)
        self.head = hk.Linear(1)

        

    # x is the previously predicted argument / tactic.
    # candidates is a matrix of possible arguments concatenated with the hidden states.

    def __call__(self, x, candidates, hidden, is_training=True):
        #x = jnp.reshape(x, (1,-1))
        
        s = self.conv1(candidates)
        s = jax.nn.relu(self.bn1(s, is_training))
        s = jax.nn.relu(self.bn2(self.conv2(s), is_training))
        s = jax.nn.relu(self.bn3(self.conv3(s), is_training))

        s = jax.nn.relu(self.fc(s)) #check correct when compared to torch version i.e. s.view(s.size(0), -1)
        scores = jax.nn.sigmoid(self.head(s))
                
        o, hidden = self.lstm(x, hidden)
        
        return hidden, scores

def _arg_forward(x, hidden_dim, candidates, hidden, is_init=False):
    module = ArgPolicy(hidden_dim)
    if is_init:
        #TODO add batch dimension
        #TODO intitialise initial state as g
        return module(x, candidates, module.lstm.initial_state(10))
    return module(x, candidates, hidden)



class TermPolicy(hk.Module):
    def __init__(self, name=None):
        super().__init__(name=name)    
        
        
        self.conv1 = hk.Conv2D(32, 2) #out_channels, kernel_size (stride defaults to 1, in_channels done automatically) 
        self.bn1 = hk.BatchNorm(False, False, 0.999)
        self.conv2 = hk.Conv2D(64, 2)
        self.bn2 = hk.BatchNorm(False, False, 0.999)
        self.fc = hk.Linear(128)
        self.head = hk.Linear(1)

    def __call__(self, x, is_training=True):
        x = self.conv1(x)
        x = jax.nn.relu(self.bn1(x, is_training))
        x = jax.nn.relu(self.bn2(self.conv2(x), is_training))
        x = jax.nn.relu(self.fc(x)) #check correct when compared to torch version
        x = jax.nn.sigmoid(self.head(x))
        
        return x
    
    
    
def _term_forward(x):
    module = TermPolicy()
    return module(x)



In [218]:
init_context, apply_context = hk.transform_with_state(_context_forward)
init_tac, apply_tac = hk.transform_with_state(_tac_forward)
init_arg, apply_arg = hk.transform_with_state(_arg_forward)
init_term, apply_term = hk.transform_with_state(_term_forward)


In [235]:
rng_key = random.PRNGKey(100)

batch_size = 10
goal_dim = 256

#goal and tactic networks, which take as input a goal
x_term = random.normal(rng_key, (batch_size, goal_dim))
x_tac = random.normal(rng_key, (batch_size, goal_dim))

#candidate network, TODO make sure shapes match what is expected given the old implementation with e.g. MAX_LEN, MAX_CONTEX
                      
c_arg = random.normal(rng_key, (batch_size, goal_dim))

h1 = random.normal(rng_key, (batch_size, goal_dim))

x_arg = random.normal(rng_key, (batch_size, goal_dim))

x_context = random.normal(rng_key, (batch_size, goal_dim))


In [237]:
# initial_params_context, initial_state_context = init_context(rng_key, x_context)
# initial_params_tac, initial_state_tac = init_tac(rng_key, x_tac, 4)
# initial_params_arg, initial_state_arg = init_arg(rng_key, x_arg, 128, c_arg, h1, True)
# initial_params_term, initial_state_term = init_term(rng_key, x_term)

In [238]:
# out_context, new_state_context = apply_context(initial_params_context, initial_state_context, rng_key, x_context)
# out_tac, new_state_tac = apply_tac(initial_params_tac, initial_state_tac, rng_key, x_tac, 4)
# out_arg, new_state_arg = apply_arg(initial_params_arg, initial_state_arg, rng_key, x_arg, 128, c_arg, h1, True)
# out_term, new_state_term = apply_term(initial_params_term, initial_state_term, rng_key, x_term)

# print (out_arg)