<a href="https://colab.research.google.com/github/rootAkash/reinforcement_learning/blob/master/muzero/muzero_more_efficient_PER.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#new experinces have priority pred value and z value  and weights will be automatically(w = ((1/(N*1/N))**beta)) at the start
#and the l1 loss between predicted value and returns z is calculated from the initial training forward pass and not after backprop for updating priorities
#but that anyways will require extra loop of converting them to scalars from category outputs to calculate l1 unless vectorised which it is
#so it will be fastter than looping since catts supports vectors so looping is not nneeded so priority update before backprop is fine

In [None]:
!pip install gym[all]
!pip install box2d-py
!apt-get install python-opengl -y
!apt install xvfb -y

In [1]:
import numpy as np
def stcat(x,support=5):
  x = np.sign(x) * ((abs(x) + 1)**0.5 - 1) + 0.001 * x
  x = np.clip(x, -support, support)
  floor = np.floor(x)
  prob = x - floor
  logits = np.zeros( 2 * support + 1)
  first_index = int(floor + support)
  second_index = int(floor + support+1)
  logits[first_index] = 1-prob
  if prob>0:
    logits[second_index] = prob
  return logits
#allow for batch processing  
def catts(x,support=5):
  support = np.arange(-support, support+1, 1)
  if len(x.shape)==2:
    #for  batch of x\
    x = np.sum(support*x,axis=1)
  elif len(x.shape)==1:
    #for single x
    x = np.sum(support*x)  
  else:
    print("wrong input for conversion to  scalar")  
  x = np.sign(x) * ((((1 + 4 * 0.001 * (np.abs(x) + 1 + 0.001))**0.5 - 1) / (2 * 0.001))** 2- 1)
  return x  

#cat = stcat(5)#test 1 example
cat = np.array([stcat(5),stcat(-5)]) # test batch example
print(cat,cat.shape)
scalar = catts(cat)
print(scalar)
print("done")        


[[0.         0.         0.         0.         0.         0.
  0.54551026 0.45448974 0.         0.         0.        ]
 [0.         0.         0.         0.45448974 0.54551026 0.
  0.         0.         0.         0.         0.        ]] (2, 11)
[ 5. -5.]
done


In [2]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim




class MuZeroNet(nn.Module):
    def __init__(self, input_size, action_space_n, reward_support_size, value_support_size):
        super().__init__()
        self.hx_size = 32
        self._representation = nn.Sequential(nn.Linear(input_size, self.hx_size),
                                             nn.Tanh())
        self._dynamics_state = nn.Sequential(nn.Linear(self.hx_size + action_space_n, 64),
                                             nn.Tanh(),
                                             nn.Linear(64, self.hx_size),
                                             nn.Tanh())
        self._dynamics_reward = nn.Sequential(nn.Linear(self.hx_size + action_space_n, 64),
                                              nn.LeakyReLU(),
                                              nn.Linear(64, 2*reward_support_size+1))
        self._prediction_actor = nn.Sequential(nn.Linear(self.hx_size, 64),
                                               nn.LeakyReLU(),
                                               nn.Linear(64, action_space_n))
        self._prediction_value = nn.Sequential(nn.Linear(self.hx_size, 64),
                                               nn.LeakyReLU(),
                                               nn.Linear(64, 2*value_support_size+1))
        self.action_space_n = action_space_n

        self._prediction_value[-1].weight.data.fill_(0)
        self._prediction_value[-1].bias.data.fill_(0)
        self._dynamics_reward[-1].weight.data.fill_(0)
        self._dynamics_reward[-1].bias.data.fill_(0)

    def p(self, state):
        actor = torch.softmax(self._prediction_actor(state),dim=1)
        value = torch.softmax(self._prediction_value(state),dim=1)
        return actor, value

    def h(self, obs_history):
        return self._representation(obs_history)

    def g(self, state, action):
        x = torch.cat((state, action), dim=1)
        next_state = self._dynamics_state(x)
        reward = torch.softmax(self._dynamics_reward(x),dim=1)
        return next_state, reward     

    def initial_state(self, x):
        hout = self.h(x)
        prob,v= self.p(hout)
        return hout,prob,v
    def next_state(self,hin,a):
        hout,r = self.g(hin,a)
        prob,v= self.p(hout)
        return hout,r,prob,v
    def inference_initial_state(self, x):
        with torch.no_grad():
          hout = self.h(x)
          prob,v=self.p(hout)

          return hout,prob,v
    def inference_next_state(self,hin,a):
        with torch.no_grad():
          hout,r = self.g(hin,a)
          prob,v=self.p(hout)
          return hout,r,prob,v     


