# Minigame 12: Patched Evaluation

In this exploration, we're looking at the ability to take policies trained on small patches and to apply them to larger meshes by combining the logits/distributions/q-values from each of the NN evaluations.

In [1]:
import math
from math import sin,cos
import random

In [2]:
import gym
from gym import spaces, utils
import numpy as np
import ray
import ray.rllib.agents.ppo as ppo
import ray.rllib.agents.dqn as dqn

Instructions for updating:
non-resource variables are not supported in the long term


In [3]:
from glvis import glvis, to_stream
from ipywidgets import Layout

In [4]:
import matplotlib.pyplot as plt

In [5]:
from mfem import path
import mfem.ser as mfem

Define some synthetic test functions.

In [6]:
def rotate(x,theta):
    x0 = x[0]
    y0 = x[1]
    x1 = x0*cos(theta)-y0*sin(theta)
    y1 = x0*sin(theta)+y0*cos(theta)
    return [x1,y1]

In [7]:
def bump(x):
    rsq = x[0]**2 +x[1]**2
    return math.exp(-rsq)

In [8]:
def smooth_step(x):
    return 0.5*(1.0 +math.tanh(x[0]))

In [9]:
def rotated_smooth_step(x,theta):
    xr = rotate(x,theta)
    return smooth_step(xr)

Create classes where we can set the parameters and then eval a bunch of points.

In [10]:
class Step(mfem.PyCoefficient):
    
    def SetParams(self):
        self.theta = random.uniform(0.0, 2.0*math.pi)
        self.dx = [random.uniform(-1.0, 1.0),random.uniform(-1.0, 1.0)]
        
    def EvalValue(self, x):
        return rotated_step(x+self.dx, self.theta)

In [11]:
class Bump(mfem.PyCoefficient):
    
    def SetParams(self):
        self.width = random.uniform(0.05,0.1)
        self.xc = [0.5,0.5]
        self.dx = [random.uniform(-0.5, 0.5),random.uniform(-0.5, 0.5)]

    def EvalValue(self, x):
        return bump((x-self.xc+self.dx)/self.width)

In [12]:
class SmoothStep(mfem.PyCoefficient):
    
    def SetParams(self):
        self.width = random.uniform(5.0, 15.0)
        self.xc = [0.5,0.5]
        self.dx = random.uniform(-0.5,0.5)
        self.theta = random.uniform(0.0, 2.0*math.pi)

    def EvalValue(self, x):
        x -= self.xc
        x += self.dx
        return rotated_smooth_step(x*self.width, self.theta)

In [13]:
class BumpAndSmoothStep(mfem.PyCoefficient):
    
    def SetParams(self):
        self.bump = Bump()
        self.bump.SetParams()
        self.smooth_step = SmoothStep()
        self.smooth_step.SetParams()

    def EvalValue(self, x):
        return 0.5*self.bump.EvalValue(x)+0.5*self.smooth_step.EvalValue(x)

Visualize an instance of the test function. Note that each instance has randomly chosen parameters.  For the steps, it's a rotation angle and a displacement.  For the bumps, it's a width and a displacement.

In [14]:
mesh = mfem.Mesh('inline-quad.mesh')
mesh.UniformRefinement()
mesh.UniformRefinement()
fec = mfem.L2_FECollection(p=1, dim=2)
fes = mfem.FiniteElementSpace(mesh, fec)
u = mfem.GridFunction(fes)
c = Bump()
c.SetParams()
u.ProjectCoefficient(c)

In [15]:
glvis(to_stream(mesh,u) + 'keys Rcjm', 500, 500)

glvis()

Create the gym environment. Note that in this case, this can be just a dummy environment that only serves to define the observation and action spaces for the purposes of evaluation of the policy.

In [16]:
class AMRGameDummy(gym.Env):
        
    # In RLlib, you need the config arg
    def __init__(self,config):
        self.meshfile = 'inline-quad-5.mesh'
        self.mesh = mfem.Mesh(self.meshfile)
        
        # The only reason we need to create a fespace and gf here
        # is to find the sizes needed for the action and observation spaces
        dim = self.mesh.Dimension()
        self.order = 1
        self.fec = mfem.L2_FECollection(self.order, dim)
        self.fes = mfem.FiniteElementSpace(self.mesh, self.fec)
        self.u = mfem.GridFunction(self.fes);

        # actions are: refine each element, or do nothing
        self.action_space = spaces.Discrete(self.mesh.GetNE())
        
        # observation space: DOFs
        self.observation_space = spaces.Box(-1.0, 1.0, shape=(self.u.Size(),), dtype=np.float32)
        
    def step(self, action):
        pass
    
    def reset(self):
        pass
    
    def render(self):
        pass

