In [2]:
%load_ext Cython

In [1]:
%%cython 
# distutils: language = c++
import numpy as np
import gym
import torch
#from thinker import util
import thinker.util as util

import cython
from libcpp cimport bool
from libcpp.vector cimport vector
from cpython.ref cimport PyObject, Py_INCREF, Py_DECREF
from libc.math cimport sqrt
from libc.stdlib cimport malloc, free

# util function

@cython.cdivision(True)
cdef float average(vector[float]& arr):
    cdef int n = arr.size()    
    if n == 0: return 0.
    cdef float sum = 0
    cdef int i    
    for i in range(n): sum += arr[i]
    return sum / n    

cdef float maximum(vector[float]& arr):
    cdef int n = arr.size()    
    if n == 0: return 0.
    cdef float max_val = arr[0]
    cdef int i
    for i in range(1, n): 
        if arr[i] > max_val: 
            max_val = arr[i]
    return max_val       

# Node-related function (we use structure instead of class to minimize Python code)

cdef struct Node:
    int action # action
    float r # reward
    float v # value
    int t # time step when last expanded
    bool done # whether done or not
    float logit # logit    
    vector[Node*]* ppchildren # children node list
    Node* pparent # parent node
    float trail_r # trailing reward
    float trail_discount # trailing discount
    float rollout_q # trailing rollout q    
    bool visited # visited?
    vector[float]* prollout_qs # all rollout return
    vector[vector[Node*]*]* ppaths # node path corresponding to the rollout return in prollout_qs
    int rollout_n # number of rollout
    float max_q # maximum of all v    
    PyObject* encoded # all python object    
    int rec_t # number of planning step
    int num_actions # number of actions
    float discounting # discount rate
    bool remember_path # whethere to remember path for the rollout

cdef Node* node_new(Node* pparent, int action, float logit, int num_actions, float discounting, int rec_t, bool remember_path):
    cdef Node* pnode = <Node*> malloc(sizeof(Node))
    cdef vector[Node*]* ppchildren =  new vector[Node*]()
    cdef vector[float]* prollout_qs = new vector[float]()
    cdef vector[vector[Node*]*]* ppaths
    if remember_path:
        ppaths = new vector[vector[Node*]*]()
    else:
        ppaths = NULL
    pnode[0] = Node(action=action, r=0., v=0., t=0, done=False, logit=logit, ppchildren=ppchildren, pparent=pparent, trail_r=0., trail_discount=1., rollout_q=0,
        visited=False, prollout_qs=prollout_qs, ppaths=ppaths, rollout_n=0, max_q=0., encoded=NULL, rec_t=rec_t, num_actions=num_actions, discounting=discounting, remember_path=remember_path)
    return pnode

cdef bool node_expanded(Node* pnode, int t):
    """
    Whether the node is expanded after time step t
    """
    return pnode[0].ppchildren[0].size() > 0 and t <= pnode[0].t

cdef node_expand(Node* pnode, float r, float v, int t, bool done, float[:] logits, PyObject* encoded, bool override):
    """
    First time arriving a node and so we expand it
    """
    cdef int a    
    cdef Node* pnode_
    if override and not node_expanded(pnode, -1):
        override = False # no override if not yet expanded

    if not override: 
        assert not node_expanded(pnode, -1), "node should not be expanded"
    else:        
        pnode[0].prollout_qs[0][0] = r + v * pnode[0].discounting
        for a in range(1, int(pnode[0].prollout_qs[0].size())):
            pnode[0].prollout_qs[0][a] = pnode[0].prollout_qs[0][a] - pnode[0].r + r
        if pnode[0].pparent != NULL and pnode[0].remember_path:
            node_refresh(pnode[0].pparent, pnode, r - pnode[0].r, v - pnode[0].v, pnode[0].discounting, 1)
    pnode[0].r = r
    pnode[0].v = v
    pnode[0].t = t
    if pnode[0].encoded != NULL: 
        Py_DECREF(<object>pnode[0].encoded)    
    pnode[0].encoded = encoded
    pnode[0].done = done
    Py_INCREF(<object>encoded)
    for a in range(pnode[0].num_actions):
        if not override:
            pnode[0].ppchildren[0].push_back(node_new(pparent=pnode, action=a, logit=logits[a], 
                num_actions = pnode[0].num_actions, discounting = pnode[0].discounting, rec_t = pnode[0].rec_t,
                remember_path = pnode[0].remember_path))
        else:
            pnode[0].ppchildren[0][a][0].logit = logits[a]    

cdef node_refresh(Node* pnode, Node* pnode_to_refresh, float r_diff, float v_diff, float discounting, int depth):
    """
    Refresh the r and v in the rollout_qs that contains pnode[0]; only available when remember_path is enabled
    """
    cdef int i, j, k
    cdef Node* pnode_check
    for i in range(int(pnode[0].ppaths[0].size())):
        j = int(pnode[0].ppaths[0][i][0].size())
        k = j - 1 - depth
        if k < 0: continue
        pnode_check = pnode[0].ppaths[0][i][0][k]
        if pnode_check == pnode_to_refresh:
            pnode[0].prollout_qs[0][i] += discounting * r_diff
            if k == 0:
                pnode[0].prollout_qs[0][i] += discounting * pnode[0].discounting * v_diff
    if pnode[0].pparent != NULL: node_refresh(pnode[0].pparent, pnode_to_refresh, r_diff, v_diff, discounting * pnode[0].discounting, depth+1)

cdef node_visit(Node* pnode):
    cdef vector[Node*]* ppath    
    pnode[0].trail_r = 0.
    pnode[0].trail_discount = 1.    
    if not pnode[0].visited and pnode[0].remember_path:
        ppath = new vector[Node*]()
    else:
        ppath = NULL
    node_propagate(pnode=pnode, r=pnode[0].r, v=pnode[0].v, new_rollout=not pnode[0].visited, ppath=ppath)
    pnode[0].visited = True

cdef void node_propagate(Node* pnode, float r, float v, bool new_rollout, vector[Node*]* ppath):
    cdef int i
    cdef vector[Node*]* ppath_
    pnode[0].trail_r = pnode[0].trail_r + pnode[0].trail_discount * r
    pnode[0].trail_discount = pnode[0].trail_discount * pnode[0].discounting
    pnode[0].rollout_q = pnode[0].trail_r + pnode[0].trail_discount * v    
    if new_rollout:
        if pnode[0].remember_path:
            ppath_ = new vector[Node*]()
            for i in range(int(ppath[0].size())):                
                ppath_.push_back(ppath[0][i])
            ppath_.push_back(pnode)
            pnode[0].ppaths[0].push_back(ppath_)
        else:
            ppath_ = NULL
        pnode[0].prollout_qs[0].push_back(pnode[0].rollout_q)        
        pnode[0].rollout_n = pnode[0].rollout_n + 1        
    if pnode[0].pparent != NULL: 
        node_propagate(pnode[0].pparent, r, v, new_rollout, ppath=ppath_)