print("done")                                      

done


In [3]:

#MTCS    MUzero modified for intermeditate rewards settings and using predicted rewards
#accepts policy as a list
import torch
import math
import numpy as np

import random
def dynamics(net,state,action):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    #print(state,action) 
    next_state,reward,prob,value = net.inference_next_state(state.to(device),torch.tensor([action]).float().to(device))
    reward = catts(reward.cpu().numpy().ravel())
    value = catts(value.cpu().numpy().ravel())
    prob = prob.cpu().tolist()[0]
    #print("dynamics",prob)
    return next_state.cpu(),reward,prob,value


class MinMaxStats:
    """A class that holds the min-max values of the tree."""

    def __init__(self):
        self.MAXIMUM_FLOAT_VALUE = float('inf')       
        self.maximum =  -self.MAXIMUM_FLOAT_VALUE
        self.minimum =  self.MAXIMUM_FLOAT_VALUE

    def update(self, value: float):
        if value is None:
            raise ValueError

        self.maximum = max(self.maximum, value)
        self.minimum = min(self.minimum, value)

    def normalize(self, value: float) -> float:
        # If the value is unknow, by default we set it to the minimum possible value
        if value is None:
            return 0.0

        if self.maximum > self.minimum:
            # We normalize only when we have set the maximum and minimum values.
            return (value - self.minimum) / (self.maximum - self.minimum)
        return value


class Node:
    """A class that represent nodes inside the MCTS tree"""

    def __init__(self, prior: float):
        self.visit_count = 0
        self.to_play = -1
        self.prior = prior
        self.value_sum = 0
        self.children = {}
        self.hidden_state = None
        self.reward = 0

    def expanded(self):
        return len(self.children) > 0

    def value(self):
        if self.visit_count == 0:
            return None
        return self.value_sum / self.visit_count


def softmax_sample(visit_counts, actions, t):
    counts_exp = np.exp(visit_counts) * (1 / t)
    probs = counts_exp / np.sum(counts_exp, axis=0)
    action_idx = np.random.choice(len(actions), p=probs)
    return actions[action_idx]


"""MCTS module: where MuZero thinks inside the tree."""


def add_exploration_noise( node):
    """
    At the start of each search, we add dirichlet noise to the prior of the root
    to encourage the search to explore new actions.
    """
    actions = list(node.children.keys())
    noise = np.random.dirichlet([0.25] * len(actions)) # config.root_dirichlet_alpha
    frac = 0.25#config.root_exploration_fraction
    for a, n in zip(actions, noise):
        node.children[a].prior = node.children[a].prior * (1 - frac) + n * frac



def ucb_score(parent, child,min_max_stats):
    """
    The score for a node is based on its value, plus an exploration bonus based on
    the prior.

    """
    pb_c_base = 19652
    pb_c_init = 1.25
    pb_c = math.log((parent.visit_count + pb_c_base + 1) / pb_c_base) + pb_c_init
    pb_c *= math.sqrt(parent.visit_count) / (child.visit_count + 1)

    prior_score = pb_c * child.prior
    value_score = min_max_stats.normalize(child.value())
    return  value_score + prior_score 

def select_child(node, min_max_stats):
    """
    Select the child with the highest UCB score.
    """
    # When the parent visit count is zero, all ucb scores are zeros, therefore we return a random child
    if node.visit_count == 0:
        return random.sample(node.children.items(), 1)[0]

    _, action, child = max(
        (ucb_score(node, child, min_max_stats), action,
         child) for action, child in node.children.items())
    return action, child




