## What does porting to JAX requires ?

- There are certain functionalities that need to be kept in mind. Replace the `numpy` imports with `jax.numpy`.

- JAX does fast vectorized operations. If there exists ways, rewrite loops using `jnp.map` and `jnp.vectorize`

- Decorate core funtions with `@jax.jit` to enable Just-In-Time compilation and optimize.

In [None]:
! pip install evojax

In [1]:
import os
import shutil
import jax

from dataclasses import dataclass
from evojax.task.slimevolley import SlimeVolley
from evojax.policy.mlp import MLPPolicy
from evojax import Trainer
from evojax import util
from hyp import hyp

from neat import NEATJax, loadHyp

os.environ['PYTHONFAULTHANDLER'] = '1'
os.environ['PYTHONTRACEMALLOC'] = '1'

In [2]:
@dataclass
class TrainingParams:
    NUM_TESTS = 100
    N_REPEATS = 16
    MAX_ITER = 50
    TEST_INTERVAL = 50
    LOG_INTERVAL = 10
    MAX_STEPS = 3000
    LOG_DIR = "./log/slimevolley"

tp = TrainingParams()

@dataclass
class PolicyParams:
    HIDDEN_SIZE = 20
    OUT_ACT_FN = 'tanh'

pp = PolicyParams()

In [3]:
if not os.path.exists(tp.LOG_DIR):
    os.makedirs(tp.LOG_DIR, exist_ok=True)

logger = util.create_logger(
    name='SlimeVolley', log_dir=tp.LOG_DIR, debug=True)

logger.info('EvoJAX SlimeVolley')

logger.info('=' * 30)

INFO:SlimeVolley:EvoJAX SlimeVolley


In [4]:
train_task = SlimeVolley(test=False, max_steps=tp.MAX_STEPS)
test_task = SlimeVolley(test=True, max_steps=tp.MAX_STEPS)

print("Input Shape: " ,train_task.obs_shape[0])
print("Output Shape: ", train_task.act_shape[0])

policy = MLPPolicy(
    input_dim=train_task.obs_shape[0],
    hidden_dims=[pp.HIDDEN_SIZE, ],
    output_dim=train_task.act_shape[0],
    output_act_fn='tanh',
)

Input Shape:  12
Output Shape:  3


INFO:MLPPolicy:MLPPolicy.num_params = 323


## We'll have our custom solver here. We'll see how that goes.
Only the CMA part needs to be changed a little bit. Otherwise, there is no change.

In [5]:
solver = NEATJax(hyp)

{'task': 'slimevolley', 'maxGen': 8, 'pop_size': 32, 'alg_nReps': 2, 'alg_speciate': 'neat', 'alg_probMoo': 0.0, 'alg_act': 5, 'prob_addConn': 0.05, 'prob_addNode': 0.03, 'prob_crossover': 0.8, 'prob_enable': 0.01, 'prob_mutAct': 0.0, 'prob_mutConn': 0.8, 'prob_initEnable': 1.0, 'select_cullRatio': 0.1, 'select_eliteRatio': 0.1, 'select_rankWeight': 'exp', 'select_tournSize': 2, 'spec_compatMod': 0.25, 'spec_dropOffAge': 64, 'spec_target': 4, 'spec_thresh': 2.0, 'spec_threshMin': 2.0, 'spec_geneCoef': 1, 'spec_weightCoef': 0.5, 'save_mod': 8, 'bestReps': 20}


In [6]:
SEED = 69 #Why not ?

hyp = loadHyp()

In [7]:

trainer = Trainer(
    policy=policy,

    solver=solver,

    train_task=train_task,

    test_task=test_task,

    max_iter=tp.MAX_ITER,

    log_interval=tp.LOG_INTERVAL,

    test_interval=tp.TEST_INTERVAL,

    n_repeats=tp.N_REPEATS,

    n_evaluations=tp.NUM_TESTS,

    seed=SEED,

    log_dir=tp.LOG_DIR,

    logger=logger,
)


INFO:SlimeVolley:use_for_loop=False


In [8]:
trainer.run(demo_mode=False)

INFO:SlimeVolley:Start to train for 50 iterations.


Problem in ask
Problem in _initPop


TypeError: Value '<ind.Ind object at 0x790764b73be0>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.

In [None]:
src_file = os.path.join(tp.LOG_DIR, 'best.npz')
tar_file = os.path.join(tp.LOG_DIR, 'model.npz')

shutil.copy(src_file, tar_file)

trainer.model_dir = tp.LOG_DIR
trainer.run(demo_mode=True)

INFO:SlimeVolley:Loaded model parameters from ./log/slimevolley.
INFO:SlimeVolley:Start to test the parameters.
INFO:SlimeVolley:[TEST] #tests=100, max=-3.0000, avg=-4.8300, min=-5.0000, std=0.4256


-4.83

In [None]:
task_reset_fn = jax.jit(test_task.reset)
policy_reset_fn = jax.jit(policy.reset)
step_fn = jax.jit(test_task.step)
action_fn = jax.jit(policy.get_actions)
best_params = trainer.solver.best_params[None, :]
key = jax.random.PRNGKey(0)[None, :]

