In [1]:
import argparse
import logging
import os
import pprint
import threading
import time
import timeit
import traceback
import typing

os.environ["OMP_NUM_THREADS"] = "1"  # Necessary for multithreading.

import torch
from torch import multiprocessing as mp
from torch.multiprocessing import Process, Manager
from torch import nn
from torch.nn import functional as F

from torchbeast.core.environment import Environment, Vec_Environment
from torchbeast.atari_wrappers import SokobanWrapper
from torchbeast.base import BaseNet
from torchbeast.train import create_env

import gym
import gym_sokoban
import numpy as np
import math
import logging
from matplotlib import pyplot as plt
from collections import deque

logging.basicConfig(format='%(message)s', level=logging.DEBUG)
logging.getLogger('matplotlib.font_manager').disabled = True

torch.multiprocessing.set_sharing_strategy('file_system')

def get_param(net, name=None):
    keys = []
    for (k, v) in actor_wrapper.model.named_parameters(): 
        if name is None:
            print(k)
        else:
            if name == k: return v
        keys.append(k)
    return keys        

def n_step_greedy(env, net, n, temp=10.):    
    if isinstance(env, Vec_Environment):
        num_actions = env.gym_env.action_space[0].n
        bsz = len(env.gym_env.envs)
    else:
        num_actions = env.gym_env.action_space.n
        bsz = 1

    q_ret = torch.zeros(bsz, num_actions).to(device)      
    state = env.clone_state()

    for act in range(num_actions):
        obs = env.step(torch.Tensor(np.full(bsz, act)).long())      
        obs = {k:v.to(device) for k, v in obs.items()}   
        
        if n > 1:
            action, prob, sub_q_ret = n_step_greedy(env, net, n-1)
            ret = obs['reward'] + flags.discounting * torch.max(sub_q_ret, dim=1)[0] * (~obs['done']).float()
        else:
            ret = obs['reward'] + flags.discounting * net(obs)[0]['baseline'] * (~obs['done']).float()

        q_ret[:, act] = ret
        env.restore_state(state)
    
    prob = F.softmax(temp*q_ret, dim=1)
    action = torch.multinomial(prob, num_samples=1)[:, 0]
    
    return action, prob, q_ret  


[DEBUG:285693 __init__:275 2022-11-06 01:39:37,629] matplotlib data path: /home/sc/anaconda3/lib/python3.9/site-packages/matplotlib/mpl-data
[DEBUG:285693 __init__:275 2022-11-06 01:39:37,631] CONFIGDIR=/home/sc/.config/matplotlib
[DEBUG:285693 __init__:1445 2022-11-06 01:39:37,632] interactive is False
[DEBUG:285693 __init__:1446 2022-11-06 01:39:37,632] platform is linux


[DEBUG:285693 __init__:275 2022-11-06 01:39:37,699] CACHEDIR=/home/sc/.cache/matplotlib
[DEBUG:285693 font_manager:1439 2022-11-06 01:39:37,700] Using fontManager instance from /home/sc/.cache/matplotlib/fontlist-v330.json


<font size="5">Testing planning algo. for perfect model with bootstrapped values</font>

In [None]:
# Synchronous version of testing 

def test_n_step(n, net, env, temp=10.):
    
    print("Testing %d step planning" % n)

    returns = []
    obs = env.initial()
    eps_n_cur = 5

    while(len(returns) <= eps_n):
        cur_returns = obs['episode_return']    
        obs = {k:v.to(device) for k, v in obs.items()}
        net_out, core_state = net(obs)            
        if n == 0:
            action = net_out["action"][0]
        else:
            action, _, _ = n_step_greedy(env, net, n, temp)
        obs = env.step(action)
        if torch.any(obs['done']):
            returns.extend(cur_returns[obs['done']].numpy())
        if eps_n_cur <= len(returns) and len(returns) > 0: 
            eps_n_cur = len(returns) + 10
            print("Finish %d episode: avg. return: %.2f (+-%.2f) " % (len(returns),
                np.average(returns), np.std(returns) / np.sqrt(len(returns))))
            
    print("Finish %d episode: avg. return: %.2f (+-%.2f) " % (len(returns),
                np.average(returns), np.std(returns) / np.sqrt(len(returns))))
    return returns

bsz = 16    
eps_n = 500
device = torch.device("cuda")

# create environments

env = gym.vector.SyncVectorEnv([lambda: SokobanWrapper(gym.make("Sokoban-v0"), noop=True)] * bsz)
env = Vec_Environment(env, bsz)
num_actions = env.gym_env.action_space[0].n

# import the net

parser = argparse.ArgumentParser()
flags = parser.parse_args("".split())   
flags.discounting = 0.97
temp = 5

net = BaseNet(observation_shape=(3,80,80), num_actions=num_actions, flags=flags)  
net = net.to(device)
checkpoint = torch.load("/home/schk/RS/thinker/models/base_2.tar", map_location="cuda")
#checkpoint = torch.load("/home/schk/RS/thinker/logs/base/torchbeast-20221105-033530/model.tar", map_location="cuda")
net.load_state_dict(checkpoint["model_state_dict"]) 

# initialize net

core_state = net.initial_state(batch_size=bsz)
core_state = tuple(v.to(device) for v in core_state)
net.train(False)

all_returns = {}
for n in range(2,3):
    t = time.process_time()
    all_returns[n] = test_n_step(n, net, env, temp)
    print("Time required for %d step planning: %f" %(n, time.process_time()-t))

In [None]:
# Asynchronous version of testing 

def act_m(
    flags,
    actor_index: int,
    net: torch.nn.Module,
    returns: Manager().list,
    eps_n: int,
    n: int,
    temp: float,
):    
    try:    
        #logging.info("Actor %i started", actor_index)
        gym_env = create_env(flags)
        seed = actor_index ^ int.from_bytes(os.urandom(4), byteorder="little")
        gym_env.seed(seed)
        env = Environment(gym_env)
        env_output = env.initial()  
        agent_state = net.initial_state(batch_size=1)
        net_out, unused_state = net(env_output, agent_state)      
        while True:            
            if len(returns) >= eps_n: break
            with torch.no_grad():
                net_out, agent_state = net(env_output, agent_state)                            
            if n == 0:
                action = net_out["action"]
            else:
                action, _, _ = n_step_greedy(env, net, n, temp)            
            env_output = env.step(action)           
            if env_output['done']: returns.append(ret)
            ret = env_output['episode_return'].item()        
        #logging.info("Actor %i end", actor_index)
    except KeyboardInterrupt:
        pass  # Return silently.
    except Exception as e:
        logging.error("Exception in worker process %i", actor_index)
        traceback.print_exc()
        raise e

def asy_test_n_step(n, net, flags, temp):
    
    print("Testing %d step planning" % n)

    mp.set_sharing_strategy('file_system')
    net.share_memory()
    ctx = mp.get_context()        
    returns = Manager().list()

    actor_processes = []
    for i in range(flags.num_actors):
        actor = ctx.Process(target=act_m, args=(flags, i, net, returns, eps_n, n, temp),)
        actor.start()
        actor_processes.append(actor)    

    for actor in actor_processes:
        actor.join()    

    print("Finish %d episode: avg. return: %.2f (+-%.2f)" % (len(returns),
                    np.average(returns), np.std(returns) / np.sqrt(len(returns)),))        
    return returns        
        
parser = argparse.ArgumentParser()
flags = parser.parse_args("".split())   

flags.env = "Sokoban-v0"
flags.env_disable_noop = False
flags.discounting = 0.97     
flags.num_actors = 32
bsz = 1
eps_n = 500
temp = 5
device = torch.device("cpu")

net = BaseNet(observation_shape=(3,80,80), num_actions=5, flags=flags)  
net = net.to("cpu")
checkpoint = torch.load("/home/schk/RS/thinker/models/base_2.tar", map_location="cpu")
net.load_state_dict(checkpoint["model_state_dict"]) 

all_returns = {}
for n in range(4):
    t = time.time()
    all_returns[n] = asy_test_n_step(n, net, flags, temp)
    print("Time required for %d step planning: %f" %(n, time.time()-t))

Results (base_1.tar):
    
Testing 0 step planning <br>
Finish 512 episode: avg. return: 0.12 (+-0.06) <br>
Testing 1 step planning <br>
Finish 502 episode: avg. return: 0.61 (+-0.04) <br>
Testing 2 step planning <br>
Finish 501 episode: avg. return: 0.92 (+-0.04) <br>
Testing 3 step planning <br>
Finish 501 episode: avg. return: 1.01 (+-0.04) <br>

Results (base_2.tar):
Testing 0 step planning <br>
Finish 500 episode: avg. return: 0.27 (+-0.04) <br>
Time required for 0 step planning: 12.629324 <br>
Testing 1 step planning <br>
Finish 502 episode: avg. return: 0.51 (+-0.04) <br>
Time required for 1 step planning: 74.194364 <br>
Testing 2 step planning <br>
Finish 500 episode: avg. return: 0.74 (+-0.04) <br>
Time required for 2 step planning: 339.732901 <br>
Testing 3 step planning <br>
Finish 500 episode: avg. return: 0.76 (+-0.04) <br>
Time required for 3 step planning: 1695.472523 <br>

<font size="5">Model Training Phase</font>

In [38]:
# Generating data for learning model [RUN]

Buffers = typing.Dict[str, typing.List[torch.Tensor]]