def expand_node(node, to_play, actions_space,hidden_state,reward,policy):
    """
    We expand a node using the value, reward and policy prediction obtained from
    the neural networks.
    """
    node.to_play = to_play
    node.hidden_state = hidden_state
    node.reward = reward
    policy = {a:policy[a] for a in actions_space}
    policy_sum = sum(policy.values())
    for action, p in policy.items():
        node.children[action] = Node(p / policy_sum) # not needed since mine are already softmax but its fine 


def backpropagate(search_path, value,to_play,discount, min_max_stats):
    """
    At the end of a simulation, we propagate the evaluation all the way up the
    tree to the root.
    """
    for node in search_path[::-1]: #[::-1] means reversed
        node.value_sum += value 
        node.visit_count += 1
        min_max_stats.update(node.value())

        value = node.reward + discount * value


def select_action(node, mode ='softmax'):
    """
    After running simulations inside in MCTS, we select an action based on the root's children visit counts.
    During training we use a softmax sample for exploration.
    During evaluation we select the most visited child.
    """
    visit_counts = [child.visit_count for child in node.children.values()]
    actions = [action for action in node.children.keys()]
    action = None
    if mode == 'softmax':
        t = 1.0
        action = softmax_sample(visit_counts, actions, t)
    elif mode == 'max':
        action, _ = max(node.children.items(), key=lambda item: item[1].visit_count)
    counts_exp = np.exp(visit_counts)
    probs = counts_exp / np.sum(counts_exp, axis=0)    
    #return action ,probs,node.value()
    return action ,np.array(visit_counts)/sum(visit_counts),node.value()

def run_mcts(net, state,prob,root_value,num_simulations,discount = 0.9):
    """
    Core Monte Carlo Tree Search algorithm.
    To decide on an action, we run N simulations, always starting at the root of
    the search tree and traversing the tree according to the UCB formula until we
    reach a leaf node.
    """
    prob, root_value = prob.tolist()[0] ,catts(root_value.numpy().ravel())
    to_play = True
    action_space=[ i for i in range(len(prob))]#history.action_space()
    #print("action space",action_space)
    root = Node(0)
    expand_node(root, to_play,action_space,state,0.0,prob)#node, to_play, actions_space ,hidden_state,reward,policy
    add_exploration_noise( root)


    min_max_stats = MinMaxStats()

    for _ in range(num_simulations): 
        node = root
        search_path = [node]

        while node.expanded():
            action, node = select_child( node, min_max_stats)
            search_path.append(node)

        # Inside the search tree we use the dynamics function to obtain the next
        # hidden state given an action and the previous hidden state.
        parent = search_path[-2]
        
        #network_output = network.recurrent_inference(parent.hidden_state, action)
        next_state,r,action_probs, value = dynamics(net,parent.hidden_state,onehot(action,len(action_space))) 
        expand_node(node, to_play, action_space,next_state,r,action_probs)#node, to_play, actions_space ,hidden_state,reward,policy

        backpropagate(search_path, value, to_play, discount, min_max_stats)#search_path, value,,discount, min_max_stats
    return root    


In [4]:
import gym
class ScalingObservationWrapper(gym.ObservationWrapper):
    """
    Wrapper that apply a min-max scaling of observations.
    """

    def __init__(self, env, low=None, high=None):
        super().__init__(env)
        assert isinstance(env.observation_space, gym.spaces.Box)

        low = np.array(self.observation_space.low if low is None else low)
        high = np.array(self.observation_space.high if high is None else high)

        self.mean = (high + low) / 2
        self.max = high - self.mean

    def observation(self, observation):
        return (observation - self.mean) / self.max

In [7]:

import random
import numpy as np
import torch
from tqdm import tqdm
def onehot(a,n=2):
  return np.eye(n)[a]