#@cython.cdivision(True)
cdef float[:] node_stat(Node* pnode, bool detailed, bool reward_transform):
    cdef float[:] result = np.zeros((pnode[0].num_actions*5+5) if detailed else (pnode[0].num_actions*5+2), dtype=np.float32) 
    cdef int i
    result[pnode[0].action] = 1. # action
    if not reward_transform:
        result[pnode[0].num_actions] = pnode[0].r # reward
        result[pnode[0].num_actions+1] = pnode[0].v # value
        for i in range(int(pnode[0].ppchildren[0].size())):
            child = pnode[0].ppchildren[0][i][0]
            result[pnode[0].num_actions+2+i] = child.logit # child_logits
            result[pnode[0].num_actions*2+2+i] = average(child.prollout_qs[0]) # child_rollout_qs_mean
            result[pnode[0].num_actions*3+2+i] = maximum(child.prollout_qs[0]) # child_rollout_qs_max
            result[pnode[0].num_actions*4+2+i] = child.rollout_n / <float>pnode[0].rec_t # child_rollout_ns_enc
        if detailed:
            pnode[0].max_q = (maximum(pnode[0].prollout_qs[0]) - pnode[0].r) / pnode[0].discounting
            result[pnode[0].num_actions*5+2] = pnode[0].trail_r / pnode[0].discounting
            result[pnode[0].num_actions*5+3] = pnode[0].rollout_q / pnode[0].discounting
            result[pnode[0].num_actions*5+4] = pnode[0].max_q
    else:
        result[pnode[0].num_actions] = enc(pnode[0].r) # reward
        result[pnode[0].num_actions+1] = enc(pnode[0].v) # value
        for i in range(int(pnode[0].ppchildren[0].size())):
            child = pnode[0].ppchildren[0][i][0]
            result[pnode[0].num_actions+2+i] = child.logit # child_logits
            result[pnode[0].num_actions*2+2+i] = enc(average(child.prollout_qs[0])) # child_rollout_qs_mean
            result[pnode[0].num_actions*3+2+i] = enc(maximum(child.prollout_qs[0])) # child_rollout_qs_max
            result[pnode[0].num_actions*4+2+i] = child.rollout_n / <float>pnode[0].rec_t # child_rollout_ns_enc
        if detailed:
            pnode[0].max_q = (maximum(pnode[0].prollout_qs[0]) - pnode[0].r) / pnode[0].discounting
            result[pnode[0].num_actions*5+2] = enc(pnode[0].trail_r / pnode[0].discounting)
            result[pnode[0].num_actions*5+3] = enc(pnode[0].rollout_q / pnode[0].discounting)
            result[pnode[0].num_actions*5+4] = enc(pnode[0].max_q)
    return result

cdef node_del(Node* pnode, int except_idx):
    cdef int i
    del pnode[0].prollout_qs

    if pnode[0].ppaths != NULL:
        for i in range(int(pnode[0].ppaths[0].size())):
            del pnode[0].ppaths[0][i]
        del pnode[0].ppaths

    for i in range(int(pnode[0].ppchildren[0].size())):
        if i != except_idx:
            node_del(pnode[0].ppchildren[0][i], -1)
        else:
            pnode[0].ppchildren[0][i][0].pparent = NULL
    del pnode[0].ppchildren
    if pnode[0].encoded != NULL:
        Py_DECREF(<object>pnode[0].encoded)
    free(pnode)

cdef float enc(float x):
    return sign(x)*(sqrt(abs(x)+1)-1)+(0.001)*x

cdef float sign(float x):
    if x > 0.: return 1.
    if x < 0.: return -1.
    return 0.

cdef float abs(float x):
    if x > 0.: return x
    if x < 0.: return -x
    return 0.


