# Minigame 10: Choose One Element To Refine

This is basically the global environment for refinement where the action is choosing one element at a time.  However, the observation space is the DOFs directly, not function values.

Some things to explore:

* PPO vs DQN vs ?
* CNN vs MLP vs ?
* order=1 vs order=2 vs ?
* H1 space vs DG space vs ?

Setup PyMFEM:

In [None]:
import math

In [None]:
import copy

In [None]:
import sys
import gym
from gym import spaces, utils
import numpy as np
import ray
import ray.rllib.agents.ppo as ppo
from os.path import expanduser, join
import os

In [None]:
from pyglvis import GlvisWidget

In [None]:
from mfem import path
import mfem.ser as mfem

Start up rllib

In [None]:
ray.shutdown()
# This env setting is necessary to avoid problems within rllib due to serialization and workers
ray.init(ignore_reinit_error=True)
config = ppo.DEFAULT_CONFIG.copy()
config['train_batch_size'] = int(5e4)
config['framework'] = 'tfe'
config

Create the gym environment. This is essentially the "global refinement environment".

In [None]:
def get_solnstream(mesh,soln):
    mesh.Print(",tmpmesh")
    with open(",tmpmesh","r") as f:
        meshdata = f.read()
    soln.Save(",tmpsoln")
    with open(",tmpsoln","r") as f:
        solndata = f.read()
    solndata = "solution\n"+meshdata+solndata
    return solndata

In [None]:
class AMRGame(gym.Env):
    
    class u0_coeff(mfem.PyCoefficient):
        
        def setw(self,w):
            self.w = w
            
        def EvalValue(self, x):
            return math.sin(self.w*math.pi*x[0]) * math.sin(self.w*math.pi*x[1])
        
    # In RLlib, you need the config arg
    def __init__(self,config):
        import mfem.ser as mfem

        self.meshfile = 'star.mesh'
        self.mesh = mfem.Mesh(self.meshfile)
        #self.mesh.UniformRefinement()
        
        dim = self.mesh.Dimension()
        order = 1
        self.fec = mfem.H1_FECollection(order, dim)
        self.fes = mfem.FiniteElementSpace(self.mesh, self.fec)
        self.u = mfem.GridFunction(self.fes);

        self.action_space = spaces.Discrete(self.mesh.GetNE())
        self.observation_space = spaces.Box(-1.0, 1.0, shape=(self.u.Size(),), dtype=np.float32)
        self.state = None
        
        self.gl = GlvisWidget(get_solnstream(self.mesh,self.u))
        
    def get_ne(self):
        return self.mesh.GetNE()
    
    def get_size(self):
        return self.u.Size()
    
    # Compute L2 error wrt to the analytic fn definition
    def get_error(self):
        err = self.u.ComputeL2Error(self.u0)
        return err
    
    # Manually refine the elements in the array elems
    def refine_elems(self, elems):
        self.mesh.GeneralRefinement(mfem.intArray(elems))
        self.fes.Update()
        self.u.Update()
        self.u.ProjectCoefficient(self.u0)
    
    # Put mesh in original state and reset gridfunction with supplied vector
    def reset_to(self,u0):
        del self.mesh
        self.mesh = mfem.Mesh(self.meshfile)
        #self.mesh.UniformRefinement()
        self.fes = mfem.FiniteElementSpace(self.mesh, self.fec)
        self.u = mfem.GridFunction(self.fes)
        self.u.Assign(u0)

    # action is the number of the element to refine
    def step(self, action):
        err1 = self.get_error()
        self.refine_elems([action])
        err2 = self.get_error()
        reward = err1-err2
        done = True
        self.state = self.u.GetDataArray()
        return np.array(self.state), reward, done, {}
    
    def reset(self):
        self.u0 = self.u0_coeff()
        self.u0.setw(1.0+5.0*np.random.rand())

        # reread the mesh from file - probably want a faster way to do this
        del self.mesh
        self.mesh = mfem.Mesh(self.meshfile)
        #self.mesh.UniformRefinement()
        dim = self.mesh.Dimension()
        order = 1
        del self.fec
        self.fec = mfem.H1_FECollection(order, dim)
        del self.fes
        self.fes = mfem.FiniteElementSpace(self.mesh, self.fec)
        del self.u
        self.u = mfem.GridFunction(self.fes)
        self.u.ProjectCoefficient(self.u0)
        err = self.get_error()
        self.state = self.u.GetDataArray()
        return np.array(self.state)
    
    def render(self):
        return GlvisWidget(get_solnstream(self.mesh,self.u))

Instantiate the environment and sanity check it.

In [None]:
env = AMRGame(None)
env.get_ne()
obs = env.reset()