def play_game(env,net,n_sim,discount,render,device,n_act,max_steps,td_steps,per):
    trajectory=[]
    root_values,pred_values,rewards=[],[],[]
    state = env.reset() 
    done = False
    r =0 
    stp=0
    while not done:
        if render:
          env.render()
        stp+=1  
        h ,prob,pred_value= net.inference_initial_state(torch.tensor([state]).float().to(device)) 
        root  = run_mcts(net,h.cpu(),prob.cpu(),pred_value.cpu(),num_simulations=n_sim,discount=discount)
        action,action_prob,mcts_val = select_action(root) 
        next_state, reward, done, info = env.step(action)
        r+=reward
        if stp>max_steps:
          done = True
        data = [state,onehot(action,n_act),action_prob,mcts_val,reward,1]#state,onehotaction,action_prob,mcts_val,reward,priority
        root_values.append(mcts_val)
        pred_values.append(catts(pred_value.cpu().numpy().ravel()))
        rewards.append(reward)
        trajectory.append(data)
        state = next_state
    #calculating priority as z - pred value
    if per:  
      priorities =get_initial_priorities(root_values,pred_values,rewards,discount=discount, td_steps=td_steps)
      #update trajectory priority
      assert len(trajectory) == len(priorities)
      for i in range(len(trajectory)):
        trajectory[i][5]=priorities[i]
    print("DATA collection:played for ",len(trajectory)," steps , rewards",r)   
    return trajectory    
def get_initial_priorities(root_values,pred_values,rewards,discount=0.99, td_steps=10):
    z_values = []
    alpha = 1
    beta = 1 
    for current_index in range(len(root_values)):
        bootstrap_index = current_index + td_steps
        if bootstrap_index < len(root_values):
            value = root_values[bootstrap_index] * discount ** td_steps
        else:
            value = 0

        for i, reward in enumerate(rewards[current_index:bootstrap_index]):
            value += reward * discount ** i

        if current_index < len(root_values):
            z_values.append(value)
    #print("get priorities",pred_values,z_values)        
    p = np.abs(np.array(pred_values)-np.array(z_values))**alpha  + 0.00001
    #priority = p /np.sum(p)
    #N= len(pred_values) 
    #weights = (1/(N*priority))**beta
    return list(p)#,list(weights)
def eval_game(env,net,n_sim,render,device,max_steps):
    state = env.reset() 
    done = False
    r = 0
    stp=0
    while not done:
        if render:
          env.render()
        stp+=1  
        h ,prob,value= net.inference_initial_state(torch.tensor([state]).float().to(device)) 
        root  = run_mcts(net,h.cpu(),prob.cpu(),value.cpu(),num_simulations=n_sim,discount=discount)
        action,action_prob,mcts_val = select_action(root,"max")
        next_state, reward, done, info = env.step(action)
        if stp>max_steps:
          done = True
        r+=reward
        state = next_state
    print("Eval:played for ",r ," rewards")   
    
def sample_games(buffer,batch_size):
    # Sample game from buffer either uniformly or according to some priority
    #print("samplig from .",len(buffer))
    return list(np.random.choice(len(buffer),batch_size))

def sample_position(trajectory,priority=None):
    # Sample position from game either uniformly or according to some priority.
    if priority == None:
      return np.random.choice(len(trajectory),1)[0]
    return np.random.choice(len(trajectory),1,p = priority)[0]
    #return np.random.choice(list(range(0, len(trajectory))),1,p = priority)[0]

def update_priorites(buffer,indexes,new_priority):
    #buffer is a list and is passed as refernce so changes made here will reflect in buffer
    for i in range(len(indexes)):
      x,y = indexes[i]
      #old_state,old_onehot_action,old_action_prob,old_mcts_val,old_reward,old_pred_value = buffer[x][y]
      #buffer[x][y]=(old_state,old_onehot_action,old_action_prob,old_mcts_val,old_reward,new_pred_values[i])
      buffer[x][y][5]=new_priority[i]