cdef class cVecFullModelWrapper():
    """Wrap the gym environment with a model; output for each 
    step is (out, reward, done, info), where out is a tuple 
    of (gym_env_out, model_out, model_encodes) that corresponds to underlying 
    environment frame, output from the model wrapper, and encoding from the model
    Assume a learned dynamic model.
    """
    # setting
    cdef int rec_t
    cdef float discounting
    cdef float depth_discounting
    cdef int max_allow_depth
    cdef bool perfect_model
    cdef bool tree_carry
    cdef int reward_type
    cdef bool reward_transform
    cdef bool actor_see_encode
    cdef bool actor_see_double_encode    
    cdef int num_actions
    cdef int obs_n    
    cdef int env_n
    cdef bool time 

    # python object
    cdef object device
    cdef object env
    cdef object timings
    cdef readonly baseline_max_q
    cdef readonly baseline_mean_q    
    cdef readonly object model_out_shape
    cdef readonly object gym_env_out_shape

    # tree statistic
    cdef vector[Node*] cur_nodes
    cdef vector[Node*] root_nodes    
    cdef float[:] root_nodes_qmax
    cdef float[:] root_nodes_qmax_
    cdef int[:] rollout_depth
    cdef int[:] max_rollout_depth
    cdef int[:] cur_t

    # internal variables only used in step function
    cdef float[:] depth_delta
    cdef int[:] max_rollout_depth_
    cdef float[:] mean_q
    cdef float[:] max_q
    cdef int[:] status
    cdef vector[Node*] cur_nodes_
    cdef float[:] par_logits
    cdef float[:, :] full_reward
    cdef bool[:] full_done
    cdef bool[:] full_real_done
    cdef int[:] total_step

    def __init__(self, env, env_n, flags, device=None, time=False):
        assert not flags.perfect_model, "this class only supports imperfect model"
        self.device = torch.device("cpu") if device is None else device
        self.env = env     
        self.rec_t = flags.rec_t               
        self.discounting = flags.discounting
        self.depth_discounting = flags.depth_discounting
        self.max_allow_depth = flags.max_depth
        self.perfect_model = flags.perfect_model
        self.tree_carry = flags.tree_carry
        self.num_actions = env.action_space[0].n
        self.reward_type = flags.reward_type
        self.reward_transform = flags.reward_transform
        self.actor_see_encode = flags.actor_see_encode  
        self.actor_see_double_encode = flags.actor_see_double_encode
        self.env_n = env_n
        self.obs_n = 9 + self.num_actions * 10 + self.rec_t
        self.model_out_shape = (self.obs_n, 1, 1)
        self.gym_env_out_shape = env.observation_space.shape[1:]

        self.baseline_max_q = torch.zeros(self.env_n, dtype=torch.float32, device=self.device)
        self.baseline_mean_q = torch.zeros(self.env_n, dtype=torch.float32, device=self.device)        
        self.time = time
        self.timings = util.Timings()

        # internal variable init.
        self.depth_delta = np.zeros(self.env_n, dtype=np.float32)
        self.max_rollout_depth_ = np.zeros(self.env_n, dtype=np.intc)
        self.mean_q =  np.zeros(self.env_n, dtype=np.float32)
        self.max_q = np.zeros(self.env_n, dtype=np.float32)
        self.status = np.zeros(self.env_n, dtype=np.intc)
        self.par_logits = np.zeros(self.num_actions, dtype=np.float32)
        self.full_reward = np.zeros((self.env_n, 2 if self.reward_type == 1 else 1), dtype=np.float32)
        self.full_done = np.zeros(self.env_n, dtype=np.bool)
        self.full_real_done = np.zeros(self.env_n, dtype=np.bool)
        self.total_step = np.zeros(self.env_n, dtype=np.intc)
        
    def reset(self, model_net):
        """reset the environment; should only be called in the initial"""
        cdef int i
        cdef Node* root_node
        cdef Node* cur_node
        cdef float[:,:] model_out        

        with torch.no_grad():
            # some init.
            self.root_nodes_qmax = np.zeros(self.env_n, dtype=np.float32)
            self.root_nodes_qmax_ = np.zeros(self.env_n, dtype=np.float32)
            self.rollout_depth = np.zeros(self.env_n, dtype=np.intc)
            self.max_rollout_depth = np.zeros(self.env_n, dtype=np.intc)
            self.cur_t = np.zeros(self.env_n, dtype=np.intc)

            # reset obs
            obs = self.env.reset()

            # obtain output from model
            obs_py = torch.tensor(obs, dtype=torch.uint8, device=self.device)
            pass_action = torch.zeros(self.env_n, dtype=torch.long)
            _, _, vs, _, logits, model_encodes = model_net(obs_py, 
                                                pass_action.unsqueeze(0).to(self.device), 
                                                one_hot=False)  
            vs = vs.cpu()
            logits = logits.cpu()

            # compute and update root node and current node
            for i in range(self.env_n):
                root_node = node_new(pparent=NULL, action=pass_action[i].item(), logit=0., num_actions=self.num_actions, 
                    discounting=self.discounting, rec_t=self.rec_t, remember_path=True)                
                encoded = {"gym_env_out": obs_py[i], "model_encodes": model_encodes[-1,i]}
                node_expand(pnode=root_node, r=0., v=vs[-1, i].item(), t=self.total_step[i], done=False,
                    logits=logits[-1, i].numpy(), encoded=<PyObject*>encoded, override=False)
                node_visit(pnode=root_node)
                self.root_nodes.push_back(root_node)
                self.cur_nodes.push_back(root_node)
            
            # compute model_out
            model_out = self.compute_model_out(None, None)

            gym_env_out = []
            for i in range(self.env_n):
                encoded = <dict>self.cur_nodes[i][0].encoded
                if encoded["gym_env_out"] is not None:
                    gym_env_out.append(encoded["gym_env_out"].unsqueeze(0))
            if len(gym_env_out) > 0:
                gym_env_out = torch.concat(gym_env_out)
            else:
                gym_env_out = None

            if self.actor_see_encode:
                model_encodes = []
                for i in range(self.env_n):
                    encoded = <dict>self.cur_nodes[i][0].encoded
                    model_encodes.append(encoded["model_encodes"].unsqueeze(0))
                model_encodes = torch.concat(model_encodes)

                if self.actor_see_double_encode:
                    model_encodes = torch.concat([model_encodes, model_encodes], dim=1)
            else:
                model_encodes = None

            # record initial root_nodes_qmax 
            for i in range(self.env_n):
                self.root_nodes_qmax[i] = self.root_nodes[i][0].max_q
            
            return torch.tensor(model_out, dtype=torch.float32, device=self.device), gym_env_out, model_encodes

    def step(self, action, model_net):  
        # action is tensor of shape (env_n, 3)
        # which corresponds to real_action, im_action, reset, term
        
        cdef int i, j, k, l
        cdef int[:] re_action
        cdef int[:] im_action
        cdef int[:] reset

        cdef Node* root_node
        cdef Node* cur_node
        cdef Node* next_node
        cdef vector[Node*] cur_nodes_
        cdef vector[Node*] root_nodes_    
        cdef float[:,:] model_out        

        cdef vector[int] pass_inds_restore
        cdef vector[int] pass_inds_step
        cdef vector[int] pass_inds_reset
        cdef vector[int] pass_inds_reset_
        cdef vector[int] pass_action
        cdef vector[int] pass_model_action

        cdef float[:] vs_1
        cdef float[:,:] logits_1

        cdef float[:] rs_4
        cdef float[:] vs_4
        cdef float[:,:] logits_4

        if self.time: self.timings.reset()
        action = action.cpu().int().numpy()
        re_action, im_action, reset = action[:, 0], action[:, 1], action[:, 2]

        pass_model_encodes = []

        for i in range(self.env_n):            
            # compute the mask of real / imagination step                             
            self.max_rollout_depth_[i] = self.max_rollout_depth[i]
            self.depth_delta[i] = self.depth_discounting ** self.rollout_depth[i]
            if self.cur_t[i] < self.rec_t - 1: # imagaination step
                self.cur_t[i] += 1
                self.rollout_depth[i] += 1
                self.max_rollout_depth[i] = max(self.max_rollout_depth[i], self.rollout_depth[i])
                next_node = self.cur_nodes[i][0].ppchildren[0][im_action[i]]
                if node_expanded(next_node, self.total_step[i]):
                    self.status[i] = 2
                else:
                    encoded = <dict> self.cur_nodes[i][0].encoded
                    pass_model_encodes.append(encoded["model_encodes"].unsqueeze(0))
                    pass_model_action.push_back(im_action[i])
                    self.status[i] = 4  
            else: # real step
                self.cur_t[i] = 0
                self.rollout_depth[i] = 0          
                self.max_rollout_depth[i] = 0
                self.total_step[i] = self.total_step[i] + 1
                # record baseline before moving on
                self.baseline_mean_q[i] = average(self.root_nodes[i][0].prollout_qs[0]) / self.discounting
                self.baseline_max_q[i] = maximum(self.root_nodes[i][0].prollout_qs[0]) / self.discounting
                encoded = <dict> self.root_nodes[i][0].encoded
                pass_inds_restore.push_back(i)
                pass_action.push_back(re_action[i])
                pass_inds_step.push_back(i)
                self.status[i] = 1                              
        if self.time: self.timings.time("misc_1")

        # one step of env
        if pass_inds_step.size() > 0:
            obs, reward, done, info = self.env.step(pass_action, inds=pass_inds_step) 
            real_done = [m["real_done"] if "real_done" in m else done[n] for n, m in enumerate(info)]
        if self.time: self.timings.time("step_state")

        # reset needed?
        for i, j in enumerate(pass_inds_step):
            if done[i]:
                pass_inds_reset.push_back(j)
                pass_inds_reset_.push_back(i) # index within pass_inds_step

        # reset
        if pass_inds_reset.size() > 0:
            obs_reset = self.env.reset(inds=pass_inds_reset) 
            for i, j in enumerate(pass_inds_reset_):
                obs[j] = obs_reset[i]
                pass_action[j] = 0        

        # use model for status 1 transition (real transition)
        if pass_inds_step.size() > 0:
            with torch.no_grad():
                obs_py = torch.tensor(obs, dtype=torch.uint8, device=self.device)
                _, _, vs_, _, logits_, model_encodes_1 = model_net(obs_py, 
                        torch.tensor(pass_action, dtype=long, device=self.device).unsqueeze(0), 
                        one_hot=False)  
            vs_1 = vs_[-1].float().cpu().numpy()
            logits_1 = logits_[-1].float().cpu().numpy()
                
        if self.time: self.timings.time("misc_2")
        # use model for status 4 transition (imagination transition)
        if pass_model_action.size() > 0:
            with torch.no_grad():
                pass_model_encodes = torch.concat(pass_model_encodes)
                rs_, _, vs_, _, logits_, model_encodes_4 = model_net.forward_encoded(encoded=pass_model_encodes,                        
                        actions = torch.tensor(pass_model_action, dtype=long, device=self.device).unsqueeze(0), 
                        one_hot=False)  
            rs_4 = rs_[-1].float().cpu().numpy()
            vs_4 = vs_[-1].float().cpu().numpy()
            logits_4 = logits_[-1].float().cpu().numpy()

        # compute the current and root nodes
        j = 0 # counter for status 1 transition
        l = 0 # counter for status 4 transition

        for i in range(self.env_n):
            if self.status[i] == 1:
                # real transition
                new_root = (not self.tree_carry or 
                    not node_expanded(self.root_nodes[i][0].ppchildren[0][re_action[i]], -1) or done[j])
                if new_root:
                    root_node = node_new(pparent=NULL, action=pass_action[j], logit=0., num_actions=self.num_actions, 
                        discounting=self.discounting, rec_t=self.rec_t, remember_path=True)
                    encoded = {"gym_env_out": obs_py[j], "model_encodes": model_encodes_1[-1,j]}
                    node_expand(pnode=root_node, r=0., v=vs_1[j], t=self.total_step[i], done=False,
                        logits=logits_1[j], encoded=<PyObject*>encoded, override=False)
                    node_del(self.root_nodes[i], except_idx=-1)
                    node_visit(root_node)
                else:
                    root_node = self.root_nodes[i][0].ppchildren[0][re_action[i]]
                    encoded = {"gym_env_out": obs_py[j], "model_encodes": model_encodes_1[-1,j]}
                    node_expand(pnode=root_node, r=0., v=vs_1[j], t=self.total_step[i], done=False,
                        logits=logits_1[j], encoded=<PyObject*>encoded, override=True)                        
                    node_del(self.root_nodes[i], except_idx=re_action[i])
                    node_visit(root_node)
                    
                j += 1
                root_nodes_.push_back(root_node)
                cur_nodes_.push_back(root_node)

            elif self.status[i] == 2:
                # expanded already
                cur_node = self.cur_nodes[i][0].ppchildren[0][im_action[i]]
                node_visit(cur_node)
                root_nodes_.push_back(self.root_nodes[i])
                cur_nodes_.push_back(cur_node)             
            
            elif self.status[i] == 4:
                # need expand
                encoded = {"gym_env_out": None, "model_encodes": model_encodes_4[-1,l]}
                cur_node = self.cur_nodes[i][0].ppchildren[0][im_action[i]]
                node_expand(pnode=cur_node, r=rs_4[l], v=vs_4[l], t=self.total_step[i], done=False,
                        logits=logits_4[l], encoded=<PyObject*>encoded, override=True)
                node_visit(cur_node)
                root_nodes_.push_back(self.root_nodes[i])
                cur_nodes_.push_back(cur_node)   
                l += 1             
        self.root_nodes = root_nodes_
        self.cur_nodes = cur_nodes_
        if self.time: self.timings.time("compute_root_cur_nodes")

        # reset if serach depth exceeds max depth
        if self.max_allow_depth > 0:
            for i in range(self.env_n):
                if self.rollout_depth[i] >= self.max_allow_depth:
                    action[i, 2] = 1
                    reset[i] = 1

        # compute model_out        
        model_out = self.compute_model_out(action, self.status)
        gym_env_out = []
        for i in range(self.env_n):
            encoded = <dict>self.cur_nodes[i][0].encoded
            if encoded["gym_env_out"] is not None:
                gym_env_out.append(encoded["gym_env_out"].unsqueeze(0))
        if len(gym_env_out) > 0:
            gym_env_out = torch.concat(gym_env_out)
        else:
            gym_env_out = None

        if self.actor_see_encode:
            model_encodes = []
            for i in range(self.env_n):
                encoded = <dict>self.cur_nodes[i][0].encoded
                model_encodes.append(encoded["model_encodes"].unsqueeze(0))
            model_encodes = torch.concat(model_encodes)
            if self.actor_see_double_encode:
                model_encodes_ = []
                for i in range(self.env_n):
                    encoded = <dict>self.root_nodes[i][0].encoded
                    model_encodes_.append(encoded["model_encodes"].unsqueeze(0))
                model_encodes_ = torch.concat(model_encodes_)
                model_encodes = torch.concat([model_encodes, model_encodes_], dim=1)
        else:
            model_encodes = None

        if self.time: self.timings.time("compute_model_out")
        # compute reward
        j = 0
        for i in range(self.env_n):
            if self.status[i] == 1:
                self.full_reward[i][0] = reward[j]
            else:
                self.full_reward[i][0] = 0.
            if self.reward_type == 1:                        
                self.root_nodes_qmax_[i] = self.root_nodes[i][0].max_q
                if self.status[i] != 1:                
                    self.full_reward[i][1] = (self.root_nodes_qmax_[i] - self.root_nodes_qmax[i])*self.depth_delta[i]
                    if self.full_reward[i][1] < 0: self.full_reward[i][1] = 0
                else:
                    self.full_reward[i][1] = 0.
                self.root_nodes_qmax[i] = self.root_nodes_qmax_[i]
            if self.status[i] == 1:
                j += 1
        if self.time: self.timings.time("compute_reward")
        # compute done & full_real_done
        j = 0
        for i in range(self.env_n):
            if self.status[i] == 1:
                self.full_done[i] = done[j]
                self.full_real_done[i] = real_done[j]
            else:
                self.full_done[i] = False
                self.full_real_done[i] = False
            if self.status[i] == 1:
                j += 1
        # compute reset
        for i in range(self.env_n):
            if reset[i]:
                self.rollout_depth[i] = 0
                self.cur_nodes[i] = self.root_nodes[i]
                node_visit(self.cur_nodes[i])
                self.status[i] = 5 
        # some extra info
        info = {"cur_t": torch.tensor(self.cur_t, dtype=torch.long, device=self.device),
                "max_rollout_depth":  torch.tensor(self.max_rollout_depth_, dtype=torch.long, device=self.device),
                "real_done": torch.tensor(self.full_real_done, dtype=torch.bool, device=self.device)}
        if self.time: self.timings.time("end")

        return ((torch.tensor(model_out, dtype=torch.float32, device=self.device), gym_env_out, model_encodes), 
                torch.tensor(self.full_reward, dtype=torch.float32, device=self.device), 
                torch.tensor(self.full_done, dtype=torch.bool, device=self.device), 
                info)

    
    cdef float[:, :] compute_model_out(self, int[:, :]& action, int[:]& status):
        cdef int i
        cdef int idx1 = self.num_actions*5+5
        cdef int idx2 = self.num_actions*10+7

        result_np = np.zeros((self.env_n, self.obs_n), dtype=np.float32)
        cdef float[:, :] result = result_np        
        for i in range(self.env_n):
            result[i, :idx1] = node_stat(self.root_nodes[i], detailed=True, reward_transform=self.reward_transform)
            result[i, idx1:idx2] = node_stat(self.cur_nodes[i], detailed=False, reward_transform=self.reward_transform)    
            # reset
            if action is None or status[i] == 1:
                result[i, idx2] = 1.
            else:
                result[i, idx2] = action[i, 2]
            # time
            result[i, idx2+1+self.cur_t[i]] = 1.
            # deprec
            result[i, idx2+self.rec_t+1] = (self.discounting ** (self.rollout_depth[i]))           
        return result

    def close(self):
        cdef int i
        if hasattr(self, "root_nodes"):
            for i in range(self.env_n):
                node_del(self.root_nodes[i], except_idx=-1)
        self.env.close()

    def seed(self, x):
        self.env.seed(x)

    def print_time(self):
        print(self.timings.summary())

    def clone_state(self):
        return self.env.clone_state()

    def restore_state(self, state):
        self.env.restore_state(state)

    def get_action_meanings(self):
        return self.env.get_action_meanings()       