In [None]:
task_state = task_reset_fn(key)
policy_state = policy_reset_fn(task_state)
screens = []
for _ in range(MAX_STEPS):
    action, policy_state = action_fn(task_state, best_params, policy_state)
    task_state, reward, done = step_fn(task_state, action)
    screens.append(SlimeVolley.render(task_state))

gif_file = os.path.join(LOG_DIR, 'slimevolley.gif')
screens[0].save(gif_file, save_all=True, append_images=screens[1:],
                duration=40, loop=0)
logger.info('GIF saved to {}.'.format(gif_file))

INFO:SlimeVolley:GIF saved to ./log/slimevolley/slimevolley.gif.


This is working !! We are getting a GIF as we want. The task now is to make a evojax algo for NEAT. This will require some really good tweaking. Will update as things go...

In [None]:
## The base class for NE Algos.

import copy
from abc import ABC
from abc import abstractmethod
from typing import Any
from typing import Dict
from typing import Union
import numpy as np
import jax.numpy as jnp

class NEAlgorithm(ABC):
    """Interface of all Neuro-evolution algorithms in EvoJAX."""

    pop_size: int

    @abstractmethod
    def ask(self) -> jnp.ndarray:
        """Ask the algorithm for a population of parameters.

        Returns
            A Jax array of shape (population_size, param_size).
        """
        raise NotImplementedError()

    @abstractmethod
    def tell(self, fitness: Union[jnp.ndarray, jnp.ndarray]) -> None:
        """Report the fitness of the population to the algorithm.

        Args:
            fitness - The fitness scores array.
        """
        raise NotImplementedError()

    def save_state(self) -> Any:
        """Optionally, save the state of the algorithm.

        Returns
            Saved state.
        """
        return None

    def load_state(self, saved_state: Any) -> None:
        """Optionally, load the saved state of the algorithm.

        Args:
            saved_states - The result of self.save_states().
        """
        pass

    @property
    def best_params(self) -> jnp.ndarray:
        raise NotImplementedError()

    @best_params.setter
    def best_params(self, params: Union[np.ndarray, jnp.ndarray]) -> None:
        raise NotImplementedError()


In [None]:
## Necessary Functions :

def getFronts(objVals):
  """Fast non-dominated sort.

  Args:
    objVals - (np_array) - Objective values of each individual
              [nInds X nObjectives]

  Returns:
    front   - [list of lists] - One list for each front:
                                list of indices of individuals in front

  Todo:
    * Extend to N objectives

  [adapted from: https://github.com/haris989/NSGA-II]
  """

  values1 = objVals[:,0]
  values2 = objVals[:,1]

  S=[[] for i in range(0,len(values1))]
  front = [[]]
  n=[0 for i in range(0,len(values1))]
  rank = [0 for i in range(0, len(values1))]
  # Get domination relations
  for p in range(0,len(values1)):
      S[p]=[]
      n[p]=0
      for q in range(0, len(values1)):
          if (values1[p] > values1[q] and values2[p] > values2[q]) \
          or (values1[p] >= values1[q] and values2[p] > values2[q]) \
          or (values1[p] > values1[q] and values2[p] >= values2[q]):
              if q not in S[p]:
                  S[p].append(q)
          elif (values1[q] > values1[p] and values2[q] > values2[p]) \
          or (values1[q] >= values1[p] and values2[q] > values2[p]) \
          or (values1[q] > values1[p] and values2[q] >= values2[p]):
              n[p] = n[p] + 1
      if n[p]==0:
          rank[p] = 0
          if p not in front[0]:
              front[0].append(p)

  # Assign fronts
  i = 0
  while(front[i] != []):
      Q=[]
      for p in front[i]:
          for q in S[p]:
              n[q] =n[q] - 1
              if( n[q]==0):
                  rank[q]=i+1
                  if q not in Q:
                      Q.append(q)
      i = i+1
      front.append(Q)
  del front[len(front)-1]
  return front

def getCrowdingDist(objVector):
  """Returns crowding distance of a vector of values, used once on each front.

  Note: Crowding distance of individuals at each end of front is infinite, as they don't have a neighbor.

  Args:
    objVector - (np_array) - Objective values of each individual
                [nInds X nObjectives]

  Returns:
    dist      - (np_array) - Crowding distance of each individual
                [nIndividuals X 1]
  """
  # Order by objective value
  key = jnp.argsort(objVector)
  sortedObj = objVector[key]

  # Distance from values on either side
  shiftVec = jnp.r_[jnp.inf,sortedObj,jnp.inf] # Edges have infinite distance

  warnings.filterwarnings("ignore", category=RuntimeWarning) # inf on purpose

  prevDist = jnp.abs(sortedObj-shiftVec[:-2])
  nextDist = jnp.abs(sortedObj-shiftVec[2:])

  crowd = prevDist+nextDist
  if (sortedObj[-1]-sortedObj[0]) > 0:
    crowd *= abs((1/sortedObj[-1]-sortedObj[0])) # Normalize by fitness range

  # Restore original order
  dist = jnp.empty(len(key))
  dist[key] = crowd[:]

  return dist


