# Minigame 12: Batched Evaluation

In this exploration, we're looking at the ability to take policies trained on small patches and to apply them to larger meshes by amalgamating the logits from each of the NN evaluations.

In [None]:
import math
from math import sin,cos
import random

In [None]:
import gym
from gym import spaces, utils
import numpy as np
import ray
import ray.rllib.agents.ppo as ppo

In [None]:
from glvis import glvis, to_stream
from ipywidgets import Layout

In [None]:
import matplotlib.pyplot as plt

In [None]:
from mfem import path
import mfem.ser as mfem

Define some synthetic test functions.

In [None]:
def rotate(x,theta):
    x0 = x[0]
    y0 = x[1]
    x1 = x0*cos(theta)-y0*sin(theta)
    y1 = x0*sin(theta)+y0*cos(theta)
    return [x1,y1]

In [None]:
def step(x):
    x0 = x[0]
    if (x0 < 0.0):
        return 1.0
    else:
        return 0.0

In [None]:
def rotated_step(x, theta):
    xr = rotate(x,theta)
    return step(xr)

In [None]:
def bump(x):
    rsq = x[0]**2 +x[1]**2
    return math.exp(-rsq)

In [None]:
def smooth_step(x):
    return 0.5*(1.0 +math.tanh(x[0]))

In [None]:
def rotated_smooth_step(x,theta):
    xr = rotate(x,theta)
    return smooth_step(xr)

Create classes where we can set the parameters and then eval a bunch of points.

In [None]:
class Step(mfem.PyCoefficient):
    
    def SetParams(self):
        self.theta = random.uniform(0.0, 2.0*math.pi)
        self.dx = [random.uniform(-1.0, 1.0),random.uniform(-1.0, 1.0)]
        
    def EvalValue(self, x):
        return rotated_step(x+self.dx, self.theta)

In [None]:
class Bump(mfem.PyCoefficient):
    
    def SetParams(self):
        self.width = random.uniform(0.1,1.0)
        self.xc = [0.5,0.5]
        self.dx = [random.uniform(-0.5, 0.5),random.uniform(-0.5, 0.5)]

    def EvalValue(self, x):
        return bump((x-self.xc+self.dx)/self.width)

In [None]:
class TwoBump(mfem.PyCoefficient):
    
    def SetParams(self):
        self.width1 = random.uniform(0.1,0.5)
        self.width2 = random.uniform(0.1,0.5)
        self.xc1 = [0.5,0.5]
        self.xc2 = [0.5,0.5]
        self.dx1 = [random.uniform(-0.5, 0.5),random.uniform(-0.5, 0.5)]
        self.dx2 = [random.uniform(-0.5, 0.5),random.uniform(-0.5, 0.5)]

    def EvalValue(self, x):
        #return max(bump((x-self.xc1+self.dx1)/self.width1),bump((x-self.xc2+self.dx2)/self.width2))
        return 0.5*(bump((x-self.xc1+self.dx1)/self.width1)+bump((x-self.xc2+self.dx2)/self.width2))

In [None]:
class SmoothStep(mfem.PyCoefficient):
    
    def SetParams(self):
        self.width = random.uniform(5.0, 10.0)
        self.xc = [0.5,0.5]
        self.dx = random.uniform(-0.5,0.5)
        self.theta = random.uniform(0.0, 2.0*math.pi)

    def EvalValue(self, x):
        x -= self.xc
        x += self.dx
        return rotated_smooth_step(x*self.width, self.theta)

In [None]:
class BumpsAndSmoothStep(mfem.PyCoefficient):
    
    def SetParams(self):
        self.bump = Bump()
        self.bump.SetParams()
        self.smooth_step = SmoothStep()
        self.smooth_step.SetParams()

    def EvalValue(self, x):
        return 0.5*self.bump.EvalValue(x)+0.5*self.smooth_step.EvalValue(x)


Visualize an instance of the test function. Note that each instance has randomly chosen parameters.  For the steps, it's a rotation angle and a displacement.  For the bumps, it's a width and a displacement.

In [None]:
mesh = mfem.Mesh('inline-quad.mesh')
mesh.UniformRefinement()
mesh.UniformRefinement()
fec = mfem.L2_FECollection(p=1, dim=2)
fes = mfem.FiniteElementSpace(mesh, fec)
u = mfem.GridFunction(fes)
c = BumpsAndSmoothStep()
c.SetParams()
u.ProjectCoefficient(c)