The estimator game has a different action space

In [47]:
class EstimatorGameDummy(gym.Env):
    
    class u0_coeff(mfem.PyCoefficient):
        
        def SetParams(self):
            #self.fn = BumpAndSmoothStep()
            self.fn = Bump()

            self.fn.SetParams()
            
        def Print(self):
            self.fn.Print()
            
        def EvalValue(self, x):
            v = self.fn.EvalValue(x)
            #assert v >= 0.0
            #assert v <= 1.0
            return self.fn.EvalValue(x)
    
    # precompute the observation points and the elements and integration points we need
    def get_obs_points(self):
        n = math.sqrt(self.mesh.GetNE())
        dx = 1.0/self.obsx
        dy = 1.0/self.obsy
        self.sample_pts = []
        self.sample_els = []
        self.sample_ips = []
        for j in range(self.obsy):
            for i in range(self.obsx):
                pt = [i*dx+0.5*dx,j*dy+0.5*dy]
                self.sample_pts.append(pt)
                n, el, ip = self.mesh.FindPoints([pt])
                #assert n == 1
                #assert ip[0].x > 0.0
                #assert ip[0].x < 1.0
                #assert ip[0].y > 0.0
                #assert ip[0].y < 1.0
                #assert el[0] >= 0
                #assert el[0] < self.mesh.GetNE()
                # copy these so they won't be destroyed when mesh goes away?
                ip0 = mfem.IntegrationPoint()
                ip0.x = ip[0].x
                ip0.y = ip[0].y
                self.sample_els.append(el[0])
                self.sample_ips.append(ip0)
                
    def get_obs(self):
        state = np.empty((self.obsx,self.obsy,1))
        k = 0
        for j in range(self.obsy):
            for i in range(self.obsx):
                #assert k < len(self.sample_els)
                #assert k < len(self.sample_ips)
                el = self.sample_els[k]
                #assert el >= 0
                #assert el < self.mesh.GetNE()
                ip = self.sample_ips[k]
                #assert ip.x >= 0.0, "k={}, i={}, j={}, x is {}".format(k,i,j,ip.x)
                #assert ip.x <= 1.0, "k={}, i={}, j={}, x is {}".format(k,i,j,ip.x)
                #assert ip.y >= 0.0, "k={}, i={}, j={}, y is {}".format(k,i,j,ip.y)
                #assert ip.y <= 1.0, "k={}, i={}, j={}, y is {}".format(k,i,j,ip.y)
                v = self.u.GetValue(self.sample_els[k],self.sample_ips[k])
                state[i][j] = v
                if (v > 2.0 or v < -1.0):
                    print("element %d" % self.sample_els[k])
                    print("ip.x = %f" % self.sample_ips[k].x)
                    print("ip.y = %f" % self.sample_ips[k].y)
                    print("%d,%d -> %f" % (i,j,v))
                    print("%d,%d -> %f" % (i,j,state[i][j]))
                    self.u0.Print()
                k += 1
        self.state = state
        return state
        
    # In RLlib, you need the config arg
    def __init__(self,config):
        self.meshfile = 'inline-quad-5.mesh'
        self.mesh = mfem.Mesh(self.meshfile)
        
        # The only reason we need to create a fespace and gf here
        # is to find the sizes needed for the action and observation spaces
        dim = self.mesh.Dimension()
        self.order = 1
        self.fec = mfem.L2_FECollection(self.order, dim)
        self.fes = mfem.FiniteElementSpace(self.mesh, self.fec)
        self.u = mfem.GridFunction(self.fes);

        # actions are: do nothing (0) or refine center element (1)
        self.action_space = spaces.Discrete(2)
        
        self.obsx = 42
        self.obsy = 42
        
        # observation space: 42x42 image
        self.observation_space = spaces.Box(-1.0, 2.0, shape=(self.obsx,self.obsy,1))

    def step(self, action):
        pass
    
    def reset(self):
        pass
    
    def render(self):
        pass