cdef class cVecModelWrapper():
    """Wrap the gym environment with a model; output for each 
    step is (out, reward, done, info), where out is a tuple 
    of (gym_env_out, model_out, model_encodes) that corresponds to underlying 
    environment frame, output from the model wrapper, and encoding from the model
    Assume a perfect dynamic model.
    """
    # setting
    cdef int rec_t
    cdef float discounting
    cdef float depth_discounting
    cdef int max_allow_depth
    cdef bool perfect_model
    cdef bool tree_carry
    cdef int reward_type
    cdef bool reward_transform
    cdef bool actor_see_encode
    cdef bool actor_see_double_encode
    cdef int num_actions
    cdef int obs_n    
    cdef int env_n
    cdef bool time 

    # python object
    cdef object device
    cdef object env
    cdef object timings
    cdef readonly baseline_max_q
    cdef readonly baseline_mean_q    
    cdef readonly object model_out_shape
    cdef readonly object gym_env_out_shape

    # tree statistic
    cdef vector[Node*] cur_nodes
    cdef vector[Node*] root_nodes    
    cdef float[:] root_nodes_qmax
    cdef float[:] root_nodes_qmax_
    cdef int[:] rollout_depth
    cdef int[:] max_rollout_depth
    cdef int[:] cur_t

    # internal variables only used in step function
    cdef float[:] depth_delta
    cdef int[:] max_rollout_depth_
    cdef float[:] mean_q
    cdef float[:] max_q
    cdef int[:] status
    cdef vector[Node*] cur_nodes_
    cdef float[:] par_logits
    cdef float[:, :] full_reward
    cdef bool[:] full_done
    cdef bool[:] full_real_done
    cdef int[:] total_step

    def __init__(self, env, env_n, flags, device=None, time=False):
        assert flags.perfect_model, "this class only supports perfect model"
        self.device = torch.device("cpu") if device is None else device
        self.env = env     
        self.rec_t = flags.rec_t               
        self.discounting = flags.discounting
        self.depth_discounting = flags.depth_discounting
        self.max_allow_depth = flags.max_depth
        self.perfect_model = flags.perfect_model
        self.tree_carry = flags.tree_carry
        self.num_actions = env.action_space[0].n
        self.reward_type = flags.reward_type
        self.reward_transform = flags.reward_transform
        self.actor_see_encode = flags.actor_see_encode      
        self.actor_see_double_encode = flags.actor_see_double_encode  
        self.env_n = env_n
        self.obs_n = 9 + self.num_actions * 10 + self.rec_t
        self.model_out_shape = (self.obs_n, 1, 1)
        self.gym_env_out_shape = env.observation_space.shape[1:]

        self.baseline_max_q = torch.zeros(self.env_n, dtype=torch.float32, device=self.device)
        self.baseline_mean_q = torch.zeros(self.env_n, dtype=torch.float32, device=self.device)        
        self.time = time
        self.timings = util.Timings()

        # internal variable init.
        self.depth_delta = np.zeros(self.env_n, dtype=np.float32)
        self.max_rollout_depth_ = np.zeros(self.env_n, dtype=np.intc)
        self.mean_q =  np.zeros(self.env_n, dtype=np.float32)
        self.max_q = np.zeros(self.env_n, dtype=np.float32)
        self.status = np.zeros(self.env_n, dtype=np.intc)
        self.par_logits = np.zeros(self.num_actions, dtype=np.float32)
        self.full_reward = np.zeros((self.env_n, 2 if self.reward_type == 1 else 1), dtype=np.float32)
        self.full_done = np.zeros(self.env_n, dtype=np.bool)
        self.full_real_done = np.zeros(self.env_n, dtype=np.bool)
        self.total_step = np.zeros(self.env_n, dtype=np.intc)
        
    def reset(self, model_net):
        """reset the environment; should only be called in the initial"""
        cdef int i
        cdef Node* root_node
        cdef Node* cur_node
        cdef float[:,:] model_out        

        with torch.no_grad():
            # some init.
            self.root_nodes_qmax = np.zeros(self.env_n, dtype=np.float32)
            self.root_nodes_qmax_ = np.zeros(self.env_n, dtype=np.float32)
            self.rollout_depth = np.zeros(self.env_n, dtype=np.intc)
            self.max_rollout_depth = np.zeros(self.env_n, dtype=np.intc)
            self.cur_t = np.zeros(self.env_n, dtype=np.intc)

            # reset obs
            obs = self.env.reset()

            # obtain output from model
            obs_py = torch.tensor(obs, dtype=torch.uint8, device=self.device)
            pass_action = torch.zeros(self.env_n, dtype=torch.long)
            _, _, vs, _, logits, model_encodes = model_net(obs_py, 
                                                pass_action.unsqueeze(0).to(self.device), 
                                                one_hot=False)  
            vs = vs.cpu()
            logits = logits.cpu()
            env_state = self.env.clone_state(inds=np.arange(self.env_n))

            # compute and update root node and current node
            for i in range(self.env_n):
                root_node = node_new(pparent=NULL, action=pass_action[i].item(), logit=0., num_actions=self.num_actions, 
                    discounting=self.discounting, rec_t=self.rec_t, remember_path=False)                
                if not self.actor_see_encode:
                    encoded = {"env_state": env_state[i], "gym_env_out": obs_py[i]}
                else:
                    encoded = {"env_state": env_state[i], "gym_env_out": obs_py[i], "model_encodes": model_encodes[0,i]}
                node_expand(pnode=root_node, r=0., v=vs[-1, i].item(), t=self.total_step[i], done=False,
                    logits=logits[-1, i].numpy(), encoded=<PyObject*>encoded, override=False)
                node_visit(pnode=root_node)
                self.root_nodes.push_back(root_node)
                self.cur_nodes.push_back(root_node)
            
            # compute model_out
            model_out = self.compute_model_out(None, None)

            gym_env_out = []
            for i in range(self.env_n):
                encoded = <dict>self.cur_nodes[i][0].encoded
                gym_env_out.append(encoded["gym_env_out"].unsqueeze(0))
            gym_env_out = torch.concat(gym_env_out)

            if self.actor_see_encode:
                model_encodes = []
                for i in range(self.env_n):
                    encoded = <dict>self.cur_nodes[i][0].encoded
                    model_encodes.append(encoded["model_encodes"].unsqueeze(0))
                model_encodes = torch.concat(model_encodes)

                if self.actor_see_double_encode:
                    model_encodes = torch.concat([model_encodes, model_encodes], dim=1)
                
            else:
                model_encodes = None

            # record initial root_nodes_qmax 
            for i in range(self.env_n):
                self.root_nodes_qmax[i] = self.root_nodes[i][0].max_q
            
            return torch.tensor(model_out, dtype=torch.float32, device=self.device), gym_env_out, model_encodes

    def step(self, action, model_net):  
        # action is tensor of shape (env_n, 3)
        # which corresponds to real_action, im_action, reset, term

        cdef int i, j, k
        cdef int[:] re_action
        cdef int[:] im_action
        cdef int[:] reset

        cdef Node* root_node
        cdef Node* cur_node
        cdef Node* next_node
        cdef vector[Node*] cur_nodes_
        cdef vector[Node*] root_nodes_    
        cdef float[:,:] model_out        

        cdef vector[int] pass_inds_restore
        cdef vector[int] pass_inds_step
        cdef vector[int] pass_inds_reset
        cdef vector[int] pass_inds_reset_
        cdef vector[int] pass_action

        cdef float[:] vs
        cdef float[:,:] logits

        if self.time: self.timings.reset()
        action = action.cpu().int().numpy()
        re_action, im_action, reset = action[:, 0], action[:, 1], action[:, 2]

        pass_env_states = []

        for i in range(self.env_n):            
            # compute the mask of real / imagination step                             
            self.max_rollout_depth_[i] = self.max_rollout_depth[i]
            self.depth_delta[i] = self.depth_discounting ** self.rollout_depth[i]
            if self.cur_t[i] < self.rec_t - 1: # imagaination step
                self.cur_t[i] += 1
                self.rollout_depth[i] += 1
                self.max_rollout_depth[i] = max(self.max_rollout_depth[i], self.rollout_depth[i])
                next_node = self.cur_nodes[i][0].ppchildren[0][im_action[i]]
                if node_expanded(next_node, -1):
                    self.status[i] = 2
                elif self.cur_nodes[i][0].done:
                    self.status[i] = 3
                else:
                    if self.status[i] != 0 or self.status[i] != 4: # no need restore if last step is real or just expanded
                        encoded = <dict> self.cur_nodes[i][0].encoded
                        pass_env_states.append(encoded["env_state"])
                        pass_inds_restore.push_back(i)
                        pass_action.push_back(im_action[i])
                        pass_inds_step.push_back(i)
                    self.status[i] = 4  
            else: # real step
                self.cur_t[i] = 0
                self.rollout_depth[i] = 0          
                self.max_rollout_depth[i] = 0
                self.total_step[i] = self.total_step[i] + 1
                # record baseline before moving on
                self.baseline_mean_q[i] = average(self.root_nodes[i][0].prollout_qs[0]) / self.discounting
                self.baseline_max_q[i] = maximum(self.root_nodes[i][0].prollout_qs[0]) / self.discounting
                encoded = <dict> self.root_nodes[i][0].encoded
                pass_env_states.append(encoded["env_state"])
                pass_inds_restore.push_back(i)
                pass_action.push_back(re_action[i])
                pass_inds_step.push_back(i)
                self.status[i] = 1                              
        if self.time: self.timings.time("misc_1")

        # restore env      
        if pass_inds_restore.size() > 0:
            self.env.restore_state(pass_env_states, inds=pass_inds_restore)

        # one step of env
        if pass_inds_step.size() > 0:
            obs, reward, done, info = self.env.step(pass_action, inds=pass_inds_step) 
            real_done = [m["real_done"] if "real_done" in m else done[n] for n, m in enumerate(info)]
        if self.time: self.timings.time("step_state")

        # reset needed?
        for i, j in enumerate(pass_inds_step):
            if self.status[j] == 1 and done[i]:
                pass_inds_reset.push_back(j)
                pass_inds_reset_.push_back(i) # index within pass_inds_step
        # reset
        if pass_inds_reset.size() > 0:
            obs_reset = self.env.reset(inds=pass_inds_reset) 
            for i, j in enumerate(pass_inds_reset_):
                obs[j] = obs_reset[i]
                pass_action[j] = 0            
        if self.time: self.timings.time("misc_2")

        # use model
        if pass_inds_step.size() > 0:
            with torch.no_grad():
                obs_py = torch.tensor(obs, dtype=torch.uint8, device=self.device)
                _, _, vs_, _, logits_, model_encodes = model_net(obs_py, 
                        torch.tensor(pass_action, dtype=long, device=self.device).unsqueeze(0), 
                        one_hot=False)  
            vs = vs_[-1].float().cpu().numpy()
            logits = logits_[-1].float().cpu().numpy()
            if self.time: self.timings.time("model")
            env_state = self.env.clone_state(inds=pass_inds_step)   
            if self.time: self.timings.time("clone_state")

        # compute the current and root nodes
        j = 0
        for i in range(self.env_n):
            if self.status[i] == 1:
                # real transition
                new_root = (not self.tree_carry or 
                    not node_expanded(self.root_nodes[i][0].ppchildren[0][re_action[i]], -1) or done[j])
                if new_root:
                    root_node = node_new(pparent=NULL, action=pass_action[j], logit=0., num_actions=self.num_actions, 
                        discounting=self.discounting, rec_t=self.rec_t, remember_path=False)
                    if not self.actor_see_encode:
                        encoded = {"env_state": env_state[j], "gym_env_out": obs_py[j]}
                    else:
                        encoded = {"env_state": env_state[j], "gym_env_out": obs_py[j], "model_encodes": model_encodes[0,j]}
                    node_expand(pnode=root_node, r=0., v=vs[j], t=self.total_step[i], done=False,
                        logits=logits[j], encoded=<PyObject*>encoded, override=False)
                    node_del(self.root_nodes[i], except_idx=-1)
                    node_visit(root_node)
                else:
                    root_node = self.root_nodes[i][0].ppchildren[0][re_action[i]]
                    if not self.actor_see_encode:
                        encoded = {"env_state": env_state[j], "gym_env_out": obs_py[j]}
                    else:
                        encoded = {"env_state": env_state[j], "gym_env_out": obs_py[j], "model_encodes": model_encodes[0,j]}
                    node_expand(pnode=root_node, r=0., v=vs[j], t=self.total_step[i], done=False,
                        logits=logits[j], encoded=<PyObject*>encoded, override=True)                        
                    node_del(self.root_nodes[i], except_idx=re_action[i])
                    node_visit(root_node)
                    
                j += 1
                root_nodes_.push_back(root_node)
                cur_nodes_.push_back(root_node)

            elif self.status[i] == 2:
                # expanded already
                cur_node = self.cur_nodes[i][0].ppchildren[0][im_action[i]]
                node_visit(cur_node)
                root_nodes_.push_back(self.root_nodes[i])
                cur_nodes_.push_back(cur_node)    

            elif self.status[i] == 3:
                # done already
                for k in range(self.num_actions):
                    self.par_logits[k] = self.cur_nodes[i].ppchildren[0][k][0].logit
                cur_node = self.cur_nodes[i][0].ppchildren[0][im_action[i]]
                node_expand(pnode=cur_node, r=0., v=0., t=self.total_step[i], done=True,
                        logits=self.par_logits, encoded=self.cur_nodes[i][0].encoded, override=False)
                node_visit(cur_node)
                root_nodes_.push_back(self.root_nodes[i])
                cur_nodes_.push_back(cur_node)              
            
            elif self.status[i] == 4:
                # need expand
                if not self.actor_see_encode:
                    encoded = {"env_state": env_state[j], "gym_env_out": obs_py[j]}
                else:
                    encoded = {"env_state": env_state[j], "gym_env_out": obs_py[j], "model_encodes": model_encodes[0,j]}
                cur_node = self.cur_nodes[i][0].ppchildren[0][im_action[i]]
                node_expand(pnode=cur_node, r=reward[j], v=vs[j] if not done[j] else 0., t=self.total_step[i], done=done[j],
                        logits=logits[j], encoded=<PyObject*>encoded, override=False)
                node_visit(cur_node)
                root_nodes_.push_back(self.root_nodes[i])
                cur_nodes_.push_back(cur_node)   
                j += 1                            

        self.root_nodes = root_nodes_
        self.cur_nodes = cur_nodes_
        if self.time: self.timings.time("compute_root_cur_nodes")

        # reset if serach depth exceeds max depth
        if self.max_allow_depth > 0:
            for i in range(self.env_n):
                if self.rollout_depth[i] >= self.max_allow_depth:
                    action[i, 2] = 1
                    reset[i] = 1

        # compute model_out        
        model_out = self.compute_model_out(action, self.status)

        gym_env_out = []
        for i in range(self.env_n):
            encoded = <dict>self.cur_nodes[i][0].encoded
            gym_env_out.append(encoded["gym_env_out"].unsqueeze(0))
        gym_env_out = torch.concat(gym_env_out)

        if self.actor_see_encode:
            model_encodes = []
            for i in range(self.env_n):
                encoded = <dict>self.cur_nodes[i][0].encoded
                model_encodes.append(encoded["model_encodes"].unsqueeze(0))
            model_encodes = torch.concat(model_encodes)

            if self.actor_see_double_encode:
                model_encodes_ = []
                for i in range(self.env_n):
                    encoded = <dict>self.root_nodes[i][0].encoded
                    model_encodes_.append(encoded["model_encodes"].unsqueeze(0))
                model_encodes_ = torch.concat(model_encodes_)
                model_encodes = torch.concat([model_encodes, model_encodes_], dim=1)
        else:
            model_encodes = None

        if self.time: self.timings.time("compute_model_out")

        # compute reward
        j = 0
        for i in range(self.env_n):
            if self.status[i] == 1:
                self.full_reward[i][0] = reward[j]
            else:
                self.full_reward[i][0] = 0.
            if self.reward_type == 1:                        
                self.root_nodes_qmax_[i] = self.root_nodes[i][0].max_q
                if self.status[i] != 1:                
                    self.full_reward[i][1] = (self.root_nodes_qmax_[i] - self.root_nodes_qmax[i])*self.depth_delta[i]
                else:
                    self.full_reward[i][1] = 0.
                self.root_nodes_qmax[i] = self.root_nodes_qmax_[i]
            if self.status[i] == 1 or self.status[i] == 4:
                j += 1
        if self.time: self.timings.time("compute_reward")

        # compute done & full_real_done
        j = 0
        for i in range(self.env_n):
            if self.status[i] == 1:
                self.full_done[i] = done[j]
                self.full_real_done[i] = real_done[j]
            else:
                self.full_done[i] = False
                self.full_real_done[i] = False
            if self.status[i] == 1 or self.status[i] == 4:
                j += 1

        # compute reset
        for i in range(self.env_n):
            if reset[i]:
                self.rollout_depth[i] = 0
                self.cur_nodes[i] = self.root_nodes[i]
                node_visit(self.cur_nodes[i])
                self.status[i] = 5 # need to restore state on the next transition, so we need to alter the status from 4
        
        # some extra info
        info = {"cur_t": torch.tensor(self.cur_t, dtype=torch.long, device=self.device),
                "max_rollout_depth":  torch.tensor(self.max_rollout_depth_, dtype=torch.long, device=self.device),
                "real_done": torch.tensor(self.full_real_done, dtype=torch.bool, device=self.device)}
        if self.time: self.timings.time("end")

        return ((torch.tensor(model_out, dtype=torch.float32, device=self.device), gym_env_out, model_encodes), 
                torch.tensor(self.full_reward, dtype=torch.float32, device=self.device), 
                torch.tensor(self.full_done, dtype=torch.bool, device=self.device), 
                info)

    
    cdef float[:, :] compute_model_out(self, int[:, :]& action, int[:]& status):
        cdef int i
        cdef int idx1 = self.num_actions*5+5
        cdef int idx2 = self.num_actions*10+7

        result_np = np.zeros((self.env_n, self.obs_n), dtype=np.float32)
        cdef float[:, :] result = result_np        
        for i in range(self.env_n):
            result[i, :idx1] = node_stat(self.root_nodes[i], detailed=True, reward_transform=self.reward_transform)
            result[i, idx1:idx2] = node_stat(self.cur_nodes[i], detailed=False, reward_transform=self.reward_transform)    
            # reset
            if action is None or status[i] == 1:
                result[i, idx2] = 1.
            else:
                result[i, idx2] = action[i, 2]
            # time
            result[i, idx2+1+self.cur_t[i]] = 1.
            # deprec
            result[i, idx2+self.rec_t+1] = (self.discounting ** (self.rollout_depth[i]))           
        return result

    def close(self):
        cdef int i
        if hasattr(self, "root_nodes"):
            for i in range(self.env_n):
                node_del(self.root_nodes[i], except_idx=-1)
        self.env.close()

    def seed(self, x):
        self.env.seed(x)

    def print_time(self):
        print(self.timings.summary())

    def clone_state(self):
        return self.env.clone_state()

    def restore_state(self, state):
        self.env.restore_state(state)

    def get_action_meanings(self):
        return self.env.get_action_meanings()       