def create_buffers_m(flags, obs_shape, num_actions) -> Buffers:
    
    seq_len = flags.seq_len
    seq_n = flags.seq_n
    specs = dict(
        frame=dict(size=(seq_len + 1, *obs_shape), dtype=torch.uint8),
        reward=dict(size=(seq_len + 1,), dtype=torch.float32),
        done=dict(size=(seq_len + 1,), dtype=torch.bool),
        truncated_done=dict(size=(seq_len + 1,), dtype=torch.bool),
        episode_return=dict(size=(seq_len + 1,), dtype=torch.float32),
        episode_step=dict(size=(seq_len + 1,), dtype=torch.int32),
        policy_logits=dict(size=(seq_len + 1, num_actions), dtype=torch.float32),
        baseline=dict(size=(seq_len + 1,), dtype=torch.float32),
        last_action=dict(size=(seq_len + 1,), dtype=torch.int64),
        action=dict(size=(seq_len + 1,), dtype=torch.int64),
        reg_loss=dict(size=(seq_len + 1,), dtype=torch.float32)
    )
    buffers: Buffers = {key: [] for key in specs}
    for _ in range(seq_n):
        for key in buffers:
            buffers[key].append(torch.empty(**specs[key]).share_memory_())
            
    return buffers

def gen_data(
    flags,
    actor_index: int,
    net: torch.nn.Module,
    buffers: Buffers,
    free_queue: mp.SimpleQueue,
):    
    try:    
        #logging.info("Actor %i started", actor_index)
        gym_env = create_env(flags)
        seed = actor_index ^ int.from_bytes(os.urandom(4), byteorder="little")
        gym_env.seed(seed)
        env = Environment(gym_env)
        env_output = env.initial()  
        agent_state = net.initial_state(batch_size=1)
        agent_output, unused_state = net(env_output, agent_state)     
        
        while True:
            index = free_queue.get()
            if index is None:
                break         

            # Write old rollout end.
            for key in env_output:
                buffers[key][index][0, ...] = env_output[key]
            for key in agent_output:
                buffers[key][index][0, ...] = agent_output[key]

            # Do new rollout.
            for t in range(flags.seq_len):
                with torch.no_grad():
                    agent_output, agent_state = net(env_output, agent_state)
                env_output = env.step(agent_output["action"])
                for key in env_output:
                    buffers[key][index][t + 1, ...] = env_output[key]
                for key in agent_output:
                    buffers[key][index][t + 1, ...] = agent_output[key]
                    
    except KeyboardInterrupt:
        pass  # Return silently.
    except Exception as e:
        logging.error("Exception in worker process %i", actor_index)
        traceback.print_exc()
        raise e
        

# Models

DOWNSCALE_C = 2

def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation,
        groups=groups, bias=False, dilation=dilation,)

def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

