# Minigame 10b: Local Error Estimator

In this environment, the game is to choose whether or not to refine an element.  A reward is given for refining elements that decrease the error down to but not below a threshold.

In [1]:
import math
from math import cos,sin
import random

In [2]:
import sys
import gym
from gym import spaces, utils
import numpy as np
import ray
import ray.rllib.agents.ppo as ppo
import ray.rllib.agents.dqn as dqn
from os.path import expanduser, join
import os

Instructions for updating:
non-resource variables are not supported in the long term


In [3]:
from glvis import glvis, to_stream

In [4]:
from mfem import path
import mfem.ser as mfem

Start up rllib

Define some synthetic test functions: steps and bumps.

In [5]:
def rotate(x,theta):
    x0 = x[0]
    y0 = x[1]
    x1 = x0*cos(theta)-y0*sin(theta)
    y1 = x0*sin(theta)+y0*cos(theta)
    return [x1,y1]

In [6]:
def step(x):
    x0 = x[0]
    if (x0 < 0.0):
        return 1.0
    else:
        return 0.0

In [7]:
def rotated_step(x, theta):
    xr = rotate(x,theta)
    return step(xr)

In [8]:
def bump(x):
    rsq = x[0]**2 +x[1]**2
    v = math.exp(-rsq)
    #assert v <= 1.0
    #assert v >= 0.0
    return v

In [9]:
def wave(x):
    return 0.5*(1.0+sin(2.*math.pi*x[0]))

In [10]:
def smooth_step(x):
    return 0.5*(1.0 +math.tanh(x[0]))

In [11]:
def rotated_smooth_step(x,theta):
    xr = rotate(x,theta)
    return smooth_step(xr)

Create classes where we can set the parameters and then eval a bunch of points.

In [12]:
class Step(mfem.PyCoefficient):
    
    def SetParams(self):
        self.theta = random.uniform(0.0, 2.0*math.pi)
        self.dx = [random.uniform(-1.0, 1.0),random.uniform(-1.0, 1.0)]
        
    def EvalValue(self, x):
        return rotated_step(x+self.dx, self.theta)

In [13]:
class Wave(mfem.PyCoefficient):
    
    def SetParams(self):
        self.theta = random.uniform(0.0, 2.0*math.pi)
        self.dx = [random.uniform(-0.5, 0.5),random.uniform(-0.5, 0.5)]
        self.wavel = random.uniform(0.0, 1.0)
        y1 = random.uniform(0.1,0.9)
        y2 = random.uniform(0.1,0.9)
        self.floor = min(y1,y2)
        self.ceiling = max(y1,y2)
        self.height = self.ceiling -self.floor
        
    def EvalValue(self, x):
        x *= self.wavel
        x -= self.dx
        xr = rotate(x,self.theta)
        return self.floor +self.height*wave(xr)

In [14]:
class Linear(mfem.PyCoefficient):
    
    def SetParams(self):
        self.theta = random.uniform(0.0, 2.0*math.pi)
        self.y0 = random.uniform(0.0,1.0)
        self.y1 = random.uniform(0.0,1.0)
        
        self.m = self.y1-self.y0
        self.b = 0.5*(self.y0+self.y1)
        
    def EvalValue(self, x):
        xc = x-[0.5,0.5]
        xr = rotate(xc,self.theta)
        line = self.m*xr[0]+self.b
        return line

In [15]:
class Bump(mfem.PyCoefficient):
    
    def SetParams(self):
        self.width = [random.uniform(0.1,1.0),random.uniform(0.1,1.0)]
        self.xc = [0.5,0.5]
        self.dx = [random.uniform(-0.5, 0.5),random.uniform(-0.5, 0.5)]
        y1 = random.uniform(0.1,0.9)
        y2 = random.uniform(0.1,0.9)
        self.floor = min(y1,y2)
        self.ceiling = max(y1,y2)
        self.height = self.ceiling -self.floor
        
    def Print(self):
        print("width = %f" % self.width)
        print("xc = %f,%f" % (self.xc[0],self.xc[1]))
        print("dx = %f,%f" % (self.dx[0],self.dx[1]))
        print("floor = %f" % self.floor)
        print("ceil = %f"  % self.ceiling)
        print("height = %f"% self.height)
        
    def EvalValue(self, x):
        v = self.floor +self.height*bump((x-self.xc+self.dx)/self.width)
        #assert v > 0.0
        #assert v < 1.0
        return v