from thinker.gym_add.asyn_vector_env import AsyncVectorEnv
import thinker.util as util
from thinker.net import ModelNet
from thinker.util import Timings
import thinker.env
import gym
import gym_csokoban
import os 

flags = util.parse([])
flags.rec_t = 20
flags.flex_t = False
flags.env = "BreakoutNoFrameskip-v4"
flags.tree_carry = True
flags.see_encode = True
flags.max_depth = 3
flags.perfect_model = True

env_n = 2
device = torch.device("cuda")

#env = Environment(flags, env_n=env_n, model_wrap=True, device=device)
env = AsyncVectorEnv([lambda: thinker.env.PreWrap(gym.make(flags.env), flags.env) for _ in range(env_n)])
num_actions = env.action_space[0].n
c = cVecModelWrapper if flags.perfect_model else cVecFullModelWrapper
env = c(env, env_n, flags, device=device, time=True)
env.seed(np.arange(env_n))

model_net = ModelNet(env.gym_env_out_shape, num_actions, flags)
_ = model_net.train(False)
model_net.to(device)

model_out, gym_env_out, model_encodes = env.reset(model_net)
timings = Timings()
timings.reset()

im_actions =    [1, 2, 3]
im_actions_tensor = torch.tensor([[im_actions[n] for _ in range(env_n)] for n in range(len(im_actions))])
#_, _, _, _, encodes_ = model_net.forward_encoded(model_encodes, im_actions_tensor.to(device))