class ResBlock(nn.Module):
    expansion: int = 1

    def __init__(self, inplanes, outplanes=None):
        super().__init__()
        if outplanes is None: outplanes = inplanes 
        norm_layer = nn.BatchNorm2d
        self.conv1 = conv3x3(inplanes, inplanes)
        self.bn1 = norm_layer(inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(inplanes, outplanes)
        self.bn2 = norm_layer(outplanes)
        self.skip_conv = (outplanes != inplanes)
        if outplanes != inplanes:
            self.conv3 = conv1x1(inplanes, outplanes)

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.skip_conv:
            out += self.conv3(identity)
        else:
            out += identity
        out = self.relu(out)
        return out
    
class FrameEncoder(nn.Module):    
    def __init__(self, num_actions, frame_channels=3, type_nn=0):
        self.num_actions = num_actions
        super(FrameEncoder, self).__init__() 
        
        if type_nn == 0:
            n_block = 1
        elif type_nn == 1:
            n_block = 2
        
        self.conv1 = nn.Conv2d(in_channels=frame_channels+num_actions, out_channels=128//DOWNSCALE_C, kernel_size=3, stride=2, padding=1) 
        res = nn.ModuleList([ResBlock(inplanes=128//DOWNSCALE_C) for i in range(n_block)]) # Deep: 2 blocks here
        self.res1 = torch.nn.Sequential(*res)
        self.conv2 = nn.Conv2d(in_channels=128//DOWNSCALE_C, out_channels=256//DOWNSCALE_C, 
                               kernel_size=3, stride=2, padding=1) 
        res = nn.ModuleList([ResBlock(inplanes=256//DOWNSCALE_C) for i in range(n_block)]) # Deep: 3 blocks here
        self.res2 = torch.nn.Sequential(*res)
        self.avg1 = nn.AvgPool2d(3, stride=2, padding=1)
        res = nn.ModuleList([ResBlock(inplanes=256//DOWNSCALE_C) for i in range(n_block)]) # Deep: 3 blocks here
        self.res3 = torch.nn.Sequential(*res)
        self.avg2 = nn.AvgPool2d(3, stride=2, padding=1)
    
    def forward(self, x, actions):        
        # input shape: B, C, H, W        
        # action shape: B 
        
        x = x.float() / 255.0    
        actions = actions.unsqueeze(-1).unsqueeze(-1).tile([1, 1, x.shape[2], x.shape[3]])        
        x = torch.concat([x, actions], dim=1)
        x = F.relu(self.conv1(x))
        x = self.res1(x)
        x = F.relu(self.conv2(x))
        x = self.res2(x)
        x = self.avg1(x)
        x = self.res3(x)
        x = self.avg2(x)
        return x
    
class DynamicModel(nn.Module):
    def __init__(self, num_actions, inplanes=256, type_nn=0):        
        super(DynamicModel, self).__init__()
        self.type_nn = type_nn
        if type_nn == 0:
            res = nn.ModuleList([ResBlock(inplanes=inplanes+num_actions, outplanes=inplanes)] + [
                    ResBlock(inplanes=inplanes) for i in range(4)]) 
        elif type_nn == 1:                      
            res = nn.ModuleList([ResBlock(inplanes=inplanes) for i in range(15)] + [
                    ResBlock(inplanes=inplanes, outplanes=inplanes*num_actions)])

        
        self.res = torch.nn.Sequential(*res)
        self.num_actions = num_actions
    
    def forward(self, x, actions):              
        bsz, c, h, w = x.shape
        if self.training:
            x.register_hook(lambda grad: grad * 0.5)
        if self.type_nn == 0:
            actions = actions.unsqueeze(-1).unsqueeze(-1).tile([1, 1, x.shape[2], x.shape[3]])        
            x = torch.concat([x, actions], dim=1)
            out = self.res(x)
        elif self.type_nn == 1:            
            res_out = self.res(x).view(bsz, self.num_actions, c, h, w)        
            actions = actions.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
            out = torch.sum(actions * res_out, dim=1)
        return out
    
class Output_rvpi(nn.Module):   
    def __init__(self, num_actions, input_shape):         
        super(Output_rvpi, self).__init__()        
        c, h, w = input_shape
        self.conv1 = nn.Conv2d(in_channels=c, out_channels=c//2, kernel_size=3, padding='same') 
        self.conv2 = nn.Conv2d(in_channels=c//2, out_channels=c//4, kernel_size=3, padding='same') 
        fc_in = h * w * (c // 4)
        self.fc_r = nn.Linear(fc_in, 1) 
        self.fc_v = nn.Linear(fc_in, 1) 
        self.fc_logits = nn.Linear(fc_in, num_actions)         
        
    def forward(self, x):    
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = torch.flatten(x, start_dim=1)
        r, v, logits = self.fc_r(x), self.fc_v(x), self.fc_logits(x)
        return r, v, logits

class Model(nn.Module):    
    def __init__(self, flags, obs_shape, num_actions):        
        super(Model, self).__init__()      
        self.flags = flags
        self.obs_shape = obs_shape
        self.num_actions = num_actions          
        self.type_nn = flags.model_type_nn # type_nn: type of neural network for the model; 0 for small, 1 for large
        self.frameEncoder = FrameEncoder(num_actions=num_actions, frame_channels=obs_shape[0], type_nn=self.type_nn)
        self.dynamicModel = DynamicModel(num_actions=num_actions, inplanes=256//DOWNSCALE_C, type_nn=self.type_nn)
        self.output_rvpi = Output_rvpi(num_actions=num_actions, input_shape=(256//DOWNSCALE_C, 
                      obs_shape[1]//16, obs_shape[1]//16))
        
    def forward(self, x, actions, one_hot=False):
        # Input
        # x: frames with shape (B, C, H, W), in the form of s_t
        # actions: action (int64) with shape (k+1, B), in the form of a_{t-1}, a_{t}, a_{t+1}, .. a_{t+k-1}
        # Output
        # reward: predicted reward with shape (k, B), in the form of r_{t+1}, r_{t+2}, ..., r_{t+k}
        # value: predicted value with shape (k+1, B), in the form of v_{t}, v_{t+1}, v_{t+2}, ..., v_{t+k}
        # policy: predicted policy with shape (k+1, B), in the form of pi_{t}, pi_{t+1}, pi_{t+2}, ..., pi_{t+k}
        # encoded: encoded states with shape (k+1, B), in the form of z_t, z_{t+1}, z_{t+2}, ..., z_{t+k}
        # Recall the transition notation: s_t, a_t, r_{t+1}, s_{t+1}, ...
        
        if not one_hot:
            actions = F.one_hot(actions, self.num_actions)                
        encoded = self.frameEncoder(x, actions[0])
        return self.forward_encoded(encoded, actions[1:], one_hot=True)
    
    def forward_encoded(self, encoded, actions, one_hot=False):
        if not one_hot:
            actions = F.one_hot(actions, self.num_actions)                
        
        r, v, logits = self.output_rvpi(encoded)
        r_list, v_list, logits_list = [], [v.squeeze(-1).unsqueeze(0)], [logits.unsqueeze(0)]
        encoded_list = [encoded.unsqueeze(0)]
        
        for k in range(actions.shape[0]):            
            encoded = self.dynamicModel(encoded, actions[k])
            r, v, logits = self.output_rvpi(encoded)
            r_list.append(r.squeeze(-1).unsqueeze(0))
            v_list.append(v.squeeze(-1).unsqueeze(0))
            logits_list.append(logits.unsqueeze(0))
            encoded_list.append(encoded.unsqueeze(0))        
        
        if len(r_list) > 0:
            rs = torch.concat(r_list, dim=0)
        else:
            rs = None
            
        vs = torch.concat(v_list, dim=0)
        logits = torch.concat(logits_list, dim=0)
        encodeds = torch.concat(encoded_list, dim=0)        
        
        return rs, vs, logits, encodeds

#model = Model(flags, (3, 80, 80), num_actions=5)
#rs, vs, logits = model(torch.rand(16, 3, 80, 80), torch.ones(8, 16).long())

# functions for training models

def get_batch_m(flags, buffers: Buffers):
    batch_indices = np.random.randint(flags.seq_n, size=flags.bsz)
    time_indices = np.random.randint(flags.seq_len - flags.unroll_len, size=flags.bsz)
    batch = {key: torch.stack([buffers[key][m][time_indices[n]:time_indices[n]+flags.unroll_len+1] 
                          for n, m in enumerate(batch_indices)], dim=1) for key in buffers}
    batch = {k: t.to(device=flags.device, non_blocking=True) for k, t in batch.items()}
    return batch

def compute_cross_entropy_loss(logits, target_logits, mask):
    target_policy = F.softmax(target_logits, dim=-1)
    log_policy = F.log_softmax(logits, dim=-1)
    return -torch.sum(target_policy * log_policy * (~mask).float().unsqueeze(-1))

def compute_loss_m(model, batch):
    rs, vs, logits, _ = model(batch['frame'][0], batch['action'])
    logits = logits[:-1]

    target_rewards = batch['reward'][1:]
    target_logits = batch['policy_logits'][1:]

    target_vs = []
    target_v = model(batch['frame'][-1], batch['action'][[-1]])[1][0].detach()
    
    for t in range(vs.shape[0]-1, 0, -1):
        new_target_v = batch['reward'][t] + flags.discounting * (target_v * (~batch['done'][t]).float())# +
                           #vs[t-1] * (batch['truncated_done'][t]).float())
        target_vs.append(new_target_v.unsqueeze(0))
        target_v = new_target_v
    target_vs.reverse()
    target_vs = torch.concat(target_vs, dim=0)

    # if done on step j, r_{j}, v_{j-1}, a_{j-1} has the last valid loss 
    # rs is stored in the form of r_{t+1}, ..., r_{t+k}
    # vs is stored in the form of v_{t}, ..., v_{t+k-1}
    # logits is stored in the form of a{t}, ..., a_{t+k-1}

    done_masks = []
    done = torch.zeros(vs.shape[1]).bool().to(batch['done'].device)
    for t in range(vs.shape[0]):
        if t > 0: done = torch.logical_or(done, batch['done'][t])
        done_masks.append(done.unsqueeze(0))

    done_masks = torch.concat(done_masks[:-1], dim=0)
    
    # compute final loss
    huberloss = torch.nn.HuberLoss(reduction='none', delta=1.0)    
    #rs_loss = torch.sum(huberloss(rs, target_rewards.detach()) * (~done_masks).float())
    rs_loss = torch.sum(((rs - target_rewards) ** 2) * (~done_masks).float())
    #vs_loss = torch.sum(huberloss(vs[:-1], target_vs.detach()) * (~done_masks).float())
    vs_loss = torch.sum(((vs[:-1] - target_vs) ** 2) * (~done_masks).float())
    logits_loss = compute_cross_entropy_loss(logits, target_logits.detach(), done_masks)
    
    return rs_loss, vs_loss, logits_loss

# n_step_greedy for testing

def n_step_greedy_model(state, action, model, n, encoded=None, temp=20.): 
    
    # Either input state, action (S_t, A_{t-1}) or the encoded Z_t
    # state / encoded in the shape of (B, C, H, W)
    # action in the shape of (B)    
    with torch.no_grad():    
      bsz = state.shape[0] if encoded is None else encoded.shape[0]
      device = state.device if encoded is None else encoded.device
      num_actions = model.num_actions    

      q_ret = torch.zeros(bsz, num_actions).to(device)        

      for act in range(num_actions):        
          new_action = torch.Tensor(np.full(bsz, act)).long().to(device)    
          if encoded is None:            
              old_new_actions = torch.concat([action.unsqueeze(0), new_action.unsqueeze(0)], dim=0)
              rs, vs, logits, encodeds = model(state, old_new_actions)
          else:
              rs, vs, logits, encodeds = model.forward_encoded(encoded, new_action.unsqueeze(0))

          if n > 1:
              action, prob, sub_q_ret = n_step_greedy_model(state=None, action=None, 
                         model=model, n=n-1, encoded=encodeds[1])
              ret = rs[0] + flags.discounting * torch.max(sub_q_ret, dim=1)[0] 
          else:
              ret = rs[0] + flags.discounting * vs[1]
          q_ret[:, act] = ret

      prob = F.softmax(temp*q_ret, dim=1)
      action = torch.multinomial(prob, num_samples=1)[:, 0]
    
    return action, prob, q_ret        
   
#n_step_greedy_model(batch['frame'][0], batch['action'][0], model, 4)  

def test_n_step_model(n, model, flags, eps_n=100, temp=20.):    
    
    print("Testing %d step planning" % n) 
    
    bsz = 100
    env = gym.vector.SyncVectorEnv([lambda: SokobanWrapper(gym.make("Sokoban-v0"), noop=True)] * bsz)
    env = Vec_Environment(env, bsz)
    num_actions = env.gym_env.action_space[0].n
    
    model.train(False)
    returns = []
    
    obs = env.initial()
    action = torch.zeros(bsz).long().to(flags.device)
    eps_n_cur = 5

    while(len(returns) <= eps_n):
        cur_returns = obs['episode_return']    
        obs = {k:v.to(flags.device) for k, v in obs.items()}
        new_action, _, _ = n_step_greedy_model(obs['frame'][0], action, model, n, None, temp)        
        obs = env.step(new_action)
        action = new_action
        if torch.any(obs['done']):
            returns.extend(cur_returns[obs['done']].numpy())
        if eps_n_cur <= len(returns) and len(returns) > 0: 
            eps_n_cur = len(returns) + 10
            #print("Finish %d episode: avg. return: %.2f (+-%.2f) " % (len(returns),
            #    np.average(returns), np.std(returns) / np.sqrt(len(returns))))
            
    returns = returns[:eps_n]
    print("Finish %d episode: avg. return: %.2f (+-%.2f) " % (len(returns),
                np.average(returns), np.std(returns) / np.sqrt(len(returns))))
    return returns

In [44]:
# Start training models

parser = argparse.ArgumentParser()
flags = parser.parse_args("".split())       

flags.env = "Sokoban-v0"
flags.env_disable_noop = False
flags.bsz = 64
flags.unroll_len = 5
flags.num_actors = 32
flags.seq_n = 1000
flags.seq_len = 200
flags.learning_rate = 0.0001
flags.loop_batch_n = 3
flags.discounting = 0.97
flags.tot_epoch = 10000
flags.grad_norm_clipping = 60

flags.model_type_nn = 1

flags.device = torch.device("cuda")

# Create buffer for actors to write

mp.set_sharing_strategy('file_system')
ctx = mp.get_context()        

env = create_env(flags)
obs_shape, num_actions = env.observation_space.shape, env.action_space.n
buffers = create_buffers_m(flags, obs_shape, num_actions)
print("Buffer created successfully.")

# Initialize the model and optimizer

env = create_env(flags)
model = Model(flags, obs_shape, num_actions=num_actions).to(device=flags.device)
optimizer = torch.optim.Adam(model.parameters(), lr=flags.learning_rate)

print("model size: ", sum(p.numel() for p in model.parameters()))
for k, v in model.named_parameters(): print(k, v.numel())    
    
tot_step = int(flags.loop_batch_n * flags.seq_n * flags.seq_len / flags.bsz / flags.unroll_len) 

Buffer created successfully.
model size:  6905607
frameEncoder.conv1.weight 4608
frameEncoder.conv1.bias 64
frameEncoder.res1.0.conv1.weight 36864
frameEncoder.res1.0.bn1.weight 64
frameEncoder.res1.0.bn1.bias 64
frameEncoder.res1.0.conv2.weight 36864
frameEncoder.res1.0.bn2.weight 64
frameEncoder.res1.0.bn2.bias 64
frameEncoder.res1.1.conv1.weight 36864
frameEncoder.res1.1.bn1.weight 64
frameEncoder.res1.1.bn1.bias 64
frameEncoder.res1.1.conv2.weight 36864
frameEncoder.res1.1.bn2.weight 64
frameEncoder.res1.1.bn2.bias 64
frameEncoder.conv2.weight 73728
frameEncoder.conv2.bias 128
frameEncoder.res2.0.conv1.weight 147456
frameEncoder.res2.0.bn1.weight 128
frameEncoder.res2.0.bn1.bias 128
frameEncoder.res2.0.conv2.weight 147456
frameEncoder.res2.0.bn2.weight 128
frameEncoder.res2.0.bn2.bias 128
frameEncoder.res2.1.conv1.weight 147456
frameEncoder.res2.1.bn1.weight 128
frameEncoder.res2.1.bn1.bias 128
frameEncoder.res2.1.conv2.weight 147456
frameEncoder.res2.1.bn2.weight 128
frameEncoder.

In [46]:
temp = 20

# Load the preset policy

net = BaseNet(observation_shape=(3,80,80), num_actions=5, flags=flags)  
checkpoint = torch.load("/home/sc/RS/thinker/models/base_1.tar", map_location="cpu")
net.load_state_dict(checkpoint["model_state_dict"])   
net.train(False)
net.share_memory()

# Get the actors to write on the buffer

actor_processes = []
free_queue = mp.SimpleQueue()
loss_stats = [deque(maxlen=400) for _ in range(4)]

net.train(False)
for i in range(flags.num_actors):
    actor = ctx.Process(target=gen_data, args=(flags, i, net, buffers, free_queue, ),)
    actor.start()
    actor_processes.append(actor)   
    
for m in range(flags.seq_n): free_queue.put(m)

# Start training loop    

model.train(True)
for epoch in range(flags.tot_epoch):    
    print("Batch [%d] starts" % epoch)
    while(not free_queue.empty()): time.sleep(1)
    for step in range(tot_step):
        if step == 0: 
            test_n_step_model(1, model, flags, eps_n=100, temp=temp)
            test_n_step_model(2, model, flags, eps_n=100, temp=temp)
            model.train(True)
        
        batch = get_batch_m(flags, buffers)
        rs_loss, vs_loss, logits_loss = compute_loss_m(model, batch)
        tot_loss = rs_loss + vs_loss + 0.01 * logits_loss
        for n, l in enumerate([tot_loss, rs_loss, vs_loss, logits_loss]):
            loss_stats[n].append(l.item())
        
        if step % 100 == 0:
            print("[%d:%d] F: %d \t tot_loss %f rs_loss %f vs_loss %f logits_loss %f" % ((
                epoch, step, (step + epoch * tot_step) * flags.bsz * flags.unroll_len,) +
                tuple(np.average(l) for l in loss_stats)))
        optimizer.zero_grad()        
        tot_loss.backward()
        optimize_params = optimizer.param_groups[0]['params']
        if flags.grad_norm_clipping > 0:
            total_norm = nn.utils.clip_grad_norm_(optimize_params, flags.grad_norm_clipping)
        optimizer.step()    
    for m in range(flags.seq_n): free_queue.put(m)
        
for _ in range(flags.num_actors): free_queue.put(None)        
for actor in actor_processes: actor.join(timeout=1)        

model size:  1095814
Batch [0] starts
Testing 1 step planning
Finish 100 episode: avg. return: 0.72 (+-0.15) 
Testing 2 step planning
Finish 100 episode: avg. return: 0.58 (+-0.15) 
[0:0] F: 0 	 tot_loss 9.398491 rs_loss 0.195361 vs_loss 4.774497 logits_loss 442.863281
[0:100] F: 32000 	 tot_loss 13.375923 rs_loss 1.491374 vs_loss 7.435394 logits_loss 444.915540
[0:200] F: 64000 	 tot_loss 17.217728 rs_loss 1.610392 vs_loss 11.151394 logits_loss 445.594250
[0:300] F: 96000 	 tot_loss 16.869749 rs_loss 1.582948 vs_loss 10.826971 logits_loss 445.982990
[0:400] F: 128000 	 tot_loss 16.696534 rs_loss 1.403860 vs_loss 10.832572 logits_loss 446.010267
[0:500] F: 160000 	 tot_loss 21.215325 rs_loss 1.629517 vs_loss 15.118014 logits_loss 446.779423
[0:600] F: 192000 	 tot_loss 20.476190 rs_loss 1.573477 vs_loss 14.431993 logits_loss 447.071967
[0:700] F: 224000 	 tot_loss 20.692786 rs_loss 1.337962 vs_loss 14.883223 logits_loss 447.160041
[0:800] F: 256000 	 tot_loss 20.481714 rs_loss 1.340234

[4:100] F: 2432000 	 tot_loss 22.043298 rs_loss 2.056556 vs_loss 15.539055 logits_loss 444.768738
[4:200] F: 2464000 	 tot_loss 21.739563 rs_loss 1.877954 vs_loss 15.411787 logits_loss 444.982181
[4:300] F: 2496000 	 tot_loss 21.313017 rs_loss 2.091234 vs_loss 14.769619 logits_loss 445.216346
[4:400] F: 2528000 	 tot_loss 18.016853 rs_loss 1.327983 vs_loss 12.231220 logits_loss 445.764984
[4:500] F: 2560000 	 tot_loss 18.719354 rs_loss 1.603701 vs_loss 12.654410 logits_loss 446.124201
[4:600] F: 2592000 	 tot_loss 17.909904 rs_loss 1.657528 vs_loss 11.792423 logits_loss 445.995277
[4:700] F: 2624000 	 tot_loss 17.333489 rs_loss 1.322830 vs_loss 11.549471 logits_loss 446.118799
[4:800] F: 2656000 	 tot_loss 18.287805 rs_loss 1.451660 vs_loss 12.380371 logits_loss 445.577464
[4:900] F: 2688000 	 tot_loss 19.371993 rs_loss 1.703586 vs_loss 13.217405 logits_loss 445.100170
[4:1000] F: 2720000 	 tot_loss 18.082279 rs_loss 1.432121 vs_loss 12.196819 logits_loss 445.333897
[4:1100] F: 2752000

[8:200] F: 4864000 	 tot_loss 17.261416 rs_loss 1.157718 vs_loss 11.651554 logits_loss 445.214406
[8:300] F: 4896000 	 tot_loss 18.287397 rs_loss 1.367892 vs_loss 12.469827 logits_loss 444.967877
[8:400] F: 4928000 	 tot_loss 18.264929 rs_loss 1.452715 vs_loss 12.365178 logits_loss 444.703622
[8:500] F: 4960000 	 tot_loss 17.112293 rs_loss 1.215900 vs_loss 11.451975 logits_loss 444.441848
[8:600] F: 4992000 	 tot_loss 16.388406 rs_loss 1.333411 vs_loss 10.613973 logits_loss 444.102142
[8:700] F: 5024000 	 tot_loss 15.187745 rs_loss 1.260408 vs_loss 9.488771 logits_loss 443.856612
[8:800] F: 5056000 	 tot_loss 16.768452 rs_loss 1.450645 vs_loss 10.879671 logits_loss 443.813593
[8:900] F: 5088000 	 tot_loss 16.810241 rs_loss 1.448875 vs_loss 10.919629 logits_loss 444.173790
[8:1000] F: 5120000 	 tot_loss 16.593635 rs_loss 1.383812 vs_loss 10.763246 logits_loss 444.657722
[8:1100] F: 5152000 	 tot_loss 19.684774 rs_loss 1.451960 vs_loss 13.784214 logits_loss 444.860028
[8:1200] F: 5184000

[12:300] F: 7296000 	 tot_loss 20.149824 rs_loss 1.892716 vs_loss 13.819561 logits_loss 443.754663
[12:400] F: 7328000 	 tot_loss 20.549679 rs_loss 2.083116 vs_loss 14.029286 logits_loss 443.727820
[12:500] F: 7360000 	 tot_loss 18.088189 rs_loss 1.278421 vs_loss 12.375892 logits_loss 443.387593
[12:600] F: 7392000 	 tot_loss 19.216534 rs_loss 1.494177 vs_loss 13.286142 logits_loss 443.621575
[12:700] F: 7424000 	 tot_loss 19.596708 rs_loss 1.530317 vs_loss 13.628876 logits_loss 443.751483
[12:800] F: 7456000 	 tot_loss 17.253484 rs_loss 1.063472 vs_loss 11.754202 logits_loss 443.581077
[12:900] F: 7488000 	 tot_loss 17.912873 rs_loss 1.072749 vs_loss 12.397231 logits_loss 444.289341
[12:1000] F: 7520000 	 tot_loss 15.553484 rs_loss 0.801156 vs_loss 10.313832 logits_loss 443.849551
[12:1100] F: 7552000 	 tot_loss 13.871137 rs_loss 0.757775 vs_loss 8.677126 logits_loss 443.623536
[12:1200] F: 7584000 	 tot_loss 17.498812 rs_loss 1.241835 vs_loss 11.816131 logits_loss 444.084690
[12:1300

[16:400] F: 9728000 	 tot_loss 15.862046 rs_loss 0.957397 vs_loss 10.469704 logits_loss 443.494552
[16:500] F: 9760000 	 tot_loss 16.341551 rs_loss 0.988442 vs_loss 10.916370 logits_loss 443.673836
[16:600] F: 9792000 	 tot_loss 15.459333 rs_loss 0.829924 vs_loss 10.189698 logits_loss 443.971120
[16:700] F: 9824000 	 tot_loss 13.560743 rs_loss 0.522828 vs_loss 8.602854 logits_loss 443.506204
[16:800] F: 9856000 	 tot_loss 13.262928 rs_loss 0.561934 vs_loss 8.265752 logits_loss 443.524235
[16:900] F: 9888000 	 tot_loss 14.633675 rs_loss 0.788739 vs_loss 9.410243 logits_loss 443.469303
[16:1000] F: 9920000 	 tot_loss 13.747763 rs_loss 0.754477 vs_loss 8.560747 logits_loss 443.253942
[16:1100] F: 9952000 	 tot_loss 14.038305 rs_loss 0.761355 vs_loss 8.837389 logits_loss 443.956110
[16:1200] F: 9984000 	 tot_loss 14.042110 rs_loss 0.741323 vs_loss 8.857599 logits_loss 444.318883
[16:1300] F: 10016000 	 tot_loss 12.923177 rs_loss 0.644109 vs_loss 7.834534 logits_loss 444.453405
[16:1400] F:

[20:400] F: 12128000 	 tot_loss 16.174898 rs_loss 1.152071 vs_loss 10.567517 logits_loss 445.531034
[20:500] F: 12160000 	 tot_loss 16.748454 rs_loss 1.372271 vs_loss 10.919778 logits_loss 445.640463
[20:600] F: 12192000 	 tot_loss 15.673618 rs_loss 1.056761 vs_loss 10.160173 logits_loss 445.668401
[20:700] F: 12224000 	 tot_loss 16.575764 rs_loss 1.283534 vs_loss 10.838039 logits_loss 445.419085
[20:800] F: 12256000 	 tot_loss 16.472525 rs_loss 1.045582 vs_loss 10.975231 logits_loss 445.171197
[20:900] F: 12288000 	 tot_loss 17.082335 rs_loss 0.976997 vs_loss 11.649559 logits_loss 445.577893
[20:1000] F: 12320000 	 tot_loss 17.428450 rs_loss 0.985555 vs_loss 11.985635 logits_loss 445.726077
[20:1100] F: 12352000 	 tot_loss 16.524490 rs_loss 0.734344 vs_loss 11.331999 logits_loss 445.814709
[20:1200] F: 12384000 	 tot_loss 18.529669 rs_loss 0.964916 vs_loss 13.103749 logits_loss 446.100429
[20:1300] F: 12416000 	 tot_loss 18.101684 rs_loss 0.788161 vs_loss 12.856773 logits_loss 445.675

[24:400] F: 14528000 	 tot_loss 15.762354 rs_loss 0.727878 vs_loss 10.600810 logits_loss 443.366570
[24:500] F: 14560000 	 tot_loss 15.260044 rs_loss 0.611803 vs_loss 10.211038 logits_loss 443.720297
[24:600] F: 14592000 	 tot_loss 17.800912 rs_loss 1.060551 vs_loss 12.303312 logits_loss 443.704897
[24:700] F: 14624000 	 tot_loss 18.587081 rs_loss 1.576883 vs_loss 12.572245 logits_loss 443.795293
[24:800] F: 14656000 	 tot_loss 22.822903 rs_loss 2.655522 vs_loss 15.729669 logits_loss 443.771222
[24:900] F: 14688000 	 tot_loss 24.166384 rs_loss 2.786600 vs_loss 16.941675 logits_loss 443.810954
[24:1000] F: 14720000 	 tot_loss 23.057169 rs_loss 2.609710 vs_loss 16.013698 logits_loss 443.376145
[24:1100] F: 14752000 	 tot_loss 21.519476 rs_loss 1.999378 vs_loss 15.088113 logits_loss 443.198559
[24:1200] F: 14784000 	 tot_loss 18.999277 rs_loss 1.173263 vs_loss 13.393374 logits_loss 443.264072
[24:1300] F: 14816000 	 tot_loss 16.999544 rs_loss 1.033095 vs_loss 11.534512 logits_loss 443.193

[28:400] F: 16928000 	 tot_loss 15.352701 rs_loss 0.724891 vs_loss 10.185083 logits_loss 444.272653
[28:500] F: 16960000 	 tot_loss 18.299600 rs_loss 1.282773 vs_loss 12.573818 logits_loss 444.300905
[28:600] F: 16992000 	 tot_loss 17.345650 rs_loss 1.280094 vs_loss 11.621226 logits_loss 444.432984
[28:700] F: 17024000 	 tot_loss 18.707971 rs_loss 1.398339 vs_loss 12.865154 logits_loss 444.447890
[28:800] F: 17056000 	 tot_loss 17.883867 rs_loss 1.454611 vs_loss 11.993010 logits_loss 443.624521
[28:900] F: 17088000 	 tot_loss 14.951602 rs_loss 1.062149 vs_loss 9.450222 logits_loss 443.923176
[28:1000] F: 17120000 	 tot_loss 16.304252 rs_loss 1.165747 vs_loss 10.698069 logits_loss 444.043566
[28:1100] F: 17152000 	 tot_loss 16.403356 rs_loss 1.188718 vs_loss 10.772021 logits_loss 444.261696
[28:1200] F: 17184000 	 tot_loss 16.560559 rs_loss 1.108846 vs_loss 11.003324 logits_loss 444.838860
[28:1300] F: 17216000 	 tot_loss 21.257740 rs_loss 1.712202 vs_loss 15.101262 logits_loss 444.4276

[32:400] F: 19328000 	 tot_loss 16.314191 rs_loss 0.816347 vs_loss 11.051511 logits_loss 444.633272
[32:500] F: 19360000 	 tot_loss 12.512091 rs_loss 0.582966 vs_loss 7.482995 logits_loss 444.613012
[32:600] F: 19392000 	 tot_loss 13.034643 rs_loss 0.611723 vs_loss 7.976818 logits_loss 444.610225
[32:700] F: 19424000 	 tot_loss 14.302225 rs_loss 0.851606 vs_loss 9.004105 logits_loss 444.651374
[32:800] F: 19456000 	 tot_loss 14.649441 rs_loss 0.933363 vs_loss 9.263650 logits_loss 445.242726
[32:900] F: 19488000 	 tot_loss 15.306425 rs_loss 0.880452 vs_loss 9.975449 logits_loss 445.052433
[32:1000] F: 19520000 	 tot_loss 15.627170 rs_loss 1.031646 vs_loss 10.151781 logits_loss 444.374320
[32:1100] F: 19552000 	 tot_loss 15.161479 rs_loss 1.009922 vs_loss 9.704050 logits_loss 444.750732
[32:1200] F: 19584000 	 tot_loss 15.967439 rs_loss 0.937731 vs_loss 10.585549 logits_loss 444.415921
[32:1300] F: 19616000 	 tot_loss 17.159189 rs_loss 1.136083 vs_loss 11.582137 logits_loss 444.096937
[3

[36:400] F: 21728000 	 tot_loss 20.184947 rs_loss 1.514830 vs_loss 14.229224 logits_loss 444.089338
[36:500] F: 21760000 	 tot_loss 17.231334 rs_loss 1.232329 vs_loss 11.564432 logits_loss 443.457325
[36:600] F: 21792000 	 tot_loss 17.188200 rs_loss 1.060454 vs_loss 11.692055 logits_loss 443.569126
[36:700] F: 21824000 	 tot_loss 17.873950 rs_loss 1.266513 vs_loss 12.175299 logits_loss 443.213876
[36:800] F: 21856000 	 tot_loss 17.471251 rs_loss 1.023905 vs_loss 12.016686 logits_loss 443.065957
[36:900] F: 21888000 	 tot_loss 18.919273 rs_loss 1.320837 vs_loss 13.160873 logits_loss 443.756342
[36:1000] F: 21920000 	 tot_loss 18.238991 rs_loss 1.100449 vs_loss 12.698050 logits_loss 444.049202
[36:1100] F: 21952000 	 tot_loss 17.341475 rs_loss 1.055285 vs_loss 11.840297 logits_loss 444.589329
[36:1200] F: 21984000 	 tot_loss 17.238000 rs_loss 1.238882 vs_loss 11.550202 logits_loss 444.891608
[36:1300] F: 22016000 	 tot_loss 18.970032 rs_loss 1.341468 vs_loss 13.180988 logits_loss 444.757

[40:400] F: 24128000 	 tot_loss 17.098542 rs_loss 1.563343 vs_loss 11.083530 logits_loss 445.166889
[40:500] F: 24160000 	 tot_loss 15.994621 rs_loss 1.136698 vs_loss 10.404624 logits_loss 445.329817
[40:600] F: 24192000 	 tot_loss 15.825994 rs_loss 1.127228 vs_loss 10.246408 logits_loss 445.235761
[40:700] F: 24224000 	 tot_loss 13.545900 rs_loss 0.678463 vs_loss 8.412867 logits_loss 445.457012
[40:800] F: 24256000 	 tot_loss 15.296952 rs_loss 0.959322 vs_loss 9.880718 logits_loss 445.691252
[40:900] F: 24288000 	 tot_loss 14.773010 rs_loss 0.914497 vs_loss 9.404090 logits_loss 445.442350
[40:1000] F: 24320000 	 tot_loss 16.128346 rs_loss 1.154443 vs_loss 10.518090 logits_loss 445.581352
[40:1100] F: 24352000 	 tot_loss 15.814594 rs_loss 1.170501 vs_loss 10.192050 logits_loss 445.204196
[40:1200] F: 24384000 	 tot_loss 15.199416 rs_loss 0.925333 vs_loss 9.825508 logits_loss 444.857608
[40:1300] F: 24416000 	 tot_loss 16.908102 rs_loss 1.108527 vs_loss 11.349614 logits_loss 444.996075


[44:400] F: 26528000 	 tot_loss 15.749089 rs_loss 0.832870 vs_loss 10.462105 logits_loss 445.411425
[44:500] F: 26560000 	 tot_loss 17.730723 rs_loss 1.086768 vs_loss 12.189467 logits_loss 445.448751
[44:600] F: 26592000 	 tot_loss 16.520607 rs_loss 0.731820 vs_loss 11.339401 logits_loss 444.938569
[44:700] F: 26624000 	 tot_loss 17.443516 rs_loss 0.906670 vs_loss 12.088709 logits_loss 444.813592
[44:800] F: 26656000 	 tot_loss 17.583899 rs_loss 1.213338 vs_loss 11.919337 logits_loss 445.122381
[44:900] F: 26688000 	 tot_loss 15.394912 rs_loss 0.934092 vs_loss 10.010300 logits_loss 445.052065
[44:1000] F: 26720000 	 tot_loss 17.231105 rs_loss 1.366534 vs_loss 11.413920 logits_loss 445.065067
[44:1100] F: 26752000 	 tot_loss 16.085836 rs_loss 1.176039 vs_loss 10.453992 logits_loss 445.580482
[44:1200] F: 26784000 	 tot_loss 15.976790 rs_loss 0.899864 vs_loss 10.631663 logits_loss 444.526280
[44:1300] F: 26816000 	 tot_loss 16.756789 rs_loss 1.065927 vs_loss 11.245298 logits_loss 444.556

[48:400] F: 28928000 	 tot_loss 16.461327 rs_loss 1.486493 vs_loss 10.537164 logits_loss 443.767067
[48:500] F: 28960000 	 tot_loss 17.595275 rs_loss 1.562693 vs_loss 11.594623 logits_loss 443.795851
[48:600] F: 28992000 	 tot_loss 19.595181 rs_loss 1.925076 vs_loss 13.231995 logits_loss 443.811020
[48:700] F: 29024000 	 tot_loss 16.999094 rs_loss 1.290986 vs_loss 11.265766 logits_loss 444.234200
[48:800] F: 29056000 	 tot_loss 17.935289 rs_loss 1.255618 vs_loss 12.235699 logits_loss 444.397264
[48:900] F: 29088000 	 tot_loss 19.789758 rs_loss 1.500613 vs_loss 13.844061 logits_loss 444.508450
[48:1000] F: 29120000 	 tot_loss 15.691419 rs_loss 0.706155 vs_loss 10.539827 logits_loss 444.543781
[48:1100] F: 29152000 	 tot_loss 17.366406 rs_loss 1.147648 vs_loss 11.775572 logits_loss 444.318601
[48:1200] F: 29184000 	 tot_loss 17.575500 rs_loss 1.178597 vs_loss 11.952278 logits_loss 444.462623
[48:1300] F: 29216000 	 tot_loss 18.467067 rs_loss 1.599898 vs_loss 12.423435 logits_loss 444.373

[52:400] F: 31328000 	 tot_loss 17.209156 rs_loss 1.149061 vs_loss 11.624775 logits_loss 443.532075
[52:500] F: 31360000 	 tot_loss 16.592372 rs_loss 1.107667 vs_loss 11.046839 logits_loss 443.786684
[52:600] F: 31392000 	 tot_loss 15.781751 rs_loss 0.913542 vs_loss 10.428177 logits_loss 444.003210
[52:700] F: 31424000 	 tot_loss 16.153050 rs_loss 1.173249 vs_loss 10.547372 logits_loss 443.242917
[52:800] F: 31456000 	 tot_loss 19.714996 rs_loss 1.709827 vs_loss 13.571529 logits_loss 443.364024
[52:900] F: 31488000 	 tot_loss 22.966148 rs_loss 1.958503 vs_loss 16.578982 logits_loss 442.866365
[52:1000] F: 31520000 	 tot_loss 22.353411 rs_loss 1.897691 vs_loss 16.025698 logits_loss 443.002204
[52:1100] F: 31552000 	 tot_loss 19.630808 rs_loss 1.301201 vs_loss 13.896917 logits_loss 443.269051
[52:1200] F: 31584000 	 tot_loss 15.747257 rs_loss 0.726154 vs_loss 10.583091 logits_loss 443.801262
[52:1300] F: 31616000 	 tot_loss 13.308045 rs_loss 0.723019 vs_loss 8.143277 logits_loss 444.1748

[56:400] F: 33728000 	 tot_loss 17.221692 rs_loss 1.209293 vs_loss 11.578455 logits_loss 443.394369
[56:500] F: 33760000 	 tot_loss 16.389727 rs_loss 0.889977 vs_loss 11.065575 logits_loss 443.417567
[56:600] F: 33792000 	 tot_loss 16.402409 rs_loss 0.713105 vs_loss 11.256800 logits_loss 443.250407
[56:700] F: 33824000 	 tot_loss 16.246397 rs_loss 0.866974 vs_loss 10.945891 logits_loss 443.353181
[56:800] F: 33856000 	 tot_loss 18.205771 rs_loss 1.311719 vs_loss 12.457706 logits_loss 443.634605
[56:900] F: 33888000 	 tot_loss 17.526187 rs_loss 1.302061 vs_loss 11.791025 logits_loss 443.310066
[56:1000] F: 33920000 	 tot_loss 16.119561 rs_loss 1.247869 vs_loss 10.440743 logits_loss 443.094968
[56:1100] F: 33952000 	 tot_loss 15.386701 rs_loss 1.048329 vs_loss 9.907979 logits_loss 443.039315
[56:1200] F: 33984000 	 tot_loss 14.650480 rs_loss 0.882679 vs_loss 9.332319 logits_loss 443.548226
[56:1300] F: 34016000 	 tot_loss 15.005664 rs_loss 1.161068 vs_loss 9.408510 logits_loss 443.608580

[60:400] F: 36128000 	 tot_loss 16.889969 rs_loss 1.192242 vs_loss 11.258858 logits_loss 443.886889
[60:500] F: 36160000 	 tot_loss 18.840314 rs_loss 1.377463 vs_loss 13.028182 logits_loss 443.466962
[60:600] F: 36192000 	 tot_loss 24.872948 rs_loss 2.446608 vs_loss 17.991282 logits_loss 443.505827
[60:700] F: 36224000 	 tot_loss 22.909071 rs_loss 2.067535 vs_loss 16.404380 logits_loss 443.715636
[60:800] F: 36256000 	 tot_loss 21.770005 rs_loss 2.363401 vs_loss 14.967427 logits_loss 443.917614
[60:900] F: 36288000 	 tot_loss 19.286440 rs_loss 2.223943 vs_loss 12.618132 logits_loss 444.436495
[60:1000] F: 36320000 	 tot_loss 13.672459 rs_loss 1.061396 vs_loss 8.160742 logits_loss 445.032125
[60:1100] F: 36352000 	 tot_loss 13.683898 rs_loss 1.061285 vs_loss 8.174947 logits_loss 444.766708
[60:1200] F: 36384000 	 tot_loss 12.347315 rs_loss 0.678781 vs_loss 7.214723 logits_loss 445.381096
[60:1300] F: 36416000 	 tot_loss 15.623554 rs_loss 0.972157 vs_loss 10.201909 logits_loss 444.948817

KeyboardInterrupt: 

In [42]:
for i in range(flags.num_actors):
    actor = ctx.Process(target=gen_data, args=(flags, i, net, buffers, free_queue, ),)
    actor.start()
    actor_processes.append(actor)   
    
for m in range(flags.seq_n): free_queue.put(m)

In [None]:
free_queue.empty()

In [None]:
# stop the threads

for _ in range(flags.num_actors): free_queue.put(None)  
for actor in actor_processes: actor.join(timeout=1)   

<font size="5">Model testing / debug </font>

In [48]:
torch.save({"model_state_dict": model.state_dict(),},"../models/large_model_1.tar")

In [45]:
# Load trained model 

#parser = argparse.ArgumentParser()
#flags = parser.parse_args("".split())         
flags.env = "Sokoban-v0"
flags.env_disable_noop = False
flags.discounting = 0.97
flags.device = torch.device("cuda")
flags.model_type_nn = 1
env = create_env(flags)
obs_shape, num_actions = env.observation_space.shape, env.action_space.n

model = Model(flags, obs_shape, num_actions=num_actions).to(device=flags.device)
checkpoint = torch.load("../models/large_model_1.tar")
model.load_state_dict(checkpoint["model_state_dict"])  

<All keys matched successfully>

In [47]:
all_returns = {}
for n in range(1,5):
    t = time.process_time()
    all_returns[n] = test_n_step_model(n, model, flags, eps_n=500, temp=20)
    print("Time required for %d step planning: %f" %(n, time.process_time()-t))

Testing 1 step planning
Finish 500 episode: avg. return: 0.89 (+-0.09) 
Time required for 1 step planning: 35.609551
Testing 2 step planning
Finish 500 episode: avg. return: 1.06 (+-0.12) 
Time required for 2 step planning: 67.660493
Testing 3 step planning


KeyboardInterrupt: 

In [None]:
# MCTS testing

class MCTS:
    """
    Core Monte Carlo Tree Search algorithm.
    To decide on an action, we run N simulations, always starting at the root of
    the search tree and traversing the tree according to the UCB formula until we
    reach a leaf node.
    """
    def __init__(self, flags, num_actions):
        self.flags = flags
        self.num_actions = num_actions

    def run(self, model, obs, add_exploration_noise,):
        """
        At the root of the search tree we use the representation function to obtain a
        hidden state given the current observation.
        We then run a Monte Carlo Tree Search using only action sequences and the model
        learned by the network.
        Only supports a batch size of 1.        
        """
        with torch.no_grad():
            root = Node(0)
            _, root_predicted_value, policy_logits, hidden_state = model(
                obs["frame"][0], obs["last_action"], one_hot=False)
            reward = 0.
            root_predicted_value = root_predicted_value[-1].item()
            policy_logits = policy_logits[-1]
            hidden_state = hidden_state[-1]

            root.expand(num_actions, reward, policy_logits, hidden_state,)

            if add_exploration_noise:
                root.add_exploration_noise(
                    dirichlet_alpha=self.flags.root_dirichlet_alpha,
                    exploration_fraction=self.flags.root_exploration_fraction,
                )

            min_max_stats = MinMaxStats()

            max_tree_depth = 0
            
            #print("p at root:", torch.softmax(policy_logits, dim=-1))
            for k in range(self.flags.num_simulations): 
                
                #print("=======%d iteration======"%k)
                node = root
                search_path = [node]
                current_tree_depth = 0

                while node.expanded():
                    current_tree_depth += 1                    
                    action, node = self.select_child(node, min_max_stats)                    
                    search_path.append(node)
                    #print("action sel: %d" % action)
                
                #np.set_printoptions(precision=5)
                #for x in ["prior_score", "value_score", "pb_c", "prior", "visit_count"]:                    
                #    print(x, "\t", np.array([getattr(search_path[0].children[n], x) for n in range(5)]))

                # Inside the search tree we use the dynamics function to obtain the next hidden
                # state given an action and the previous hidden state
                parent = search_path[-2]     
                reward, value, policy_logits, hidden_state = model.forward_encoded(
                    parent.hidden_state, torch.tensor([[action]]).to(parent.hidden_state.device), one_hot=False)
                reward = reward[-1].item()
                value = value[-1].item()
                #print("model final output: %4f" % value)
                policy_logits = policy_logits[-1]
                hidden_state = hidden_state[-1]
                node.expand(num_actions, reward, policy_logits, hidden_state)
                self.backpropagate(search_path, value, min_max_stats)
                max_tree_depth = max(max_tree_depth, current_tree_depth)

            extra_info = {
                "max_tree_depth": max_tree_depth,
                "root_predicted_value": root_predicted_value,
            }
        return root, extra_info

    def select_child(self, node, min_max_stats):
        """
        Select the child with the highest UCB score.
        """
        max_ucb = max(
            self.ucb_score(node, child, min_max_stats)
            for action, child in node.children.items()
        )
        action = np.random.choice(
            [
                action
                for action, child in node.children.items()
                if self.ucb_score(node, child, min_max_stats) == max_ucb
            ]
        )
        return action, node.children[action]

    def ucb_score(self, parent, child, min_max_stats):
        """
        The score for a node is based on its value, plus an exploration bonus based on the prior.
        """
        pb_c = (
            math.log(
                (parent.visit_count + self.flags.pb_c_base + 1) / self.flags.pb_c_base
            )
            + self.flags.pb_c_init
        )
        pb_c *= math.sqrt(parent.visit_count) / (child.visit_count + 1)

        prior_score = pb_c * child.prior

        if child.visit_count > 0:
            # Mean value Q
            value_score = min_max_stats.normalize(
                child.reward + self.flags.discounting * child.value())
        else:
            value_score = 0
            
        child.pb_c = pb_c
        child.prior_score = prior_score
        child.value_score = value_score
        
        return prior_score + value_score

    def backpropagate(self, search_path, value, min_max_stats):
        """
        At the end of a simulation, we propagate the evaluation all the way up the tree
        to the root.
        """
        #print("bs value: %.4f" % value)
        for n, node in enumerate(reversed(search_path)):
            node.value_sum += value
            node.visit_count += 1
            min_max_stats.update(node.reward + self.flags.discounting * node.value())
            value = node.reward + self.flags.discounting * value
            #print("%d - val: %.4f r: %.4f" % (n, value, node.reward))
            #print("node value_sum %.4f" % node.value_sum)

class Node:
    def __init__(self, prior):
        self.visit_count = 0
        self.prior = prior
        self.value_sum = 0
        self.children = {}
        self.hidden_state = None
        self.reward = 0

    def expanded(self):
        return len(self.children) > 0

    def value(self):
        if self.visit_count == 0:
            return 0
        return self.value_sum / self.visit_count

    def expand(self, num_actions, reward, policy_logits, hidden_state):
        """
        We expand a node using the value, reward and policy prediction obtained from the
        neural network.
        """
        self.reward = reward
        self.hidden_state = hidden_state
        policy_values = torch.softmax(policy_logits[0], dim=0).tolist()
        for a in range(num_actions):
            self.children[a] = Node(policy_values[a])

    def add_exploration_noise(self, dirichlet_alpha, exploration_fraction):
        """
        At the start of each search, we add dirichlet noise to the prior of the root to
        encourage the search to explore new actions.
        """
        actions = list(self.children.keys())
        noise = np.random.dirichlet([dirichlet_alpha] * len(actions))
        frac = exploration_fraction
        for a, n in zip(actions, noise):
            self.children[a].prior = self.children[a].prior * (1 - frac) + n * frac

class MinMaxStats:
    """
    A class that holds the min-max values of the tree.
    """

    def __init__(self):
        self.maximum = -float("inf")
        self.minimum = float("inf")

    def update(self, value):
        self.maximum = max(self.maximum, value)
        self.minimum = min(self.minimum, value)

    def normalize(self, value):
        if self.maximum > self.minimum:
            # We normalize only when we have set the maximum and minimum values
            return (value - self.minimum) / (self.maximum - self.minimum)
        return value            

def select_action(node, temperature):
    """
    Select action according to the visit count distribution and the temperature.
    The temperature is changed dynamically with the visit_softmax_temperature function
    in the config.
    """
    visit_counts = np.array(
        [child.visit_count for child in node.children.values()], dtype="int32"
    )
    actions = [action for action in node.children.keys()]
    if temperature == 0:
        action = actions[np.argmax(visit_counts)]
    elif temperature == float("inf"):
        action = np.random.choice(actions)
    else:
        # See paper appendix Data Generation
        visit_count_distribution = visit_counts ** (1 / temperature)
        visit_count_distribution = visit_count_distribution / sum(
            visit_count_distribution
        )
        action = np.random.choice(actions, p=visit_count_distribution)
    #print("visit_counts", visit_counts)
    #print("visit_count_distribution", visit_count_distribution)
    return action
    
    
parser = argparse.ArgumentParser()      
flags = parser.parse_args([])   

env = SokobanWrapper(gym.make("Sokoban-v0"), noop=True)
env = Environment(env)
env.initial()
obs_shape, num_actions = env.gym_env.observation_space.shape, env.gym_env.action_space.n

parser = argparse.ArgumentParser()
flags = parser.parse_args([])   
flags.discounting = 0.97
flags.pb_c_init = 1.25
flags.pb_c_base = 19652
flags.root_dirichlet_alpha = 0.25
flags.root_exploration_fraction = 0.
flags.num_simulations = 3
flags.temp = 0.5
flags.device = torch.device("cuda")

eps_n = 10
eps_n_cur = 0

model = Model(flags, obs_shape, num_actions=num_actions).to(device=flags.device)
checkpoint = torch.load("../models/model_1.tar")
model.load_state_dict(checkpoint["model_state_dict"])  

obs = env.initial()
returns = []
mcts = MCTS(flags, num_actions)
obs = {k:v.to(flags.device) for k, v in obs.items()}
root, extra_info = mcts.run(model, obs, add_exploration_noise=True)   

plt.imshow(torch.swapaxes(torch.swapaxes(obs['frame'][0,0].to(
    flags.device).clone().cpu(),0,2),0,1), interpolation='nearest')
plt.show()

actions = torch.tensor([0, 1, 3, 4, 4, 4, 2, 4, 2, 2, 2, 1, 1, 1, 1]).long().to(flags.device).reshape(-1, 1)
reward, value, policy_logits, hidden_state  = model(obs["frame"][0], actions, one_hot=False)
print(reward, value)

while(len(returns) <= eps_n):
    cur_returns = obs['episode_return']    
    obs = {k:v.to(flags.device) for k, v in obs.items()}
    root, extra_info = mcts.run(model, obs, add_exploration_noise=True)    
    new_action = select_action(root, flags.temp)    
    #plt.imshow(torch.swapaxes(torch.swapaxes(obs['frame'][0,0].to(
    #flags.device).clone().cpu(),0,2),0,1), interpolation='nearest')
    #plt.show()
    #print("action selected", new_action)
    #print("===========================================")
    obs = env.step(torch.tensor([new_action]))
    if torch.any(obs['done']):
        returns.extend(cur_returns[obs['done']].numpy())
    if eps_n_cur <= len(returns) and len(returns) > 0: 
        eps_n_cur = len(returns) + 10
print("Finish %d episode: avg. return: %.2f (+-%.2f) " % (len(returns),
            np.average(returns), np.std(returns) / np.sqrt(len(returns))))

In [None]:
# Test planning algorithm
bsz = 1
env = gym.vector.SyncVectorEnv([lambda: SokobanWrapper(gym.make("Sokoban-v0"), noop=True)] * bsz)
env = Vec_Environment(env, bsz)
obs = env.initial()
state = obs['frame'][0].to(flags.device).clone()
action = torch.zeros(bsz).long().to(flags.device)
encoded = None

In [None]:
action = torch.Tensor([4]).long().to(flags.device)
obs = env.step(action)
state = obs['frame'][0].to(flags.device).clone()

In [None]:
plt.imshow(torch.swapaxes(torch.swapaxes(state[0].cpu(),0,2),0,1), interpolation='nearest')
plt.show()

In [None]:
from matplotlib import pyplot as plt
import logging
logging.getLogger('matplotlib.font_manager').disabled = True
device = flags.device

for _ in range(1):
    plt.imshow(torch.swapaxes(torch.swapaxes(state[0].cpu(),0,2),0,1), interpolation='nearest')
    plt.show()
    ret = np.zeros((5, 5, 5))
    for i in range(5):
        for j in range(5):
            for k in range(5):
                test_action_seq = [i,j,k]
                test_action_seq = torch.Tensor(test_action_seq).unsqueeze(-1).long().to(device)  
                old_new_actions = torch.concat([action.unsqueeze(0), test_action_seq], dim=0)
                rs, vs, logits, encodeds = model(state, old_new_actions)
                ret[i, j, k] = rs[0] + rs[1] * 0.97 + rs[2] * (0.97**2) + vs[-1] * (0.97**3)
    print(np.max(ret), (np.max(ret) == ret).nonzero())    
    new_action = torch.Tensor((np.max(ret) == ret).nonzero()[0]).long().to(flags.device)
    #obs = env.step(new_action)
    #state = obs['frame'][0].to(flags.device).clone()
    #action = new_action            

In [None]:
action, prob, q_ret = n_step_greedy_model(state, action, model, 3, encoded=None, temp=10.)
print("action: ", action)
print("prob: ", prob)
print("q_ret: ", q_ret)

In [None]:
test_action_seq = [2,3,1]
test_action_seq = torch.Tensor(test_action_seq).unsqueeze(-1).long().to(device)  
old_new_actions = torch.concat([action.unsqueeze(0), test_action_seq], dim=0)
rs, vs, logits, encodeds = model(state, old_new_actions)
ret[i, j, k] = rs[0] + rs[1] * 0.97 + rs[2] * (0.97**2) + vs[-1] * (0.97**3)
print("rs", rs)
print("vs", vs)
print("logits", logits)
print("ret", ret[i,j,k])

In [None]:
temp = 10.

bsz = state.shape[0]
device = state.device 
num_actions = model.num_actions    
model.train(False)

q_ret = torch.zeros(bsz, num_actions).to(device)        
rs_act = torch.zeros(bsz, num_actions).to(device)        
vs_act = torch.zeros(bsz, num_actions).to(device)        

for act in range(num_actions):        
    new_action = torch.Tensor(np.full(bsz, act)).long().to(device)    
    old_new_actions = torch.concat([action.unsqueeze(0), new_action.unsqueeze(0)], dim=0)
    rs, vs, logits, encodeds = model(state, old_new_actions)
    ret = rs[0] + flags.discounting * vs[1]
    rs_act[:, act] = rs[0]
    vs_act[:, act] = vs[1]
    q_ret[:, act] = ret

prob = F.softmax(temp*q_ret, dim=1)
action = torch.multinomial(prob, num_samples=1)[:, 0]

print("rs_act", rs_act)
print("vs_act", vs_act)
print("q_ret", q_ret)
print("prob", prob)

In [None]:
device = flags.device
net_state = env.clone_state()

bsz = 1
temp = 10.
q_ret = torch.zeros(bsz, num_actions).to(device)      
rs_act = torch.zeros(bsz, num_actions).to(device)        
vs_act = torch.zeros(bsz, num_actions).to(device)   

net = net.to(device)

for act in range(num_actions):
    obs = env.step(torch.Tensor(np.full(bsz, act)).long())      
    obs = {k:v.to(device) for k, v in obs.items()}   
    ret = obs['reward'] + flags.discounting * net(obs)[0]['baseline'] * (~obs['done']).float()
    rs_act[:, act] = obs['reward']
    vs_act[:, act] = net(obs)[0]['baseline']
    q_ret[:, act] = ret
    env.restore_state(net_state)

prob = F.softmax(temp*q_ret, dim=1)
action = torch.multinomial(prob, num_samples=1)[:, 0]

print("rs_act", rs_act)
print("vs_act", vs_act)
print("q_ret", q_ret)
print("prob", prob)

plt.imshow(torch.swapaxes(torch.swapaxes(state[0].cpu(),0,2),0,1), interpolation='nearest')
plt.show()

In [None]:
batch = get_batch_m(flags, buffers)
print(torch.max(batch["reward"]), (torch.max(batch["reward"]) == batch["reward"]).nonzero())
print(batch["done"].nonzero())

In [None]:
# DEBUG LOSS

#batch = get_batch_m(flags, buffers)

model.train(False)

rs, vs, logits, _ = model(batch['frame'][0], batch['action'])
logits = logits[:-1]

target_rewards = batch['reward'][1:]
target_logits = batch['policy_logits'][1:]

target_vs = []
target_v = model(batch['frame'][-1], batch['action'][[-1]])[1][0]    

for t in range(vs.shape[0]-1, 0, -1):
    new_target_v = batch['reward'][t] + flags.discounting * (target_v * (~batch['done'][t]).float() +
                       vs[t-1] * (batch['truncated_done'][t]).float())
    target_vs.append(new_target_v.unsqueeze(0))
    target_v = new_target_v
target_vs.reverse()
target_vs = torch.concat(target_vs, dim=0)

# if done on step j, r_{j}, v_{j-1}, a_{j-1} has the last valid loss 
# rs is stored in the form of r_{t+1}, ..., r_{t+k}
# vs is stored in the form of v_{t}, ..., v_{t+k-1}
# logits is stored in the form of a{t}, ..., a_{t+k-1}

done_masks = []
done = torch.zeros(vs.shape[1]).bool().to(batch['done'].device)
for t in range(vs.shape[0]):
    done = torch.logical_or(done, batch['done'][t])
    done_masks.append(done.unsqueeze(0))

done_masks = torch.concat(done_masks[:-1], dim=0)

# compute final loss
huberloss = torch.nn.HuberLoss(reduction='none', delta=1.0)    
#rs_loss = torch.sum(huberloss(rs, target_rewards) * (~done_masks).float())
rs_loss = torch.sum(((rs - target_rewards) ** 2) * (~done_masks).float())
#vs_loss = torch.sum(huberloss(vs[:-1], target_vs) * (~done_masks).float())
vs_loss = torch.sum(((vs[:-1] - target_vs) ** 2) * (~done_masks).float())
logits_loss = compute_cross_entropy_loss(logits, target_logits, done_masks)

# debug
ind = 10

target_vs = []
target_v = vs[-1]
for t in range(vs.shape[0]-1, 0, -1):        
    new_target_v = batch['reward'][t] + flags.discounting * (target_v * (~batch['done'][t]).float() +
                       vs[t-1] * (batch['truncated_done'][t]).float())
    print(t, 
          "reward %2f" % batch['reward'][t,ind].item(), 
          "bootstrap %2f" % (target_v * (~batch['done'][t]).float())[ind].item(), 
          "truncated %2f" % (vs[t-1] * (batch['truncated_done'][t]).float())[ind].item(),
          "vs[t-1] %2f" % vs[t-1][ind].item(),
          "new_targ %2f" % new_target_v[ind].item())
    target_vs.append(new_target_v.unsqueeze(0))    
    target_v = new_target_v
target_vs.reverse()
target_vs = torch.concat(target_vs, dim=0)   
print("done", batch["done"][:, ind])
print("done_masks", done_masks[:, ind])
print("vs: ", vs[:, ind])
print("target_vs: ", target_vs[:, ind])
print("reward: ", rs[:, ind])
print("target_reward: ", target_rewards[:, ind])
print("logits: ", logits[:, ind])
print("target_logits: ", target_logits[:, ind])

In [None]:
def compute_loss_m(model, batch):
    rs, vs, logits, _ = model(batch['frame'][0], batch['action'])
    logits = logits[:-1]

    target_rewards = batch['reward'][1:]
    target_logits = batch['policy_logits'][1:]

    target_vs = []
    target_v = model(batch['frame'][-1], batch['action'][[-1]])[1][0].detach()
    
    for t in range(vs.shape[0]-1, 0, -1):
        new_target_v = batch['reward'][t] + flags.discounting * (target_v * (~batch['done'][t]).float())# +
                           #vs[t-1] * (batch['truncated_done'][t]).float())
        target_vs.append(new_target_v.unsqueeze(0))
        target_v = new_target_v
    target_vs.reverse()
    target_vs = torch.concat(target_vs, dim=0)

    # if done on step j, r_{j}, v_{j-1}, a_{j-1} has the last valid loss 
    # rs is stored in the form of r_{t+1}, ..., r_{t+k}
    # vs is stored in the form of v_{t}, ..., v_{t+k-1}
    # logits is stored in the form of a{t}, ..., a_{t+k-1}

    done_masks = []
    done = torch.zeros(vs.shape[1]).bool().to(batch['done'].device)
    for t in range(vs.shape[0]):
        done = torch.logical_or(done, batch['done'][t])
        done_masks.append(done.unsqueeze(0))

    done_masks = torch.concat(done_masks[:-1], dim=0)
    
    # compute final loss
    huberloss = torch.nn.HuberLoss(reduction='none', delta=1.0)    
    rs_loss = torch.sum(huberloss(rs, target_rewards.detach()) * (~done_masks).float())
    #rs_loss = torch.sum(((rs - target_rewards) ** 2) * (~done_masks).float())
    vs_loss = torch.sum(huberloss(vs[:-1], target_vs.detach()) * (~done_masks).float())
    #vs_loss = torch.sum(((vs[:-1] - target_vs) ** 2) * (~done_masks).float())
    logits_loss = compute_cross_entropy_loss(logits, target_logits.detach(), done_masks)
    
    return rs_loss, vs_loss, logits_loss

# alt. version of computing loss by treading terminal state as absorbing state (as in MuZero)

def compute_loss_m(model, batch):

    rs, vs, logits, _ = model(batch['frame'][0], batch['action'])
    logits = logits[:-1]

    target_logits = batch['policy_logits'][1:].clone()
    target_rewards = batch['reward'][1:].clone()

    done_masks = []
    done = torch.zeros(vs.shape[1]).bool().to(batch['done'].device)

    c_logits = target_logits[0]
    c_state = batch['frame'][0]
    for t in range(vs.shape[0]-1):
        if t > 0: done = torch.logical_or(done, batch['done'][t])
        c_logits = torch.where(done.unsqueeze(-1), c_logits, target_logits[t])
        target_logits[t] = c_logits
        c_state = torch.where(done.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1), c_state, batch['frame'][t])  
        done_masks.append(done.unsqueeze(0))
    done_masks = torch.concat(done_masks, dim=0)
    done = torch.logical_or(done, batch['done'][-1])
    c_state = torch.where(done.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1), c_state, batch['frame'][-1])
    target_rewards = target_rewards * (~done_masks).float()

    target_vs = []
    target_v = model(c_state, batch['action'][[-1]])[1][0].detach()
    
    for t in range(vs.shape[0]-1, 0, -1):
        new_target_v = batch['reward'][t] + flags.discounting * target_v
        target_vs.append(new_target_v.unsqueeze(0))
        target_v = new_target_v
    target_vs.reverse()
    target_vs = torch.concat(target_vs, dim=0)
    
    # compute final loss
    huberloss = torch.nn.HuberLoss(reduction='none', delta=1.0)    
    rs_loss = torch.sum(huberloss(rs, target_rewards.detach()))
    #rs_loss = torch.sum(((rs - target_rewards) ** 2) * (~r_logit_done_masks).float())
    vs_loss = torch.sum(huberloss(vs[:-1], target_vs.detach()))
    #vs_loss = torch.sum(((vs[:-1] - target_vs) ** 2) * (~v_done_masks).float())
    logits_loss = compute_cross_entropy_loss(logits, target_logits.detach(), None)
    
    return rs_loss, vs_loss, logits_loss