def sample_batch(action_space_size,buffer,discount,batch_size,num_unroll_steps, td_steps,per):
    obs_batch, action_batch, reward_batch, value_batch, policy_batch,weights_batch = [], [], [], [], [],[]
    indexes=[]
    game_idx = sample_games(buffer,batch_size)
    for gi in game_idx:
      g = buffer[gi]
      state,action,action_prob,root_val,reward,priority = zip(*g)
      state,action,action_prob,root_val,reward,priority  =list(state),list(action),list(action_prob),list(root_val),list(reward),list(priority)
      #print("pred val sample batch",priority)
      if per:
        #make priority for sampling from root_value and n_step value
        ps  = np.array(priority)/np.sum(np.array(priority))
        game_pos = sample_position(g,list(ps))#state index sampled using priority
        beta =1 
        N = len(g)
        weight =(1/(N*ps[game_pos]))**beta
        #N= len(pred_values) 
        #weights = (1/(N*priority))**beta
      else:  
        weight = 1.0
        game_pos = sample_position(g)#state index sampled using priority
      _actions = action[game_pos:game_pos + num_unroll_steps]
      # random action selection to complete num_unroll_steps
      _actions += [onehot(np.random.randint(0, action_space_size),action_space_size)for _ in range(num_unroll_steps - len(_actions))]

      obs_batch.append(state[game_pos])
      action_batch.append(_actions)
      value, reward, policy = make_target(child_visits=action_prob ,root_values=root_val,rewards=reward,state_index=game_pos,discount=discount, num_unroll_steps=num_unroll_steps, td_steps=td_steps)
      reward_batch.append(reward)
      value_batch.append(value)
      policy_batch.append(policy)
      weights_batch.append(weight)
      indexes.append((gi,game_pos))



    obs_batch = torch.tensor(obs_batch).float()
    action_batch = torch.tensor(action_batch).long()
    reward_batch = torch.tensor(reward_batch).float()
    value_batch = torch.tensor(value_batch).float()
    policy_batch = torch.tensor(policy_batch).float()
    weights_batch = torch.tensor(weights_batch).float()
    return obs_batch, action_batch, reward_batch, value_batch, policy_batch,weights_batch,indexes


def make_target(child_visits,root_values,rewards,state_index,discount=0.99, num_unroll_steps=5, td_steps=10):
        # The value target is the discounted root value of the search tree N steps into the future, plus
        # the discounted sum of all rewards until then.
        target_values, target_rewards, target_policies = [], [], []
        for current_index in range(state_index, state_index + num_unroll_steps + 1):
            bootstrap_index = current_index + td_steps
            if bootstrap_index < len(root_values):
                value = root_values[bootstrap_index] * discount ** td_steps
            else:
                value = 0

            for i, reward in enumerate(rewards[current_index:bootstrap_index]):
                value += reward * discount ** i

            if current_index < len(root_values):
                target_values.append(stcat(value))
                target_rewards.append(stcat(rewards[current_index]))
                target_policies.append(child_visits[current_index])

            else:
                # States past the end of games are treated as absorbing states.
                target_values.append(stcat(0))
                target_rewards.append(stcat(0))
                # Note: Target policy is  set to 0 so that no policy loss is calculated for them
                #target_policies.append([0 for _ in range(len(child_visits[0]))])
                target_policies.append(child_visits[0]*0.0)

        return target_values, target_rewards, target_policies


def scalar_reward_loss( prediction, target):
        return -(torch.log(prediction) * target).sum(1)

def scalar_value_loss( prediction, target):
        return -(torch.log(prediction) * target).sum(1)