real_actions =  [0, 0, 0, 1, 0, 0, 0, 0]
im_actions =    [1, 2, 3, 0, 2, 3, 1, 2]
reset_actions = [0, 0, 0, 0, 0, 0, 0, 0]
for n in range(len(real_actions)):
    action = torch.tensor([[real_actions[n],
                            im_actions[n],
                            reset_actions[n],
                        ] for _ in range(env_n)], dtype=torch.long)
    (model_out, gym_env_out, model_encodes), reward, done, info = env.step(action.to(device), model_net)    
    print(n, util.decode_model_out(model_out.unsqueeze(0), num_actions, flags.reward_transform)['reset'])
    

env.close()
#print(torch.sum(torch.abs(encodes_[-1] - model_encodes)))

UsageError: Cell magic `%%cython` not found.


In [3]:
%load_ext Cython
%load_ext autoreload
%autoreload 2

import os
import sys
module_path = os.path.abspath('thinker/thinker')
if module_path not in sys.path:
    sys.path.append(module_path)

from collections import namedtuple
from matplotlib import pyplot as plt
import matplotlib.ticker as mticker
from collections import deque
import time
import numpy as np
import argparse
import torch
import torch.nn.functional as F
from thinker.env import Environment, EnvOut
from thinker.net import ActorNet, ModelNet
from thinker.buffer import ModelBuffer
from torch import nn
import thinker.util as util
import gym
import gym_csokoban