In [16]:
class TwoBump(mfem.PyCoefficient):
    
    def SetParams(self):
        self.width1 = random.uniform(0.1,0.5)
        self.width2 = random.uniform(0.1,0.5)
        self.xc1 = [0.5,0.5]
        self.xc2 = [0.5,0.5]
        self.dx1 = [random.uniform(-0.5, 0.5),random.uniform(-0.5, 0.5)]
        self.dx2 = [random.uniform(-0.5, 0.5),random.uniform(-0.5, 0.5)]

    def EvalValue(self, x):
        return 0.5*(bump((x-self.xc1+self.dx1)/self.width1)+bump((x-self.xc2+self.dx2)/self.width2))

In [17]:
class SmoothStep(mfem.PyCoefficient):
    
    def SetParams(self):
        self.width = random.uniform(5.0, 10.0)
        self.xc = [0.5,0.5]
        self.dx = random.uniform(-0.5,0.5)
        self.theta = random.uniform(0.0, 2.0*math.pi)
        self.height = random.uniform(0.0, 1.0)
        y1 = random.uniform(0.0,1.0)
        y2 = random.uniform(0.0,1.0)
        self.floor = min(y1,y2)
        self.ceiling = max(y1,y2)
        self.height = self.ceiling -self.floor

    def EvalValue(self, x):
        x -= self.xc
        x += self.dx
        return self.floor +self.height*rotated_smooth_step(x*self.width, self.theta)

In [18]:
class BumpAndSmoothStep(mfem.PyCoefficient):
    
    def SetParams(self):
        self.bump = Bump()
        self.bump.SetParams()
        self.smooth_step = SmoothStep()
        self.smooth_step.SetParams()
        self.alpha = random.uniform(0.0, 1.0)

    def EvalValue(self, x):
        return self.alpha*self.bump.EvalValue(x)+ (1-self.alpha)*self.smooth_step.EvalValue(x)

In [19]:
class BumpNarrowWide(mfem.PyCoefficient):
    
    def SetParams(self):
        a = random.uniform(0.0,1.0)
        if (a < 0.5):
            self.width = 0.2
            self.height = 1.0
        else:
            self.width = 0.4
            self.height = 0.1
        self.xc = [0.5,0.5]
        self.dx = [random.uniform(-0.5, 0.5),random.uniform(-0.5, 0.5)]

    def EvalValue(self, x):
        return self.height*bump((x-self.xc+self.dx)/self.width)

In [20]:
class DiscreteBump(mfem.PyCoefficient):
    
    def SetParams(self):
        nlocs = 100
        nmesh = 5
        i = random.randrange(nlocs)
        j = random.randrange(nlocs)
        dx = 1.0/nlocs
        dx_mesh = 1.0/nmesh
        self.xc = [0.,0.]
        self.xc[0] = i*dx+0.5*dx
        self.xc[1] = j*dx+0.5*dx
        self.width = dx_mesh/2
        #print("(%d,%d)" % (i,j))
        #print("(%f,%f)" % (self.xc[0],self.xc[1]))
        
    def Print(self):
        pass
        
    def EvalValue(self, x):
        v = bump((x-self.xc)/self.width)
        assert v >= 0.0
        assert v <= 1.0
        return v

The "library" training set chooses primitive functions from a set randomly.

In [21]:
class Library(mfem.PyCoefficient):
    
    def SetParams(self):
        # 0 - linear
        # 1 - bump
        # 2 - tanh
        # 3 - wave
        pick = random.randrange(4)
        if pick == 0:
            self.fn = Linear()
        elif pick == 1:
            self.fn = Bump()
        elif pick == 2:
            self.fn = SmoothStep()
        elif pick == 3:
            self.fn = Wave()
        
        self.fn.SetParams()

    def EvalValue(self, x):
        return self.fn.EvalValue(x)

Visualize an instance of the test function. Note that each instance has randomly chosen parameters.  For the steps, it's a rotation angle and a displacement.  For the bumps, it's a width and a displacement.