In [None]:
glvis((mesh, u), 500, 500)

Create the gym environment. Note that in this case, this can be just a dummy environment that only serves to define the observation and action spaces for the purposes of evaluation of the policy.

In [None]:
class AMRGameDummy(gym.Env):
        
    # In RLlib, you need the config arg
    def __init__(self,config):
        self.meshfile = 'inline-quad-7.mesh'
        self.mesh = mfem.Mesh(self.meshfile)
        
        # The only reason we need to create a fespace and gf here
        # is to find the sizes needed for the action and observation spaces
        dim = self.mesh.Dimension()
        self.order = 1
        self.fec = mfem.L2_FECollection(self.order, dim)
        self.fes = mfem.FiniteElementSpace(self.mesh, self.fec)
        self.u = mfem.GridFunction(self.fes);

        # actions are: refine each element, or do nothing
        self.action_space = spaces.Discrete(self.mesh.GetNE())
        
        # observation space: DOFs
        self.observation_space = spaces.Box(-1.0, 1.0, shape=(self.u.Size(),), dtype=np.float32)
        
    def step(self, action):
        pass
    
    def reset(self):
        pass
    
    def render(self):
        pass

Now we want to load a trained policy, and apply it in a strided way.

In [None]:
ray.shutdown()
ray.init(ignore_reinit_error=True)

In [None]:
config = ppo.DEFAULT_CONFIG.copy()
config['framework'] = 'tfe'
agent = ppo.PPOTrainer(config, env=AMRGameDummy)

Restore a policy

In [None]:
agent.restore("/home/rwa/ray_results/PPO_AMRGame_2021-02-15_22-07-44llq8wv6i/checkpoint_20/checkpoint-20")

In [None]:
policy = agent.get_policy()

Now we want to create the larger problem we'll be applying this local indicator on.

In [None]:
mesh = mfem.Mesh('inline-quad-20.mesh')
print(mesh.GetNE())
fec = mfem.L2_FECollection(p=1, dim=2)
fes = mfem.FiniteElementSpace(mesh, fec)
u = mfem.GridFunction(fes)
u0 = mfem.GridFunction(fes)
coeff = BumpsAndSmoothStep()
coeff.SetParams()
u.ProjectCoefficient(coeff)
u0.Assign(u) # save so we can restore later if desired
    
def new_function():
    global mesh, fec, fes, u, u0
    mesh = mfem.Mesh('inline-quad-20.mesh')
    fec = mfem.L2_FECollection(p=1, dim=2)
    fes = mfem.FiniteElementSpace(mesh, fec)
    u = mfem.GridFunction(fes)
    u0 = mfem.GridFunction(fes)
    c = BumpsAndSmoothStep()
    c.SetParams()
    u.ProjectCoefficient(c)
    u0.Assign(u) # save so we can restore later if desired
    return glvis((mesh, u), 400, 400,layout = Layout(width='100%', height='400px'))

In [None]:
def restore_function():
    global mesh, fec, fes, u
    mesh = mfem.Mesh('inline-quad-20.mesh')
    fec = mfem.L2_FECollection(p=1, dim=2)
    fes = mfem.FiniteElementSpace(mesh, fec)
    u = mfem.GridFunction(fes)
    u.Assign(u0)
    return glvis((mesh, u), 400, 400,layout = Layout(width='100%', height='400px'))

Build a map from each element to the elements which consist of the "stencil" around it. Since not every element has a full stencil, use a dictionary that only contains the elements containing full stencils as keys.

In [None]:
def build_stencils(mesh, width):
    els = {}
    nx = math.sqrt(mesh.GetNE())
    dx = 1.0/nx
    dim = mesh.Dimension()
    els = {}
    hw = int(width/2)
    c = mfem.Vector(dim)
    x = mfem.Vector(dim)
    for k in range(0,mesh.GetNE()):
        els[k] = []
        mesh.GetElementCenter(k,c)
        full = True
        for j in range(-hw,hw+1):
            for i in range(-hw,hw+1):
                x[0] = c[0]+i*dx
                x[1] = c[1]+j*dx
                if (x[0] < 0.0): full = False
                if (x[0] > 1.0): full = False
                if (x[1] < 0.0): full = False
                if (x[1] > 1.0): full = False
                pt = [[x[0],x[1]]]
                n, el, ip = mesh.FindPoints(pt)
                els[k].append(el[0])
        if (not full):
            els.pop(k)
    return els