def gplot(x, ax=None, title=None):
    if ax is None: fig, ax = plt.subplots()
    if type(x) == torch.Tensor: x = x.cpu()
    if type(x) == np.ndarray: x = torch.tensor(x)       
    ax.imshow(torch.swapaxes(torch.swapaxes(x,0,2),0,1), interpolation='nearest', aspect="auto")
    if title is not None: ax.set_title(title)

The Cython extension is already loaded. To reload it, use:
  %reload_ext Cython
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:

pydict = {"A": 1 }
cdef Node* proot_node = node_new(pparent=NULL, action=0, logit=0.25, num_actions=4, discounting=1, rec_t=5, remember_path=True)    
cdef Node* pcur_node = proot_node
cdef Node* ptest_node
node_expand(pnode=proot_node, r=1, v=2, t=0, done=False, logits=np.array([3,4,5,1], dtype=np.float32), encoded=<PyObject*>pydict, override=False)
node_visit(pnode=proot_node)
print("root paths")
for n in range(proot_node.ppaths[0].size()):
    for m in range(proot_node.ppaths[0][n][0].size()):
        print(n, m, proot_node.ppaths[0][n][0][m][0].v)

pcur_node = pcur_node[0].ppchildren[0][3]
node_expand(pnode=pcur_node, r=3, v=4, t=0, done=False, logits=np.array([3,4,6,1], dtype=np.float32), encoded=<PyObject*>pydict, override=False)
node_visit(pnode=pcur_node)

pcur_node = pcur_node[0].ppchildren[0][3]
node_expand(pnode=pcur_node, r=5, v=6, t=0, done=False, logits=np.array([3,4,6,1], dtype=np.float32), encoded=<PyObject*>pydict, override=False)
node_visit(pnode=pcur_node)

ptest_node = proot_node[0].ppchildren[0][3]
print("root rollout_qs")
for n, i in enumerate(ptest_node.prollout_qs[0]): print(n, i)