In [22]:
mesh1 = mfem.Mesh('inline-quad-5.mesh')
#mesh1.UniformRefinement()
fec1 = mfem.L2_FECollection(p=1, dim=2)
fes1 = mfem.FiniteElementSpace(mesh1, fec1)
u1 = mfem.GridFunction(fes1)
c1 = Wave()
c1.SetParams()
u1.ProjectCoefficient(c1)

In [23]:
glvis(to_stream(mesh1,u1) + 'keys Rjlmc',600,600)

glvis()

Create the gym environment.

In [24]:
class EstimatorGame(gym.Env):
    
    class u0_coeff(mfem.PyCoefficient):
        
        def SetParams(self):
            self.fn = Library()
            #self.fn = Bump()

            self.fn.SetParams()
            
        def Print(self):
            self.fn.Print()
            
        def EvalValue(self, x):
            v = self.fn.EvalValue(x)
            #assert v >= 0.0
            #assert v <= 1.0
            return self.fn.EvalValue(x)
        
    # In RLlib, you need the config arg
    def __init__(self,config):
        self.meshfile = 'inline-quad-5.mesh'
        
        # keep a copy of the unrefined mesh so we can restore it from memory
        self.mesh0 = mfem.Mesh(self.meshfile)
        self.mesh = mfem.Mesh(self.meshfile)
        
        # The only reason we need to create an fespace and gf here
        # is to find the sizes needed for the action and observation spaces
        dim = self.mesh.Dimension()
        self.order = 1
        self.fec = mfem.L2_FECollection(self.order, dim)
        self.fes = mfem.FiniteElementSpace(self.mesh, self.fec)
        self.u = mfem.GridFunction(self.fes);

        self.action_space = spaces.Discrete(2) # refine center element or don't
        
        # observation is a set of sample points forming a 2D image
        self.obsx = 42
        self.obsy = 42
        
        self.thresh = 1.e-4
        
        # add a little extra range on the space to account for interpolation errors
        self.observation_space = spaces.Box(-1.0, 2.0, shape=(self.obsx,self.obsy,1))
        self.get_obs_points()
        
        self.n = 0
        
        # call reset to create the first synthetic function
        self.reset()
        
        #self.gl = GlvisWidget(get_solnstream(self.mesh,self.u))
        
    # precompute the observation points and the elements and integration points we need
    def get_obs_points(self):
        n = math.sqrt(self.mesh.GetNE())
        dx = 1.0/self.obsx
        dy = 1.0/self.obsy
        self.sample_pts = []
        self.sample_els = []
        self.sample_ips = []
        for j in range(self.obsy):
            for i in range(self.obsx):
                pt = [i*dx+0.5*dx,j*dy+0.5*dy]
                self.sample_pts.append(pt)
                n, el, ip = self.mesh.FindPoints([pt])
                #assert n == 1
                #assert ip[0].x > 0.0
                #assert ip[0].x < 1.0
                #assert ip[0].y > 0.0
                #assert ip[0].y < 1.0
                #assert el[0] >= 0
                #assert el[0] < self.mesh.GetNE()
                # copy these so they won't be destroyed when mesh goes away?
                ip0 = mfem.IntegrationPoint()
                ip0.x = ip[0].x
                ip0.y = ip[0].y
                self.sample_els.append(el[0])
                self.sample_ips.append(ip0)
                
    def get_obs(self):
        state = np.empty((self.obsx,self.obsy,1))
        k = 0
        for j in range(self.obsy):
            for i in range(self.obsx):
                #assert k < len(self.sample_els)
                #assert k < len(self.sample_ips)
                el = self.sample_els[k]
                #assert el >= 0
                #assert el < self.mesh.GetNE()
                ip = self.sample_ips[k]
                #assert ip.x >= 0.0, "k={}, i={}, j={}, x is {}".format(k,i,j,ip.x)
                #assert ip.x <= 1.0, "k={}, i={}, j={}, x is {}".format(k,i,j,ip.x)
                #assert ip.y >= 0.0, "k={}, i={}, j={}, y is {}".format(k,i,j,ip.y)
                #assert ip.y <= 1.0, "k={}, i={}, j={}, y is {}".format(k,i,j,ip.y)
                v = self.u.GetValue(self.sample_els[k],self.sample_ips[k])
                state[i][j] = v
                if (v > 2.0 or v < -1.0):
                    print("element %d" % self.sample_els[k])
                    print("ip.x = %f" % self.sample_ips[k].x)
                    print("ip.y = %f" % self.sample_ips[k].y)
                    print("%d,%d -> %f" % (i,j,v))
                    print("%d,%d -> %f" % (i,j,state[i][j]))
                    self.u0.Print()
                k += 1
        self.state = state
        return state
                    
    # Compute L2 error wrt to the analytic fn definition
    def get_element_error(self):
        el_err = mfem.Vector(np.empty(25))
        self.u.ComputeElementL2Errors(self.u0, el_err)
        return el_err[10]
        #return 0.0
        
    def get_refined_elements_error(self):
        el_err = mfem.Vector(np.empty(28))
        self.u.ComputeElementL2Errors(self.u0, el_err)
        err = math.sqrt(el_err[10]**2+el_err[11]**2+el_err[12]**2+el_err[13]**2)
        return err
    
    # Manually refine the elements in the array elems
    def refine_elems(self, elems):
        self.mesh.GeneralRefinement(mfem.intArray(elems))
        self.fes.Update()
        self.u.Update()
        self.u.ProjectCoefficient(self.u0)
            
    # action is the number of the element to refine
    def step(self, action):
        self.n += 1
        
        reward = 0.0
        
        if action == 1:
            err1 = self.get_element_error()
            if err1 < self.thresh:
                reward = err1 -self.thresh
            else:
                self.refine_elems([10])
                err2 = self.get_refined_elements_error()
                baseline = max(err2,self.thresh)
                reward = err1-baseline
        done = True
        
        # use old state
        return np.array(self.state), reward*1.e5, done, {}
    
    # similar to reset, but do not choose a new function
    def reinit(self):
        #print("reinit")
        del self.mesh
        self.mesh = mfem.Mesh(self.mesh0)

        del self.fes
        self.fes = mfem.FiniteElementSpace(self.mesh, self.fec)

        del self.u
        self.u = mfem.GridFunction(self.fes)
        self.u.ProjectCoefficient(self.u0)
        
        #self.get_obs_points()
        obs = self.get_obs()
        
        return np.array(obs)
    
    # every reset of the env chooses a new synthetic function
    def reset(self):
        #print("reset")
        self.u0 = self.u0_coeff()
        self.u0.SetParams()
        return self.reinit()
    
    def render(self):
        return glvis(to_stream(self.mesh,self.u) + 'keys Rjlmc',600,600)