Now we want to load a trained policy, and apply it in a strided way.

In [25]:
ray.shutdown()
ray.init(ignore_reinit_error=True)

2021-02-26 21:16:03,121	INFO services.py:1174 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8267[39m[22m


{'node_ip_address': '192.168.1.201',
 'raylet_ip_address': '192.168.1.201',
 'redis_address': '192.168.1.201:35154',
 'object_store_address': '/tmp/ray/session_2021-02-26_21-16-02_494368_5464/sockets/plasma_store',
 'raylet_socket_name': '/tmp/ray/session_2021-02-26_21-16-02_494368_5464/sockets/raylet',
 'webui_url': '127.0.0.1:8267',
 'session_dir': '/tmp/ray/session_2021-02-26_21-16-02_494368_5464',
 'metrics_export_port': 63789,
 'node_id': '50123b4473b6fcf8b984d6084d70eed04d0d64144c8233325c967135'}

In [48]:
ppo_config = ppo.DEFAULT_CONFIG.copy()
dqn_config = dqn.DEFAULT_CONFIG.copy()

#ppo_config['framework'] = 'tfe'
#dqn_config['num_workers'] = 1

ppo_config['train_batch_size'] = 1000
ppo_config['num_workers'] = 3
ppo_config['num_gpus'] = 0

agent = ppo.PPOTrainer(ppo_config, env=EstimatorGameDummy)

#agent = dqn.DQNTrainer(dqn_config, env=AMRGameDummy)
ppo_config

[2m[36m(pid=6763)[0m Instructions for updating:
[2m[36m(pid=6763)[0m non-resource variables are not supported in the long term
[2m[36m(pid=6763)[0m Instructions for updating:
[2m[36m(pid=6763)[0m non-resource variables are not supported in the long term
[2m[36m(pid=6763)[0m Instructions for updating:
[2m[36m(pid=6763)[0m non-resource variables are not supported in the long term
[2m[36m(pid=6763)[0m Instructions for updating:
[2m[36m(pid=6763)[0m If using Keras pass *_constraint arguments to layers.
[2m[36m(pid=6763)[0m Instructions for updating:
[2m[36m(pid=6763)[0m If using Keras pass *_constraint arguments to layers.
[2m[36m(pid=6763)[0m Instructions for updating:
[2m[36m(pid=6763)[0m If using Keras pass *_constraint arguments to layers.
[2m[36m(pid=6763)[0m Instructions for updating:
[2m[36m(pid=6763)[0m Use tf.where in 2.0, which has the same broadcast rule as np.where
[2m[36m(pid=6763)[0m Instructions for updating:
[2m[36m(pid=6763)[

{'num_workers': 3,
 'num_envs_per_worker': 1,
 'create_env_on_driver': False,
 'rollout_fragment_length': 200,
 'batch_mode': 'truncate_episodes',
 'num_gpus': 0,
 'train_batch_size': 1000,
 'model': {'fcnet_hiddens': [256, 256],
  'fcnet_activation': 'tanh',
  'conv_filters': None,
  'conv_activation': 'relu',
  'free_log_std': False,
  'no_final_linear': False,
  'vf_share_layers': False,
  'use_lstm': False,
  'max_seq_len': 20,
  'lstm_cell_size': 256,
  'lstm_use_prev_action': False,
  'lstm_use_prev_reward': False,
  '_time_major': False,
  'use_attention': False,
  'attention_num_transformer_units': 1,
  'attention_dim': 64,
  'attention_num_heads': 1,
  'attention_head_dim': 32,
  'attention_memory_inference': 50,
  'attention_memory_training': 50,
  'attention_position_wise_mlp_dim': 32,
  'attention_init_gru_gate_bias': 2.0,
  'num_framestacks': 'auto',
  'dim': 84,
  'grayscale': False,
  'zero_mean': True,
  'custom_model': None,
  'custom_model_config': {},
  'custom_actio

Restore a policy

In [49]:
agent.restore("/home/rwa/ray_results/PPO_EstimatorGame_2021-02-26_20-55-36iny1vi4_/checkpoint_5/checkpoint-5")

2021-02-26 22:01:00,314	INFO trainable.py:372 -- Restored on 192.168.1.201 from checkpoint: /home/rwa/ray_results/PPO_EstimatorGame_2021-02-26_20-55-36iny1vi4_/checkpoint_5/checkpoint-5
2021-02-26 22:01:00,315	INFO trainable.py:379 -- Current state after restoring: {'_iteration': 5, '_timesteps_total': None, '_time_total': 468.9241576194763, '_episodes_total': 6000}


In [50]:
policy = agent.get_policy()

[2m[36m(pid=27087)[0m Instructions for updating:
[2m[36m(pid=27087)[0m non-resource variables are not supported in the long term
[2m[36m(pid=27087)[0m Instructions for updating:
[2m[36m(pid=27087)[0m non-resource variables are not supported in the long term
[2m[36m(pid=27087)[0m Instructions for updating:
[2m[36m(pid=27087)[0m non-resource variables are not supported in the long term
[2m[36m(pid=27086)[0m Instructions for updating:
[2m[36m(pid=27086)[0m non-resource variables are not supported in the long term
[2m[36m(pid=27086)[0m Instructions for updating:
[2m[36m(pid=27086)[0m non-resource variables are not supported in the long term
[2m[36m(pid=27086)[0m Instructions for updating:
[2m[36m(pid=27086)[0m non-resource variables are not supported in the long term
[2m[36m(pid=27087)[0m Instructions for updating:
[2m[36m(pid=27087)[0m If using Keras pass *_constraint arguments to layers.
[2m[36m(pid=27087)[0m Instructions for updating:
[2m[3

Now we want to create the larger problem we'll be applying this local indicator on.

In [29]:
mesh = mfem.Mesh('inline-quad-30.mesh')
print(mesh.GetNE())
fec = mfem.L2_FECollection(p=1, dim=2)
fes = mfem.FiniteElementSpace(mesh, fec)
u = mfem.GridFunction(fes)
u0 = mfem.GridFunction(fes)
coeff = Bump()
coeff.SetParams()
u.ProjectCoefficient(coeff)
u0.Assign(u) # save so we can restore later if desired
    
def new_function():
    global mesh, fec, fes, u, u0, coeff
    mesh = mfem.Mesh('inline-quad-30.mesh')
    fec = mfem.L2_FECollection(p=1, dim=2)
    fes = mfem.FiniteElementSpace(mesh, fec)
    u = mfem.GridFunction(fes)
    u0 = mfem.GridFunction(fes)
    coeff = Bump()
    coeff.SetParams()
    u.ProjectCoefficient(coeff)
    u0.Assign(u) # save so we can restore later if desired
    return glvis(to_stream(mesh,u) + 'keys Rcjm', 500, 500)

900


In [30]:
def restore_function():
    global mesh, fec, fes, u
    mesh = mfem.Mesh('inline-quad-30.mesh')
    fec = mfem.L2_FECollection(p=1, dim=2)
    fes = mfem.FiniteElementSpace(mesh, fec)
    u = mfem.GridFunction(fes)
    u.Assign(u0)
    #return glvis(to_stream(mesh,u) + 'keys Rcjm', 500, 500)

Build a map from each element to the elements which consist of the "stencil" around it. Since not every element has a full stencil, use a dictionary that only contains the elements containing full stencils as keys.

In [31]:
def build_stencils(mesh, width):
    els = {}
    nx = math.sqrt(mesh.GetNE())
    dx = 1.0/nx
    dim = mesh.Dimension()
    els = {}
    hw = int(width/2)
    c = mfem.Vector(dim)
    x = mfem.Vector(dim)
    for k in range(0,mesh.GetNE()):
        els[k] = []
        mesh.GetElementCenter(k,c)
        full = True
        for j in range(-hw,hw+1):
            for i in range(-hw,hw+1):
                x[0] = c[0]+i*dx
                x[1] = c[1]+j*dx
                if (x[0] < 0.0): full = False
                if (x[0] > 1.0): full = False
                if (x[1] < 0.0): full = False
                if (x[1] > 1.0): full = False
                pt = [[x[0],x[1]]]
                n, el, ip = mesh.FindPoints(pt)
                els[k].append(el[0])
        if (not full):
            els.pop(k)
    return els

Create a function and build the stencils for it.

In [32]:
new_function()
width=5
els = build_stencils(mesh, width)

Create the local observation mesh into which we will copy the dofs for the purposes of creating an observation vector.

In [33]:
obs_mesh = mfem.Mesh('inline-quad-5.mesh')
obs_fec = mfem.L2_FECollection(p=1, dim=2)
obs_fes = mfem.FiniteElementSpace(obs_mesh, obs_fec)
obs_u = mfem.GridFunction(obs_fes)
glvis((obs_mesh), 400, 400,layout = Layout(width='100%', height='400px'))

glvis(layout=Layout(height='400px', width='100%'))

Also build a 0th order L2 field to look at per-element quantities (like logits or prob dist).

In [34]:
fec0 = mfem.L2_FECollection(p=0, dim=2)
fes0 = mfem.FiniteElementSpace(obs_mesh, fec0)
obs_u0 = mfem.GridFunction(fes0)

Now we need a mapping from the "logical" space of the observation mesh into element ids. This has the same ordering as the stencil elements, so we can form a mapping for the purposes of data transfer from the src mesh to the obs mesh.

In [35]:
def build_map(obs_mesh, width):
    id_map = []
    c = [0.5, 0.5]
    x = [0.0, 0.0]
    dx = 1./width
    hw = int(width/2)
    for j in range(-hw,hw+1):
        for i in range(-hw,hw+1):
            x[0] = c[0]+i*dx
            x[1] = c[1]+j*dx
            pt = [[x[0],x[1]]]
            n, el, ip = obs_mesh.FindPoints(pt)
            id_map.append(el[0])
    return id_map

In [36]:
id_map = build_map(obs_mesh, width)

Create a function to transfer from the stencil associated with a src element k into the observation gf.

In [37]:
def transfer_stencil(k):
    global obs_u
    for n in range(len(els[k])):
        dst_el = id_map[n]
        src_el = els[k][n]
        #print("el %d -> el %d" % (src_el,dst_el))
        src_dofs = fes.GetElementDofs(src_el)
        dst_dofs = obs_fes.GetElementDofs(dst_el)
        for d in range(len(src_dofs)):
            obs_u[dst_dofs[d]] = u[src_dofs[d]]

Visualize the observation mesh and function.

In [38]:
def show_obs():
    return glvis(to_stream(obs_mesh,obs_u) + 'keys Rcjm', 500, 500)

Compute the logits for each element in the observation mesh and visualize as a p=0 L2 function.

In [39]:
def show_patch_logits():
    obs = np.array(obs_u.GetDataArray())
    action, _, info = policy.compute_single_action(obs, explore=False)
    logits = np.array(info['action_dist_inputs'],dtype=np.float64)
    obs_u0.Assign(mfem.Vector(logits))
    return glvis(to_stream(obs_mesh,obs_u0) + 'keys Rcjm', 500, 500)

In [40]:
def show_patch_qvalues():
    obs = np.array(obs_u.GetDataArray())
    action, _, info = policy.compute_single_action(obs, explore=False)
    print(info)
    qvalues = np.array(info['q_values'],dtype=np.float64)
    obs_u0.Assign(mfem.Vector(qvalues))
    return glvis(to_stream(obs_mesh,obs_u0) + 'keys Rcjm', 500, 500)

Test it out on a specific src element:

In [41]:
new_function()

glvis()

In [None]:
transfer_stencil(10)
show_obs()

In [None]:
show_patch_logits()

show_patch_qvalues()

In [None]:
transfer_stencil(889)
show_obs()

In [None]:
show_patch_logits()

Iterate over all the elements with full stencils in the src mesh and record 'center' logits for each observation:

In [None]:
pt = [[0.5,0.5]]
n, center_el, ip = obs_mesh.FindPoints(pt)
center_el

In [None]:
def compute_center_logits(mesh):
    logits = [-1.0]*mesh.GetNE()
    for k in els:
        transfer_stencil(k)
        obs = np.array(obs_u.GetDataArray())
        action, _, info = policy.compute_single_action(obs, explore=False)
        obs_logits = info['q_values']
        #value = info['vf_preds']
        #print("value is %f" % value)
        
        # find minimum to use as datum
        #min_logit = 1.e6
        #for j in range(len(id_map)):
        #    src_el = id_map[j]
        #    min_logit = min(min_logit,obs_logits[src_el])
        
        logits[k] = obs_logits[center_el][0] # -min_logit
    return logits

In [42]:
def get_center_actions(mesh):
    actions = [0.0]*mesh.GetNE()
    for k in els:
        transfer_stencil(k)
        obs = env.get_obs()
        actions[k], _, info = policy.compute_single_action(obs, explore=False)
    return actions

In [46]:
new_function()

glvis()

In [44]:
actions = get_center_actions(mesh)
actions

2021-02-26 21:17:47,360	ERROR tf_run_builder.py:47 -- Error fetching: [<tf.Tensor 'default_policy/cond_1/Merge:0' shape=(?,) dtype=int64>, {'action_prob': <tf.Tensor 'default_policy/Exp:0' shape=(?,) dtype=float32>, 'action_logp': <tf.Tensor 'default_policy/cond_2/Merge:0' shape=(?,) dtype=float32>, 'action_dist_inputs': <tf.Tensor 'default_policy/Squeeze:0' shape=(?, 2) dtype=float32>, 'vf_preds': <tf.Tensor 'default_policy/Reshape_1:0' shape=(?,) dtype=float32>}], feed_dict={<tf.Tensor 'default_policy/obs:0' shape=(?, 42, 42, 1) dtype=float32>: [array([5.43666893e-62, 2.84434875e-60, 1.09953042e-60, 5.75250769e-59,
       4.86857089e-59, 2.18028851e-57, 9.84636341e-58, 4.40948966e-56,
       8.40469204e-57, 3.76386704e-55, 1.45498554e-55, 6.51585101e-54,
       9.38540879e-60, 4.91024489e-58, 1.62476317e-58, 8.50041295e-57,
       1.23764228e-57, 6.47507934e-56, 1.83397967e-56, 9.59498888e-55,
       1.24669068e-55, 6.52241865e-54, 1.58132306e-54, 8.27314360e-53,
       9.59276764e-5

ValueError: Cannot feed value of shape (1, 100) for Tensor 'default_policy/obs:0', which has shape '(?, 42, 42, 1)'

In [43]:
fec0fm = mfem.L2_FECollection(p=0, dim=2)
fes0fm = mfem.FiniteElementSpace(mesh, fec0fm)
log_fm = mfem.GridFunction(fes0)
log_fm.Assign(mfem.Vector(np.array(actions)))
glvis(to_stream(mesh, log_fm), 500, 500)

2021-02-26 21:17:14,637	ERROR tf_run_builder.py:47 -- Error fetching: [<tf.Tensor 'default_policy/cond_1/Merge:0' shape=(?,) dtype=int64>, {'action_prob': <tf.Tensor 'default_policy/Exp:0' shape=(?,) dtype=float32>, 'action_logp': <tf.Tensor 'default_policy/cond_2/Merge:0' shape=(?,) dtype=float32>, 'action_dist_inputs': <tf.Tensor 'default_policy/Squeeze:0' shape=(?, 2) dtype=float32>, 'vf_preds': <tf.Tensor 'default_policy/Reshape_1:0' shape=(?,) dtype=float32>}], feed_dict={<tf.Tensor 'default_policy/obs:0' shape=(?, 42, 42, 1) dtype=float32>: [array([5.43666893e-62, 2.84434875e-60, 1.09953042e-60, 5.75250769e-59,
       4.86857089e-59, 2.18028851e-57, 9.84636341e-58, 4.40948966e-56,
       8.40469204e-57, 3.76386704e-55, 1.45498554e-55, 6.51585101e-54,
       9.38540879e-60, 4.91024489e-58, 1.62476317e-58, 8.50041295e-57,
       1.23764228e-57, 6.47507934e-56, 1.83397967e-56, 9.59498888e-55,
       1.24669068e-55, 6.52241865e-54, 1.58132306e-54, 8.27314360e-53,
       9.59276764e-5

ValueError: Cannot feed value of shape (1, 100) for Tensor 'default_policy/obs:0', which has shape '(?, 42, 42, 1)'

In [None]:
def compute_avg_logits(mesh):
    logits = [0.0]*mesh.GetNE()
    count = [0]*mesh.GetNE()
    
    # accumulate logit sums
    for k in els:
        print ("el %d" % k)
        transfer_stencil(k)
        obs = np.array(obs_u.GetDataArray())
        action, _, info = policy.compute_single_action(obs, explore=False)
        obs_logits = info['q_values']
        
        # find minimum to use as datum
#        min_logit = 1.e6
#        for j in range(len(id_map)):
#            src_el = id_map[j]
#            min_logit = min(min_logit,obs_logits[src_el])

        for j in range(len(id_map)):
            dst_el = els[k][j]
            src_el = id_map[j]
            logits[dst_el] += obs_logits[src_el] #-min_logit
            count[dst_el] += 1
    
    # average
    for idx,val in enumerate(logits):
        logits[idx] /= count[idx]
        if (count[idx] < 16):
            logits[idx] = -10
        #logits[idx] /= obs_mesh.GetNE()

    return logits

In [None]:
logits = compute_avg_logits(mesh)
fec0fm = mfem.L2_FECollection(p=0, dim=2)
fes0fm = mfem.FiniteElementSpace(mesh, fec0fm)
log_fm = mfem.GridFunction(fes0)
log_fm.Assign(mfem.Vector(np.array(logits)))
glvis(to_stream(mesh, log_fm), 500, 500)

In [None]:
def compute_min_logits(mesh):
    logits = [100.]*mesh.GetNE()
    
    # choose the max at each element
    for k in els:
        transfer_stencil(k)
        obs = np.array(obs_u.GetDataArray())
        action, _, info = policy.compute_single_action(obs, explore=False)
        #print("action is %d" % action)
        obs_logits = info['action_dist_inputs']
        #print(obs_logits)
        for j in range(len(id_map)):
            dst_el = els[k][j]
            src_el = id_map[j]
            logits[dst_el] = min(logits[dst_el],obs_logits[src_el])
        
    return logits

In [None]:
def compose_max_q(mesh):
    logits = [-100.]*mesh.GetNE()
    
    # choose the max *in each patch*
    for k in els:
        transfer_stencil(k)
        obs = np.array(obs_u.GetDataArray())
        action, _, info = policy.compute_single_action(obs, explore=False)
        #print("action is %d" % action)
        obs_logits = info['q_values']
        #print(obs_logits)
        maxq = -1.e6
        for j in range(len(id_map)):
            dst_el = els[k][j]
            src_el = id_map[j]
            q = obs_logits[src_el]
            if (q > maxq):
                maxq = q
                src = src_el
                dst = dst_el
        logits[dst] = obs_logits[src_el]
        
    return logits

Re-normalize the collected logits into a probability distribution that sums to 1.

In [None]:
def compute_distribution(mesh, u, method):
    if (method == 1):
        logits = compute_center_logits(mesh)
    elif (method == 2):
        logits = compute_avg_logits(mesh)
    elif (method == 3):
        logits = compute_min_logits(mesh)
    elif (method == 4):
        logits = compose_max_q(mesh)
    sumexp = 0.0
    dist = [0.0] * mesh.GetNE()
    for k in range(mesh.GetNE()):        
        logit = logits[k]
        sumexp += math.exp(logit)
    for k in range(mesh.GetNE()):
        logit = logits[k]
        dist[k] = math.exp(logit)/sumexp
    return dist

Create a similar function that returns elementwise errors via the dg indicator.

In [None]:
def compute_dg_indicator(mesh, u):
    
    # put the L2 gridfunction into a coefficient so we can project it into H1
    u_disc_coeff = mfem.GridFunctionCoefficient(u)
    h1_fec = mfem.H1_FECollection(p=1, dim=2)
    h1_fes = mfem.FiniteElementSpace(mesh, h1_fec)
    u_h1 = mfem.GridFunction(h1_fes)
    u_h1.ProjectDiscCoefficient(u_disc_coeff, mfem.GridFunction.ARITHMETIC)
    
    # put the H1 smoothed function into a coefficient
    u_h1_coeff = mfem.GridFunctionCoefficient(u_h1)
    
    # create a 0-order L2 field to hold errors
    l2_0_fec = mfem.L2_FECollection(p=0,dim=2)
    l2_0_fes = mfem.FiniteElementSpace(mesh,l2_0_fec)

    # Compute elementwise "errors" between continuous and discontinuous fields
    err_gf = mfem.GridFunction(l2_0_fes);
    u.ComputeElementL2Errors(u_h1_coeff, err_gf);
    
    return np.array(err_gf.GetDataArray())

In [None]:
def compute_actual_error_reduction(mesh, u, u0):
    init_err = u.ComputeL2Error(coeff)
    print("init_err: %f" % init_err)
    delta_elem_err = []
    for k in range(mesh.GetNE()):
        mesh = mfem.Mesh('inline-quad-30.mesh')
        fec = mfem.L2_FECollection(p=1, dim=2)
        fes = mfem.FiniteElementSpace(mesh, fec)
        u = mfem.GridFunction(fes)
        u.Assign(u0)
    
        refine_els = []
        refine_els.append(k)
        mesh.GeneralRefinement(mfem.intArray(refine_els))
        u.FESpace().Update()
        u.Update()
        u.ProjectCoefficient(coeff)
        new_err = u.ComputeL2Error(coeff)
        #print("delta for %d is %e" % (k,init_err-new_err))
        delta_elem_err.append(init_err -new_err)

    return delta_elem_err

Given an indicator on each element, refine everything over the threshold

In [None]:
def refine_via_indicator(ind, mesh, u, thresh):
    refine_els = []
    for k,e in enumerate(ind):
        if (e > thresh):
            refine_els.append(k)
    mesh.GeneralRefinement(mfem.intArray(refine_els))
    u.FESpace().Update()
    u.Update()
    return glvis(to_stream(mesh,u) + 'keys Rcjm', 500, 500)

Refine everywhere the policy is over a threshold.

In [None]:
def refine_via_policy_threshold(mesh, u, thresh, method):
    dist = compute_distribution(mesh, u, method)
    return refine_via_indicator(dist, mesh, u, thresh)

Refine everywhere the DG indicator is over a threshold.

In [None]:
def refine_via_dg_threshold(mesh, u, thresh):
    ind = compute_dg_indicator(mesh, u)
    return refine_via_indicator(ind, mesh, u, thresh)

In [None]:
def refine_topk_via_indicator(ind, mesh, u, K):
    refine_els = []
    indsort = np.argsort(ind)[::-1]
    for i in range(K):
        refine_els.append(indsort[i])
    mesh.GeneralRefinement(mfem.intArray(refine_els))
    u.FESpace().Update()
    u.Update()
    return glvis(to_stream(mesh,u) + 'keys Rjm', 400, 400,layout = Layout(width='100%', height='400px'))
 
def refine_topk_via_policy(mesh, u, K, method):
    dist = compute_distribution(mesh, u, method)
    return refine_topk_via_indicator(dist, mesh, u, K)
 
def refine_topk_via_dg(mesh, u, K):
    ind = compute_dg_indicator(mesh, u)
    return refine_topk_via_indicator(ind, mesh, u, K)

def refine_topk_via_delta_elem_err(mesh, u, K):
    return refine_topk_via_indicator(delta_elem_err, mesh, u, K)

In [None]:
new_function()

In [None]:
delta_elem_err = compute_actual_error_reduction(mesh, u, u0)

In [None]:
restore_function()
refine_topk_via_delta_elem_err(mesh, u, 100)

In [None]:
restore_function()
refine_topk_via_dg(mesh, u, 100)

In [None]:
restore_function()
method = 1 # center
refine_topk_via_policy(mesh, u, 100, method)

In [None]:
restore_function()
method = 2 # avg
refine_topk_via_policy(mesh, u, 100, method)

In [None]:
restore_function()
method = 4 # max q
refine_topk_via_policy(mesh, u, 100, method)

In [None]:
new_function()
refine_topk_via_dg(mesh, u, 100)

In [None]:
restore_function()
method = 1 # center
refine_topk_via_policy(mesh, u, 100, method)

In [None]:
restore_function()
method = 2 # avg
refine_topk_via_policy(mesh, u, 100, method)