In [None]:
env.get_ne()

In [None]:
env.get_size()

In [None]:
obs0 = copy.copy(obs)
u0 = mfem.Vector(obs0)

In [None]:
state, reward, done, info = env.step(0)
reward

Show with refinement of element 0. Then we'll test resetting it to the original state.  We're going to need this to go through a searching for the best actions.

In [None]:
env.render()

In [None]:
env.reset_to(u0) # puts the mesh back in the orig state, and sets the DOF vector to u0
env.render()

Ok, try training a policy

In [None]:
agent = ppo.PPOTrainer(config, env=AMRGame)

%%time
for n in range(2):
    result = agent.train()
    print("episode reward mean: %f " % result["episode_reward_mean"])


In [None]:
policy = agent.get_policy()
model = policy.model
print(model.base_model.summary())

In [None]:
def apply_policy(model, obs):
    action = agent.compute_action(obs)
    state, reward, done, info = env.step(action)
    #print("policy chooses action %d with reward %f" % (action, reward))
    return action, reward

In [None]:
obs = env.reset()
obs0 = copy.copy(obs)
u0 = mfem.Vector(obs0)

The original state:

In [None]:
env.render()

Brute force search for the best choice by trying each one, remembering to reset the environment after each action and after we're done.

In [None]:
def find_best_el(obs):
    u0 = mfem.Vector(obs)
    maxr = 0.0;
    maxel = -1;
    env.reset_to(u0)
    ne = env.get_ne()
    for n in range(ne):
        env.reset_to(u0)
        state, reward, done, info = env.step(n)
        if reward > maxr:
            maxr = reward
            maxel = n
    #print("max reward is %f by refining element %d" % (maxr, maxel))
    env.reset_to(u0)
    return maxel, maxr

maxel, maxr = find_best_el(obs)

In [None]:
env.refine_elems([maxel])
env.render()

Compare with what the policy does:

In [None]:
env.reset_to(u0)
apply_policy(model,obs0)
env.render()

Let's run a more systematic evaluation:

In [None]:
def eval_ensemble(model, ntrials):
    ncorrect = 0
    sumsq = 0
    for n in range(ntrials):
        obs = env.reset()
        bestaction, bestreward = find_best_el(obs)
        action, reward = apply_policy(model,obs)
        err = bestreward-reward
        sumsq += err*err
        if (bestaction == action):
            ncorrect += 1
    rms = math.sqrt(sumsq/ntrials)
    corr = 100.*ncorrect/ntrials
    print("rms error: ",rms,flush=True)
    print("% correct: ",corr,flush=True)
    return rms, corr

eval_ensemble(model, 100)

Let's see if the training process is converging:

In [None]:
nsteps = 20
neval = 200

del agent
agent = ppo.PPOTrainer(config, env=AMRGame)

rms = [0.0] * nsteps
cor = [0.0] * nsteps
for n in range(nsteps):
    print("training batch %d" % n)
    agent.train()
    print("evaluating on %d instances..." %  neval)
    rms[n], cor[n] = eval_ensemble(model, neval)

In [None]:
%matplotlib inline
isteps = list(range(nsteps))
asteps = [i*config['train_batch_size'] for i in isteps]
import matplotlib.pyplot as plt
ax = plt.subplot(211)
ax.set_ylim(0.001,0.1)
ax.set_ylabel('RMS error')
plt.semilogy(asteps,rms[:10], marker='o')
plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
ax = plt.subplot(212)
ax.set_ylim(0,100)
ax.set_ylabel('% correct')
plt.plot(asteps,cor[:10], marker='o')

plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))

In [None]:
rms

## DQN

In [None]:
import ray.rllib.agents.dqn as dqn

In [None]:
del agent
agent = dqn.DQNTrainer(config, env=AMRGame)

In [None]:
result = agent.train()
policy = agent.get_policy()
model = policy.model
print(model.base_model.summary())

In [None]:
nsteps = 10
neval = 200

del agent
agent = dqn.DQNTrainer(config, env=AMRGame)

rms = [0.0] * nsteps
cor = [0.0] * nsteps
for n in range(10):
    agent.train()
    print("evaluating on %d instances..." %  neval)
    rms[n], cor[n] = eval_ensemble(model, neval)

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
ax = plt.subplot(211)
ax.set_ylim(0.0001,0.1)
ax.set_ylabel('RMS error')
plt.semilogy(nsteps,rms, marker='o')
plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
ax = plt.subplot(212)
ax.set_ylim(0,100)
ax.set_ylabel('% correct')
plt.plot(nsteps,cor, marker='o')

plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))