Instantiate the environment and sanity check it.

In [25]:
env = EstimatorGame(None)
env.render()

glvis()

In [26]:
env.reset()
obs, reward, done, _ = env.step(1)
env.render()

glvis()

In [27]:
reward

1.310585517081649

Create a convenience function for applying a policy to a given observation

In [28]:
def apply_policy(model, obs):
    action = agent.compute_action(obs, explore=False) # use deterministic mode
    state, reward, done, info = env.step(action)
    #print("policy chooses action %d with reward %f" % (action, reward))
    return action, reward

Run a more systematic evaluation using an ensemble of samples:

In [29]:
def eval_ensemble(model, ntrials):
    ncorrect = 0.0
    sumsq = 0.0
    maxerrsq = 0.0
    dg_ncorrect = 0.0
    dg_sumsq = 0.0
    dg_maxerrsq = 0.0
    for n in range(ntrials):
        obs = env.reset()
        bestaction, bestreward = find_optimal(obs)
        dgaction, dgreward = find_dgjumps(env)
        action, reward = apply_policy(model,obs)
        err = bestreward-reward
        maxerrsq = max(err*err,maxerrsq)
        sumsq += err*err
        dg_err = bestreward-dgreward
        dg_maxerrsq = max(dg_err*dg_err,dg_maxerrsq)
        dg_sumsq += dg_err*dg_err
        if (bestaction == action):
            ncorrect += 1
        if (bestaction == dgaction):
            dg_ncorrect += 1
    rms = math.sqrt(sumsq/ntrials)
    corr = 100.*ncorrect/ntrials
    print("policy rms error: ",rms,flush=True)
    print("policy max sq error: ",maxerrsq,flush=True)
    print("policy % correct: ",corr,flush=True)
    dg_rms = math.sqrt(dg_sumsq/ntrials)
    dg_corr = 100.*dg_ncorrect/ntrials
    print("dg rms error: ",dg_rms,flush=True)
    print("dg max sq error: ",dg_maxerrsq,flush=True)
    print("dg % correct: ",dg_corr,flush=True)
    return rms, math.sqrt(maxerrsq), corr, dg_rms, math.sqrt(dg_maxerrsq), dg_corr