def update_weights(model, action_space_size, optimizer, replay_buffer,discount,batch_size,num_unroll_steps, td_steps,per ):
    batch = sample_batch(action_space_size,replay_buffer,discount,batch_size,num_unroll_steps, td_steps,per)
    obs_batch, action_batch, target_reward, target_value, target_policy,target_weights,indexes = batch
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    obs_batch = obs_batch.to(device)
    action_batch = action_batch.to(device) 
    target_reward = target_reward.to(device)
    target_value = target_value.to(device)
    target_policy = target_policy.to(device)
    target_weights = target_weights.to(device)
    target_reward_phi =target_reward 
    target_value_phi = target_value

    hidden_state, policy_prob,value  = model.initial_state(obs_batch) # initial model_call #
    
    value_loss = scalar_value_loss(value, target_value_phi[:, 0])
    policy_loss = -(torch.log(policy_prob) * target_policy[:, 0]).sum(1)
    reward_loss = torch.zeros(batch_size, device=device)
    initial_state_values = value.detach()
    gradient_scale = 1 / num_unroll_steps
    for step_i in range(num_unroll_steps):
        hidden_state, reward,policy_prob,value  = model.next_state(hidden_state, action_batch[:, step_i]) 
        #h,pred_reward,pred_policy,pred_value= net.next_state(h,act)
        policy_loss += -(torch.log(policy_prob) * target_policy[:, step_i + 1]).sum(1)
        value_loss += scalar_value_loss(value, target_value_phi[:, step_i + 1])
        reward_loss += scalar_reward_loss(reward, target_reward_phi[:, step_i])
        hidden_state.register_hook(lambda grad: grad * 0.5)

    # optimize
    value_loss_coeff = 1
    loss = (policy_loss + value_loss_coeff * value_loss + reward_loss) # find value loss coefficiet = 1?
    weights = target_weights#/target_weights.max()#dividing by max doesnt work
    weighted_loss = (weights * loss).mean()#1?
    weighted_loss.register_hook(lambda grad: grad * gradient_scale)
    loss = loss.mean()

    optimizer.zero_grad()
    weighted_loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
    optimizer.step()
    if per:
      #remvoing 2nd forward pass can do it also should be chill???
      #updated_h,updated_prob,updated_pred_value= model.inference_initial_state(obs_batch) 
      #return indexes,updated_pred_value.cpu().numpy()
      return indexes,np.abs(catts(initial_state_values.cpu().numpy())-catts(target_value[:, 0].cpu().numpy()))
    return None,None  

def adjust_lr(optimizer, step_count):

    lr_init=0.05
    lr_decay_rate=0.01
    lr_decay_steps=10000
    lr = lr_init * lr_decay_rate ** (step_count / lr_decay_steps)
    lr = max(lr, 0.001)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr
def soft_update(target, source, tau):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)

def get_scalars(new_pred_values):
    vals = []
    for i in range(new_pred_values.shape[0]):
      #print(new_pred_values[i,:].shape)
      vals.append(catts(new_pred_values[i,:]))
    return vals
learning_rate = [0.05]   
def net_train(net,  action_space_size, replay_buffer,discount,batch_size,num_unroll_steps, td_steps,training_steps=1000,per = False):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model =net
    #MuZeroNet(input_size=4, action_space_n=2, reward_support_size=5, value_support_size=5).to(device) #training fresh net
    optimizer = optim.SGD(model.parameters(), lr=learning_rate[0], momentum=0.9,weight_decay=1e-4)
    #training_steps=training_steps=500#20000
    # wait for replay buffer to be non-empty
    while len(replay_buffer) == 0:
        pass

    for step_count in tqdm(range(training_steps)):
        learning_rate[0] = adjust_lr( optimizer, step_count)
        indexes,new_priority = update_weights(model, action_space_size, optimizer, replay_buffer,discount,batch_size,num_unroll_steps, td_steps,per)
        if per:
          #print("new pred val net train",new_pred_values,new_pred_values.shape)
          #new_pred_values = get_scalars(new_pred_values)
          #print("new pred val net train",new_pred_values)
          update_priorites(replay_buffer,indexes,new_priority)

    return model


In [None]:
import gym
import numpy as np
from collections import deque

render = False
episodes_per_train=30
episodes_per_eval =5
buffer =[]
#buffer = deque(maxlen = episodes_per_train)
training_steps=50
max_steps=5000
n_sim= 50
discount = 0.99
batch_size = 512
envs = ['CartPole-v1','MountainCar-v0','LunarLander-v2']
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("training for ",envs[0])
env=gym.make(envs[0])
#env=env.unwrapped
env = ScalingObservationWrapper(env, low=[-2.4, -2.0, -0.42, -3.5], high=[2.4, 2.0, 0.42, 3.5])