Create a function and build the stencils for it.

In [None]:
new_function()
width=7
els = build_stencils(mesh, width)

Create the local observation mesh into which we will copy the dofs for the purposes of creating an observation vector.

In [None]:
obs_mesh = mfem.Mesh('inline-quad-7.mesh')
obs_fec = mfem.L2_FECollection(p=1, dim=2)
obs_fes = mfem.FiniteElementSpace(obs_mesh, obs_fec)
obs_u = mfem.GridFunction(obs_fes)
glvis((obs_mesh), 400, 400,layout = Layout(width='100%', height='400px'))

Also build a 0th order L2 field to look at per-element quantities (like logits or prob dist).

In [None]:
fec0 = mfem.L2_FECollection(p=0, dim=2)
fes0 = mfem.FiniteElementSpace(obs_mesh, fec0)
obs_u0 = mfem.GridFunction(fes0)

Now we need a mapping from the "logical" space of the observation mesh into element ids. This has the same ordering as the stencil elements, so we can form a mapping for the purposes of data transfer from the src mesh to the obs mesh.

In [None]:
def build_map(obs_mesh, width):
    id_map = []
    c = [0.5, 0.5]
    x = [0.0, 0.0]
    dx = 1./width
    hw = int(width/2)
    for j in range(-hw,hw+1):
        for i in range(-hw,hw+1):
            x[0] = c[0]+i*dx
            x[1] = c[1]+j*dx
            pt = [[x[0],x[1]]]
            n, el, ip = obs_mesh.FindPoints(pt)
            id_map.append(el[0])
    return id_map

In [None]:
id_map = build_map(obs_mesh, width)

Create a function to transfer from the stencil associated with a src element k into the observation gf.

In [None]:
def transfer_stencil(k):
    global obs_u
    for n in range(len(els[k])):
        dst_el = id_map[n]
        src_el = els[k][n]
        #print("el %d -> el %d" % (src_el,dst_el))
        src_dofs = fes.GetElementDofs(src_el)
        dst_dofs = obs_fes.GetElementDofs(dst_el)
        for d in range(len(src_dofs)):
            obs_u[dst_dofs[d]] = u[src_dofs[d]]

Visualize the observation mesh and function.

In [None]:
def show_obs():
    return glvis((obs_mesh, obs_u), 400, 400,layout = Layout(width='100%', height='400px'))

Compute the logits for each element in the observation mesh and visualize as a p=0 L2 function.

In [None]:
def show_logits():
    obs = np.array(obs_u.GetDataArray())
    action, _, info = policy.compute_single_action(obs, explore=False)
    logits = np.array(info['action_dist_inputs'],dtype=np.float64)
    obs_u0.Assign(mfem.Vector(logits))
    return glvis((obs_mesh, obs_u0), 400, 400,layout = Layout(width='100%', height='400px'))

Test it out on a specific src element:

In [None]:
new_function()

In [None]:
transfer_stencil(14)
show_obs()

In [None]:
show_logits()

In [None]:
transfer_stencil(207)
show_obs()

In [None]:
show_logits()

Iterate over all the elements with full stencils in the src mesh and record 'center' logits for each observation:

In [None]:
pt = [[0.5,0.5]]
n, center_el, ip = obs_mesh.FindPoints(pt)

In [None]:
def compute_center_logits(mesh):
    logits = [0.0]*mesh.GetNE()
    for k in els:
        transfer_stencil(k)
        obs = np.array(obs_u.GetDataArray())
        action, _, info = policy.compute_single_action(obs, explore=False)
        #print("action is %d" % action)
        obs_logits = info['action_dist_inputs']
        #print(obs_logits)
        logits[k] = obs_logits[center_el]
    return logits