eval_ensemble(model, 100)

Run a few eval sample sizes to get a sense of how many are needed to estimate the metrics of the policy

eval_ensemble(model, 200)

eval_ensemble(model, 400)

Let's see if the training process is making progress:

In [30]:
ray.shutdown()
ray.init(ignore_reinit_error=True)

total_episodes = 2.e6
batch_size = 1000
nbatches = int(total_episodes/batch_size)

config = ppo.DEFAULT_CONFIG.copy()
config['train_batch_size'] = batch_size
config['num_workers'] = 3
config['num_gpus'] = 0
#config['lr'] = 1.e-5
config

2021-02-27 20:58:29,655	INFO services.py:1174 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8266[39m[22m


{'num_workers': 3,
 'num_envs_per_worker': 1,
 'create_env_on_driver': False,
 'rollout_fragment_length': 200,
 'batch_mode': 'truncate_episodes',
 'num_gpus': 0,
 'train_batch_size': 1000,
 'model': {'fcnet_hiddens': [256, 256],
  'fcnet_activation': 'tanh',
  'conv_filters': None,
  'conv_activation': 'relu',
  'free_log_std': False,
  'no_final_linear': False,
  'vf_share_layers': False,
  'use_lstm': False,
  'max_seq_len': 20,
  'lstm_cell_size': 256,
  'lstm_use_prev_action': False,
  'lstm_use_prev_reward': False,
  '_time_major': False,
  'use_attention': False,
  'attention_num_transformer_units': 1,
  'attention_dim': 64,
  'attention_num_heads': 1,
  'attention_head_dim': 32,
  'attention_memory_inference': 50,
  'attention_memory_training': 50,
  'attention_position_wise_mlp_dim': 32,
  'attention_init_gru_gate_bias': 2.0,
  'num_framestacks': 'auto',
  'dim': 84,
  'grayscale': False,
  'zero_mean': True,
  'custom_model': None,
  'custom_model_config': {},
  'custom_actio

In [31]:
agent = ppo.PPOTrainer(config, env=EstimatorGame)
#agent.restore("/home/rwa/ray_results/PPO_EstimatorGame_2021-02-26_14-31-49nivcscnb/checkpoint_10/checkpoint-10")
policy = agent.get_policy()
model = policy.model

rms = []
cor = []
maxerr = []

dg_rms = []
dg_cor = []
dg_maxerr = []

checkpoint_period = 5000

neval = 50
eval_period = 0

episode = 0
checkpoint_episode = 0
eval_episode = 0
for n in range(nbatches):
    print("training batch %d of size %d" % (n,config['train_batch_size']))
    agent.train()
    episode += config['train_batch_size']
    checkpoint_episode += config['train_batch_size']
    if (checkpoint_episode >= checkpoint_period):
        checkpoint_episode = 0
        checkpoint_path = agent.save()
        print(checkpoint_path)
        
    eval_episode += config['train_batch_size']
    if (eval_period and eval_episode >= eval_period):
        eval_episode = 0
        rms1, maxerr1, cor1, dg_rms1, dg_maxerr1, dg_cor1 = eval_ensemble(model, neval)
        rms.append(rms1)
        maxerr.append(maxerr1)
        cor.append(cor1)
        dg_rms.append(dg_rms1)
        dg_maxerr.append(dg_maxerr1)
        dg_cor.append(dg_cor1)

2021-02-27 20:58:32,178	INFO trainer.py:616 -- Tip: set framework=tfe or the --eager flag to enable TensorFlow eager execution
2021-02-27 20:58:32,179	INFO trainer.py:643 -- Current log_level is WARN. For more information, set 'log_level': 'INFO' / 'DEBUG' or use the -v and -vv flags.
[2m[36m(pid=19884)[0m Instructions for updating:
[2m[36m(pid=19884)[0m non-resource variables are not supported in the long term
[2m[36m(pid=19885)[0m Instructions for updating:
[2m[36m(pid=19885)[0m non-resource variables are not supported in the long term
[2m[36m(pid=19887)[0m Instructions for updating:
[2m[36m(pid=19887)[0m non-resource variables are not supported in the long term
[2m[36m(pid=19884)[0m Instructions for updating:
[2m[36m(pid=19884)[0m If using Keras pass *_constraint arguments to layers.
[2m[36m(pid=19885)[0m Instructions for updating:
[2m[36m(pid=19885)[0m If using Keras pass *_constraint arguments to layers.
[2m[36m(pid=19887)[0m Instructions for updat

Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where




training batch 0 of size 1000
Instructions for updating:
Prefer Variable.assign which has equivalent behavior in 2.X.


[2m[36m(pid=19884)[0m Instructions for updating:
[2m[36m(pid=19884)[0m Prefer Variable.assign which has equivalent behavior in 2.X.
[2m[36m(pid=19885)[0m Instructions for updating:
[2m[36m(pid=19885)[0m Prefer Variable.assign which has equivalent behavior in 2.X.
[2m[36m(pid=19887)[0m Instructions for updating:
[2m[36m(pid=19887)[0m Prefer Variable.assign which has equivalent behavior in 2.X.


training batch 1 of size 1000
training batch 2 of size 1000
training batch 3 of size 1000
training batch 4 of size 1000
/home/rwa/ray_results/PPO_EstimatorGame_2021-02-27_20-58-32bp5n1ofm/checkpoint_5/checkpoint-5
training batch 5 of size 1000
training batch 6 of size 1000
training batch 7 of size 1000
training batch 8 of size 1000
training batch 9 of size 1000
/home/rwa/ray_results/PPO_EstimatorGame_2021-02-27_20-58-32bp5n1ofm/checkpoint_10/checkpoint-10
training batch 10 of size 1000
training batch 11 of size 1000
training batch 12 of size 1000
training batch 13 of size 1000
training batch 14 of size 1000
/home/rwa/ray_results/PPO_EstimatorGame_2021-02-27_20-58-32bp5n1ofm/checkpoint_15/checkpoint-15
training batch 15 of size 1000
training batch 16 of size 1000
training batch 17 of size 1000
training batch 18 of size 1000
training batch 19 of size 1000
/home/rwa/ray_results/PPO_EstimatorGame_2021-02-27_20-58-32bp5n1ofm/checkpoint_20/checkpoint-20
training batch 20 of size 1000
trainin

KeyboardInterrupt: 

In [None]:
%matplotlib inline
isteps = list(range(nbatches))
asteps = [i*config['train_batch_size'] for i in isteps]
import matplotlib.pyplot as plt
ax = plt.subplot(211)
ax.set_ylim(0.00001,0.01)
ax.set_ylabel('Error')
line1, = plt.semilogy(asteps,rms[:nbatches], marker='o')
line2, = plt.semilogy(asteps,dg_rms[:nbatches], marker='x')
line3, = plt.semilogy(asteps,maxerr[:nbatches], marker='.')
line4, = plt.semilogy(asteps,dg_maxerr[:nbatches], marker='+')

line1.set_label('RL rms')
line2.set_label('DG rms')
line3.set_label('RL max')
line4.set_label('DG max')
ax.legend()
plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))

ax = plt.subplot(212)
ax.set_ylim(0,100)
ax.set_ylabel('% correct')
ax.set_xlabel('training episodes')
line1, = plt.plot(asteps,cor[:nbatches], marker='o')
line2, = plt.plot(asteps,dg_cor[:nbatches], marker='x')
line1.set_label('RL policy')
line2.set_label('DG')
ax.legend()
plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))

In [None]:
rms

Let's look for cases where the policy gets it right and the DG method gets it wrong.

In [None]:
for n in range(500):
    obs = env.reset()
    opt_action, opt_reward = find_optimal(obs)
    dg_action, dg_reward = find_dgjumps(env)
    pol_action, pol_reward = apply_policy(model, obs)
    if ((pol_action == opt_action) and (dg_action != opt_action)):
        break
env.reinit()
env.step(pol_action)
env.render()

In [None]:
env.reinit()
env.step(dg_action)
env.render()