s_dim =env.observation_space.shape[0]
print("s_dim: ",s_dim)
a_dim =env.action_space.n
print("a_dim: ",a_dim)
a_bound =1 #env.action_space.high[0]
print("a_bound: ",a_bound)



net = MuZeroNet(input_size=s_dim, action_space_n=a_dim, reward_support_size=5, value_support_size=5).to(device)
targetnet = MuZeroNet(input_size=s_dim, action_space_n=a_dim, reward_support_size=5, value_support_size=5).to(device)
soft_update(target=targetnet, source=net, tau=1)#make them same
for t in range(training_steps):
  if t<20:
    priority = True 
    tr_stp=2000
  else :
    tr_stp=2000
    priority =False  
  buffer =[] # onpolicy 
  for _ in range(episodes_per_train):
    buffer.append(play_game(env,net,n_sim,discount,render,device,a_dim,max_steps,td_steps=10,per=priority))
  print("training from ",len(buffer)," games")  

  print("training with "," priority ",priority," training_steps ",tr_stp," discount ",discount," batch_size ",batch_size)  
  net = net_train(net,  action_space_size=a_dim, replay_buffer=buffer,discount=discount,batch_size=batch_size,num_unroll_steps=5, td_steps=10,training_steps=tr_stp,per = priority)
  for _ in range(episodes_per_eval):
    eval_game(env,net,n_sim,render,device,max_steps)
  


training for  CartPole-v1
s_dim:  4
a_dim:  2
a_bound:  1
DATA collection:played for  15  steps , rewards 15.0
DATA collection:played for  17  steps , rewards 17.0
DATA collection:played for  20  steps , rewards 20.0
DATA collection:played for  11  steps , rewards 11.0
DATA collection:played for  17  steps , rewards 17.0
DATA collection:played for  18  steps , rewards 18.0
DATA collection:played for  21  steps , rewards 21.0
DATA collection:played for  46  steps , rewards 46.0
DATA collection:played for  13  steps , rewards 13.0
DATA collection:played for  14  steps , rewards 14.0
DATA collection:played for  43  steps , rewards 43.0
DATA collection:played for  20  steps , rewards 20.0
DATA collection:played for  11  steps , rewards 11.0
DATA collection:played for  25  steps , rewards 25.0
DATA collection:played for  40  steps , rewards 40.0
DATA collection:played for  20  steps , rewards 20.0
DATA collection:played for  20  steps , rewards 20.0
DATA collection:played for  27  steps , r

  0%|          | 0/2000 [00:00<?, ?it/s]

DATA collection:played for  20  steps , rewards 20.0
training from  30  games
training with   priority  True  training_steps  2000  discount  0.99  batch_size  512


100%|██████████| 2000/2000 [13:13<00:00,  2.52it/s]


Eval:played for  500.0  rewards
Eval:played for  500.0  rewards
Eval:played for  500.0  rewards
Eval:played for  500.0  rewards
Eval:played for  500.0  rewards
DATA collection:played for  500  steps , rewards 500.0
DATA collection:played for  499  steps , rewards 499.0
DATA collection:played for  500  steps , rewards 500.0
DATA collection:played for  500  steps , rewards 500.0
DATA collection:played for  500  steps , rewards 500.0
DATA collection:played for  500  steps , rewards 500.0
DATA collection:played for  500  steps , rewards 500.0
DATA collection:played for  500  steps , rewards 500.0
DATA collection:played for  500  steps , rewards 500.0
DATA collection:played for  500  steps , rewards 500.0
DATA collection:played for  500  steps , rewards 500.0
DATA collection:played for  500  steps , rewards 500.0
DATA collection:played for  349  steps , rewards 349.0
DATA collection:played for  424  steps , rewards 424.0
DATA collection:played for  500  steps , rewards 500.0
DATA collection

  0%|          | 0/2000 [00:00<?, ?it/s]

DATA collection:played for  500  steps , rewards 500.0
training from  30  games
training with   priority  True  training_steps  2000  discount  0.99  batch_size  512


 12%|█▏        | 236/2000 [02:19<18:09,  1.62it/s]