node_expand(pnode=pcur_node, r=7, v=8, t=1, done=False, logits=np.array([3,4,6,1], dtype=np.float32), encoded=<PyObject*>pydict, override=True)
node_visit(pnode=pcur_node)

print("new root rollout_qs")
for n, i in enumerate(ptest_node.prollout_qs[0]): print(n, i)

print("root paths")
for n in range(ptest_node.ppaths[0].size()):
    for m in range(ptest_node.ppaths[0][n][0].size()):
        print(n, m, ptest_node.ppaths[0][n][0][m][0].v)

node_del(proot_node, -1)        


/home/sc/.cache/ipython/cython/_cython_magic_ebd584b4aea78c44f6db178c00dc29dc.cpp: In function ‘void __pyx_f_46_cython_magic_ebd584b4aea78c44f6db178c00dc29dc_node_propagate(__pyx_t_46_cython_magic_ebd584b4aea78c44f6db178c00dc29dc_Node*, float, float, bool, std::vector<__pyx_t_46_cython_magic_ebd584b4aea78c44f6db178c00dc29dc_Node*>*)’:
 3601 |       for (__pyx_t_5 = 0; __pyx_t_5 < __pyx_t_4; __pyx_t_5+=1) {
      |                           ~~~~~~~~~~^~~~~~~~~~~
/home/sc/.cache/ipython/cython/_cython_magic_ebd584b4aea78c44f6db178c00dc29dc.cpp: At global scope:
 3753 | static __Pyx_memviewslice __pyx_f_46_cython_magic_ebd584b4aea78c44f6db178c00dc29dc_node_stat(struct __pyx_t_46_cython_magic_ebd584b4aea78c44f6db178c00dc29dc_Node *__pyx_v_pnode, bool __pyx_v_detailed, bool __pyx_v_reward_transform) {
      |                           ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~


root paths
0 0 2.0
root rollout_qs
0 7.0
1 14.0
new root rollout_qs
0 7.0
1 18.0
root paths
0 0 4.0
1 0 8.0
1 1 4.0


In [None]:
#env = Environment(flags, env_n=env_n, device=device)
env = AsyncVectorEnv([lambda: thinker.env.PreWrap(gym.make(flags.env), flags.env) for _ in range(env_n)])
num_actions = env.action_space[0].n
env = cVecModelWrapper(env, env_n, flags, device=device, time=True)
env.seed(np.arange(env_n))

model_net = ModelNet(env.gym_env_out_shape, num_actions, flags)
_ = model_net.train(False)
model_net.to(device)
model_out, gym_env_out, model_encodes = env.reset(model_net)


flags.perfect_model = False
env_ = AsyncVectorEnv([lambda: thinker.env.PreWrap(gym.make(flags.env), flags.env) for _ in range(env_n)])
env_ = cVecModelWrapper(env_, env_n, flags, device=device, time=True)
env_.seed(np.arange(env_n))
model_out_, gym_env_out_, model_encodes_ = env_.reset(model_net)
print("diff: ", torch.sum(torch.abs(model_out - model_out_)))


action = torch.tensor([[np.random.randint(num_actions),
                        np.random.randint(num_actions),
                        np.random.randint(1),
                        ] for _ in range(env_n)], dtype=torch.long)

timings = Timings()
timings.reset()


for n in range(100):
    timings.time("s0")
    (model_out, gym_env_out, model_encodes), reward, done, info = env.step(action.to(device), model_net)
    
    timings.time("s1")
    (model_out_, gym_env_out_, model_encodes_), reward_, done_, info_ = env_.step(action.to(device), model_net)
    timings.time("s2")

    model_out_diff = torch.sum(torch.abs(model_out - model_out_))
    gym_env_out_diff = torch.sum(torch.abs(gym_env_out - gym_env_out_))
    reward_diff = torch.sum(torch.abs(reward - reward_))
    done_diff = torch.sum(torch.abs(done.float() - done_.float()))
    
    if model_out_diff > 1e-5: 
        err_ind = torch.argmax(torch.max(torch.abs(model_out - model_out_), dim=1)[0])     
        print(model_out[err_ind], model_out_[err_ind])
        raise Exception("model_out_diff %f" % model_out_diff)
    if gym_env_out_diff > 1e-5: raise Exception("gym_env_out_diff %f" % gym_env_out_diff)
    if reward_diff > 1e-5: raise Exception("gym_env_out_diff %f" % reward_diff)
    if done_diff > 1e-5: raise Exception("gym_env_out_diff %f" % done_diff)

    if n % 50 == 0: print("Finish %d step" % n)
env.close()
env_.close()
print(timings.summary()) 

In [None]:
# test for Node
%%cython 

pydict = {"A": 1 }
cdef Node* proot_node = node_new(pparent=NULL, action=0, logit=0.25, num_actions=4, discounting=0.97, rec_t=5)    
node_expand(pnode=proot_node, r=0.5, v=1.2, logits=np.array([1.2,3.4,1.2,-3.4], dtype=np.float32), encoded=<PyObject*>pydict, override=False)
node_visit(pnode=proot_node)
cdef Node* pcur_node = proot_node[0].ppchildren[0][3]
node_expand(pnode=pcur_node, r=0.4, v=1.9, logits=np.array([1.4,3.5,5.2,-3.4], dtype=np.float32), encoded=<PyObject*>pydict, override=False)
node_visit(pnode=pcur_node)

pcur_node = pcur_node[0].ppchildren[0][3]
node_expand(pnode=pcur_node, r=1.4, v=1.9, logits=np.array([1.4,3.5,5.2,-3.4], dtype=np.float32), encoded=<PyObject*>pydict, override=False)
node_visit(pnode=pcur_node)

node_visit(pnode=proot_node)
pcur_node = proot_node[0].ppchildren[0][2]
node_expand(pnode=pcur_node, r=4.4, v=-1.9, logits=np.array([5.4,1.5,5.2,-3.4], dtype=np.float32), encoded=<PyObject*>pydict, override=False)
node_visit(pnode=pcur_node)

for i in range(proot_node[0].prollout_qs[0].size()):
    print(proot_node[0].prollout_qs[0][i])
print(np.array(node_stat(proot_node, True)))

node_del(proot_node, 2)
print(np.array(node_stat(pcur_node, True)), pcur_node[0].pparent == NULL)


import thinker.env
import torch
import numpy as np
pydict = None
node = thinker.env.Node(None, action=0, logit=0.25, num_actions=4, discounting=0.97, rec_t=5)   
node.expand(r=torch.tensor([0.5]), v=torch.tensor([1.2]), logits=torch.tensor([1.2,3.4,1.2,-3.4], dtype=float), encoded=pydict)
node.visit()
cur_node = node.children[3]
cur_node.expand(r=torch.tensor([0.4]), v=torch.tensor([1.9]), logits=torch.tensor([1.4,3.5,5.2,-3.4], dtype=float), encoded=pydict, override=False)
cur_node.visit()
cur_node = cur_node.children[3]
cur_node.expand(r=torch.tensor([1.4]), v=torch.tensor([1.9]), logits=torch.tensor([1.4,3.5,5.2,-3.4], dtype=float), encoded=pydict, override=False)
cur_node.visit()

node.visit()
cur_node = node.children[2]
cur_node.expand(r=torch.tensor([4.4]), v=torch.tensor([-1.9]), logits=torch.tensor([5.4,1.5,5.2,-3.4], dtype=float), encoded=pydict, override=False)
cur_node.visit()


print(node.stat(True))
print(node.rollout_qs)

In [32]:
%%cython --annotate
from __future__ import print_function
cdef class Shrubbery:
    cdef int width
    cdef int height

    def __init__(self, int w, int h):
        self.width = w
        self.height = h

    cdef describe(self):
        print("This shrubbery is", self.width,
              "by", self.height, "cubits.")