In [None]:
def compute_avg_logits(mesh):
    logits = [0.0]*mesh.GetNE()
    count = [0]*mesh.GetNE()
    
    # accumulate logit sums
    for k in els:
        transfer_stencil(k)
        obs = np.array(obs_u.GetDataArray())
        action, _, info = policy.compute_single_action(obs, explore=False)
        #print("action is %d" % action)
        obs_logits = info['action_dist_inputs']
        #print(obs_logits)
        for j in range(len(id_map)):
            dst_el = els[k][j]
            src_el = id_map[j]
            logits[dst_el] += obs_logits[src_el]
            count[dst_el] += 1
    
    # average
    for idx,val in enumerate(logits):
        if (count[idx] == 0.0):
            print('zero count at %d' % idx)
        logits[idx] /= count[idx]
        
    return logits

In [None]:
def compute_max_logits(mesh):
    logits = [0.0]*mesh.GetNE()
    
    # accumulate logit sums
    for k in els:
        transfer_stencil(k)
        obs = np.array(obs_u.GetDataArray())
        action, _, info = policy.compute_single_action(obs, explore=False)
        #print("action is %d" % action)
        obs_logits = info['action_dist_inputs']
        #print(obs_logits)
        for j in range(len(id_map)):
            dst_el = els[k][j]
            src_el = id_map[j]
            logits[dst_el] = max(logits[dst_el],obs_logits[src_el])
        
    return logits

Re-normalize the collected logits into a probability distribution that sums to 1.

In [None]:
def compute_distribution(mesh, u, method):
    if (method == 1):
        logits = compute_center_logits(mesh)
    elif (method == 2):
        logits = compute_avg_logits(mesh)
    elif (method == 3):
        logits = compute_max_logits(mesh)
    sumexp = 0.0
    dist = [0.0] * mesh.GetNE()
    for k in range(mesh.GetNE()):        
        logit = logits[k]
        sumexp += math.exp(logit)
    for k in range(mesh.GetNE()):
        logit = logits[k]
        dist[k] = math.exp(logit)/sumexp
    return dist

Create a similar function that returns elementwise errors via the dg indicator.

In [None]:
def compute_dg_indicator(mesh, u):
    
    # put the L2 gridfunction into a coefficient so we can project it into H1
    u_disc_coeff = mfem.GridFunctionCoefficient(u)
    h1_fec = mfem.H1_FECollection(p=1, dim=2)
    h1_fes = mfem.FiniteElementSpace(mesh, h1_fec)
    u_h1 = mfem.GridFunction(h1_fes)
    u_h1.ProjectDiscCoefficient(u_disc_coeff, mfem.GridFunction.ARITHMETIC)
    
    # put the H1 smoothed function into a coefficient
    u_h1_coeff = mfem.GridFunctionCoefficient(u_h1)
    
    # create a 0-order L2 field to hold errors
    l2_0_fec = mfem.L2_FECollection(p=0,dim=2)
    l2_0_fes = mfem.FiniteElementSpace(mesh,l2_0_fec)

    # Compute elementwise "errors" between continuous and discontinuous fields
    err_gf = mfem.GridFunction(l2_0_fes);
    u.ComputeElementL2Errors(u_h1_coeff, err_gf);
    
    return np.array(err_gf.GetDataArray())

Given an indicator on each element, refine everything over the threshold

In [None]:
def refine_via_indicator(ind, mesh, u, thresh):
    refine_els = []
    for k,e in enumerate(ind):
        if (e > thresh):
            refine_els.append(k)
    mesh.GeneralRefinement(mfem.intArray(refine_els))
    u.FESpace().Update()
    u.Update()
    return glvis((mesh, u), 400, 400,layout = Layout(width='100%', height='400px'))

Refine everywhere the policy is over a threshold.

In [None]:
def refine_via_policy_threshold(mesh, u, thresh, method):
    dist = compute_distribution(mesh, u, method)
    return refine_via_indicator(dist, mesh, u, thresh)

Refine everywhere the DG indicator is over a threshold.

In [None]:
def refine_via_dg_threshold(mesh, u, thresh):
    ind = compute_dg_indicator(mesh, u)
    return refine_via_indicator(ind, mesh, u, thresh)

In [None]:
new_function()
refine_via_dg_threshold(mesh, u, 7.e-6)

In [None]:
restore_function()
method = 1
refine_via_policy_threshold(mesh, u, 0.0035, method)

In [None]:
restore_function()
method = 2
refine_via_policy_threshold(mesh, u, 0.0035, method)

In [None]:
restore_function()
method = 3
refine_via_policy_threshold(mesh, u, 0.0035, method)