def nsga_sort(objVals, returnFronts=False):
  """Returns ranking of objective values based on non-dominated sorting.
  Optionally returns fronts (useful for visualization).

  NOTE: Assumes maximization of objective function

  Args:
    objVals - (np_array) - Objective values of each individual
              [nInds X nObjectives]

  Returns:
    rank    - (np_array) - Rank in population of each individual
            int([nIndividuals X 1])
    front   - (np_array) - Pareto front of each individual
            int([nIndividuals X 1])

  Todo:
    * Extend to N objectives
  """
  fronts = getFronts(objVals)

  # Rank each individual in each front by crowding distance
  for f in range(len(fronts)):
    x1 = objVals[fronts[f],0]
    x2 = objVals[fronts[f],1]
    crowdDist = getCrowdingDist(x1) + getCrowdingDist(x2)
    frontRank = jnp.argsort(-crowdDist)
    fronts[f] = [fronts[f][i] for i in frontRank]

  # Convert to ranking
  tmp = [ind for front in fronts for ind in front]
  rank = jnp.empty_like(tmp)
  rank[tmp] = jnp.arange(len(tmp))

  if returnFronts is True:
    return rank, fronts
  else:
    return rank

def rankArray(X):
  """Returns ranking of a list, with ties resolved by first-found first-order
  NOTE: Sorts descending to follow numpy conventions
  """
  tmp = jnp.argsort(X)
  rank = jnp.empty_like(tmp)
  rank[tmp] = jnp.arange(len(X))
  return rank

In [None]:
class NEAT(NEAlgorithm):
  def __init__(self, hyp):
      # Whether abstract classes require adding super().__init__() ?
      self.p       = hyp
      self.pop     = []
      self.species = []
      self.innov   = []
      self.gen     = 0

  def ask(self):

    """Returns newly evolved population"""
    if len(self.pop) == 0:
      self.initPop()

    else:

      self.probMoo()
      self.speciate()
      self.evolvePop()

    return self.pop

  def tell(self,reward):

    for i in range(np.shape(reward)[0]):
      self.pop[i].fitness = reward[i]
      self.pop[i].nConn   = self.pop[i].nConn

  def initPop(self) :

    ##  Create base individual
    p = self.p # readability

    # - Create Nodes -
    nodeId = jnp.arange(0,p['ann_nInput']+ p['ann_nOutput']+1,1)
    node = jnp.empty((3,len(nodeId)))
    node[0,:] = nodeId

    # Node types: [1:input, 2:hidden, 3:bias, 4:output]
    node[1,0]             = 4 # Bias
    node[1,1:p['ann_nInput']+1] = 1 # Input Nodes
    node[1,(p['ann_nInput']+1):\
           (p['ann_nInput']+p['ann_nOutput']+1)]  = 2 # Output Nodes

    # Node Activations
    node[2,:] = p['ann_initAct']
    # - Create Conns -
    nConn = (p['ann_nInput']+1) * p['ann_nOutput']
    ins   = jnp.arange(0,p['ann_nInput']+1,1)            # Input and Bias Ids
    outs  = (p['ann_nInput']+1) + jnp.arange(0,p['ann_nOutput']) # Output Ids

    conn = jnp.empty((5,nConn,))
    conn[0,:] = jnp.arange(0,nConn,1)      # Connection Id
    conn[1,:] = jnp.tile(ins, len(outs))   # Source Nodes
    conn[2,:] = jnp.repeat(outs,len(ins) ) # Destination Nodes
    conn[3,:] = jnp.nan                    # Weight Values
    conn[4,:] = 1                         # Enabled?

    # Create population of individuals with varied weights
    pop = []
    for i in range(p['popSize']):
        newInd = Ind(conn, node)
        newInd.conn[3,:] = (2*(jnp.random.rand(1,nConn)-0.5))*p['ann_absWCap']
        newInd.conn[4,:] = jnp.random.rand(1,nConn) < p['prob_initEnable']
        newInd.express()
        newInd.birth = 0
        pop.append(copy.deepcopy(newInd))
    # - Create Innovation Record -
    innov = jnp.zeros([5,nConn])
    innov[0:3,:] = pop[0].conn[0:3,:]
    innov[3,:] = -1

    self.pop = pop
    self.innov = innov

  def probMoo(self):
    """Rank population according to Pareto dominance. """

    meanFit = jnp.asarray([ind.fitness for ind in self.pop])
    nConns  = jnp.asarray([ind.nConn   for ind in self.pop])
    nConns[nConns==0] = 1 # No connections is pareto optimal but boring...
    objVals = jnp.c_[meanFit,1/nConns] # Maximize

    # Alternate between two objectives and single objective
    if self.p['alg_probMoo'] < jnp.random.rand():
      rank = nsga_sort(objVals[:,[0,1]])

    else: # Single objective
      rank = rankArray(-objVals[:,0])

    # Assign ranks
    for i in range(len(self.pop)):
      self.pop[i].rank = rank[i]