In [500]:
import torch
import torch.nn as nn
import numpy as np 
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import gymnasium as gym
import json

In [501]:
class DQN(torch.nn.Module):
    def __init__(self, n_states,n_actions,n_hidden_units):
        super(DQN, self).__init__()
        self.state_dim = n_states
        self.num_hidden_units = n_hidden_units
        
        self.num_actions = n_actions
        self.layer1 = torch.nn.Linear(self.state_dim, self.num_hidden_units)
        self.layer2 = torch.nn.Linear(self.num_hidden_units,self.num_hidden_units)
        self.layer3 = torch.nn.Linear(self.num_hidden_units, self.num_actions)
        self.dropout = torch.nn.Dropout(p=0.2)
        self.layer_norm = torch.nn.LayerNorm(self.num_hidden_units)
        self.relu = torch.nn.ReLU()

    # Called with either one element to determine next action, or a batch
    # during optimization. Returns tensor([[left0exp,right0exp]...]).
    def forward(self, x):
        x = self.layer1(x)
        #x = self.dropout(x)
        #x = self.layer_norm(x)
        x = self.relu(x)
        x = self.layer2(x)
        #x = self.dropout(x)
        #x = self.layer_norm(x)
        x = self.relu(x)
        x = self.layer3(x)
        return x

In [502]:
def buffer_to_JSON(replay_buffer,buffer_size):
    # Experience Replay stores transitions as numpy arrays for performance reasons
    # Need to transform this to a list in order to write properly to JSON
    """ Transforms replay buffer data to JSON data type
    """
    j = []
    for i in range(buffer_size):
        transition = {
            "state":replay_buffer.states[i].tolist(),
            "action":int(replay_buffer.actions[i]),
            "reward":float(replay_buffer.rewards[i]),
            "terminal":int(replay_buffer.terminals[i]),
            "next_state":replay_buffer.next_states[i].tolist()
        }
        j.append(transition)
    return j

def json_to_buffer(json_data,replay_buffer):
    """ Transfers JSON data to replay buffer directly
    """
    assert len(json_data)<=replay_buffer.max_size
    for i,j in enumerate(json_data):
        state = j["state"]
        action = j["action"]
        reward = j["reward"]
        terminal = j["terminal"]
        next_state=j["next_state"]
        replay_buffer.states[i] = state[0]
        replay_buffer.actions[i] = action
        replay_buffer.rewards[i] = reward
        replay_buffer.terminals[i] = terminal
        replay_buffer.next_states[i] = next_state[0]
    return len(json_data)
    
def save_json(path, data):
    """save json data
    Args:
        path (Path): path to json file
        data (dict): data to be saved in json file
    """
    with open(path,"w") as f:
        json.dump(data,f,indent=2)

def read_json(path):
    """read json file for stored ReplayBuffer data
    Args:
        path (path): path to json file
    Returns:
        data (list): data to be used for Expereince Replay
    """
    f = open(path)
    json_data = json.load(f)
    return json_data


In [503]:
class ReplayBuffer:
    def __init__(self, buffer_size, minibatch_size, observation_size):
        #self.buffer = []
        self.minibatch_size = minibatch_size
        #random.seed(seed)
        self.max_size = buffer_size
        self.pos = 0
        self.full = False
        self.states = np.zeros((self.max_size,observation_size))
        self.next_states = np.zeros((self.max_size,observation_size))
        self.actions = np.zeros(self.max_size,dtype=np.int8)
        self.rewards = np.zeros(self.max_size)
        self.terminals = np.zeros(self.max_size,dtype=np.int8)
        self.rand_generator = np.random.RandomState()

    def set_seed(self,seed=1):
        self.rand_generator = np.random.RandomState(seed)


    def append(self, state, action, reward, terminal, next_state):
        """
        Args:
            state (Numpy array): The state.              
            action (integer): The action.
            reward (float): The reward.
            terminal (integer): 1 if the next state is a terminal state and 0 otherwise.
            next_state (Numpy array): The next state.           
        """
        self.states[self.pos] = state
        self.actions[self.pos] = action
        self.rewards[self.pos] = reward
        self.terminals[self.pos] = terminal
        self.next_states[self.pos] = next_state
        self.pos += 1
        if(self.pos==self.max_size):
            self.pos = 0
            self.full = True


    def sample(self):
        """
        Returns:
            A list of transition tuples including state, action, reward, terinal, and next_state
        """
        if(self.full):
            idxs = self.rand_generator.randint(0,self.max_size,size=self.minibatch_size) 
        else:
            idxs = self.rand_generator.randint(0,self.pos,size=self.minibatch_size)
        sample_ = [self.states[idxs],self.actions[idxs],self.rewards[idxs],self.terminals[idxs],
                   self.next_states[idxs]]
        #print(sample_)
        return sample_

    def size(self):
        if(self.full):
            return self.max_size
        else:
            return self.pos
        
    def save_buffer(self,path="ReplayBuffer.JSON"):
        json_data = buffer_to_JSON(self,self.size())
        save_json(path,json_data)

    def load_buffer(self,path="ReplayBuffer.JSON"):
        self.reset()
        json_data = read_json(path)
        data_size=json_to_buffer(json_data,self)
        if(data_size==self.max_size):
            self.full = True
            self.pos = 0
        else:
            self.pos = data_size
        
    
    def reset(self):
        self.full = False
        self.pos = 0

In [504]:
def get_td_error(states, next_states, actions, rewards, discount, terminals, target_network, current_q_network):
    with torch.no_grad():
        # The idea of Double DQN is to get max actions from current network
        # and to get Q values from target_network for next states. 
        q_next_mat = current_q_network(next_states)
        max_actions = torch.argmax(q_next_mat,1)
        double_q_mat = target_network(next_states)
    batch_indices = torch.arange(q_next_mat.shape[0])
    double_q_max = double_q_mat[batch_indices,max_actions]
    target_vec = rewards+discount*double_q_max*(torch.ones_like(terminals)-terminals)
    q_mat = current_q_network(states)
    batch_indices = torch.arange(q_mat.shape[0])
    q_vec = q_mat[batch_indices,actions]
    #delta_vec = target_vec - q_vec
    return target_vec,q_vec

In [505]:
def optimize_network(experiences, discount, optimizer, target_network, current_q_network,device):
    """
    Args:
        experiences (Numpy array): The batch of experiences including the states, actions,
                                   rewards, terminals, and next_states.
        discount (float): The discount factor.
        network (ActionValueNetwork): The latest state of the network that is getting replay updates.
        current_q (ActionValueNetwork): The fixed network used for computing the targets,
                                        and particularly, the action-values at the next-states.
    """
    # Get states, action, rewards, terminals, and next_states from experiences
    states = experiences[0]
    actions = experiences[1]
    rewards = experiences[2]
    terminals = experiences[3]
    next_states = experiences[4]
    # numpy arrays to tensors and move to device (cpu or gpu)
    states = torch.tensor(states,dtype=torch.float32,device=device)
    next_states = torch.tensor(next_states,dtype=torch.float32,device=device)
    rewards = torch.tensor(rewards,dtype=torch.float32,device=device)
    terminals = torch.tensor(terminals,dtype=torch.int,device=device)
    actions = torch.tensor(actions,dtype=torch.int,device=device)
 
    # Compute TD error using the get_td_error function
    # Note that q_vec is a 1D array of shape (batch_size)
    target_vec,q_vec = get_td_error(states, next_states, actions, rewards, discount, terminals, target_network, current_q_network)
    loss_fun = torch.nn.MSELoss()
    loss = loss_fun(target_vec,q_vec)
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(current_q_network.parameters(), 10)
    optimizer.step()
    return loss.detach().cpu().numpy()

In [506]:
class DQNAgent:
    def __init__(self,buffer_config):
        self.name = "DQN"
        self.device = None
        self.rand_generator = np.random.RandomState() # random seed. Later can be changed by using set_seed method
        self.replay_buffer = ReplayBuffer(buffer_config["replay_buffer_size"],
                                          buffer_config["minibatch_sz"],
                                          buffer_config["observation_size"])

    def set_seed(self,seed=1):
        self.rand_generator = np.random.RandomState(seed)
        #random.seed(self.seed)
    
    def set_epsilon_decay(self,n_steps=10000):
        self.eps_decay = 1. - 1./n_steps

    def set_device(self,device="cpu"):
        if(device=="cuda"):
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = torch.device("cpu")
    
    def agent_init(self, agent_config):
        if(self.device==None):
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.state_dim = agent_config["network_config"].get("state_dim")
        self.num_hidden_layers = agent_config["network_config"].get("num_hidden_units")
        self.num_actions = agent_config["network_config"].get("num_actions")
        
        self.network_type = agent_config["network_config"].get("network_type")
        
        self.q_network = DQN(self.state_dim,self.num_actions,self.num_hidden_layers).to(self.device)
        self.target_network = DQN(self.state_dim,self.num_actions,self.num_hidden_layers).to(self.device)
        self.target_network.load_state_dict(self.q_network.state_dict())
        
        self.step_size = agent_config['step_size']
        self.double_dqn = agent_config['double_dqn']
        self.num_actions = agent_config['network_config']['num_actions']
        self.num_replay = agent_config['num_replay_updates_per_step']
        self.discount = agent_config['gamma']
        self.epsilon = agent_config['epsilon']
        self.time_step = 0
        self.update_freq = agent_config['update_freq']
        self.loss = []
        self.episode_rewards = []
        self.loss_capacity = 5_000
        self.warmup_steps = agent_config['warmup_steps']
        self.eps_decay = 0.9999
        self.last_state = None
        self.last_action = None
        self.sum_rewards = 0
        self.episode_steps = 0
        self.optimizer = torch.optim.Adam(self.q_network.parameters(),lr=self.step_size,weight_decay=0.01)
        self.soft_update = True
        self.tau = 0.005

    def greedy_policy(self,state,epsilon=0.001):
        state = torch.tensor(state,dtype=torch.float32,device=self.device)
        a = self.rand_generator.rand()
        if(a>=epsilon):
            with torch.no_grad():
                action_values = self.q_network(state)
            action = torch.argmax(action_values).item()
        else:
            action = self.rand_generator.choice(self.num_actions)
        return action

    def epsilon_greedy_policy(self,state):
        epsilon = np.max([self.epsilon,0.05]) 
        state = torch.tensor(state,dtype=torch.float32,requires_grad=False,device=self.device)
        self.epsilon *= self.eps_decay
        a = self.rand_generator.rand()
        if(a>=epsilon):
            with torch.no_grad():
                action_values = self.q_network(state)
            action = torch.argmax(action_values).item()
        else:
            action = self.rand_generator.choice(self.num_actions)
        return action

    # Work Required: No.
    def agent_start(self, state):
        """The first method called when the experiment starts, called after
        the environment starts.
        Args:
            state (Numpy array): the state from the
                environment's evn_start function.
        Returns:
            The first action the agent takes.
        """
        self.sum_rewards = 0
        self.episode_steps = 0
        self.last_state = state #torch.tensor(np.array([state]),dtype=torch.float32,device=self.device)
        self.last_action = self.epsilon_greedy_policy(self.last_state)
        self.time_step += 1
        return self.last_action

    def agent_step(self, reward, state):
        """A step taken by the agent.
        Args:
            reward (float): the reward received for taking the last action taken
            state (Numpy array): the state from the
                environment's step based, where the agent ended up after the
                last step
        Returns:
            The action the agent is taking.
        """
        self.sum_rewards += reward
        self.episode_steps += 1
        #state = torch.tensor(np.array([state]),dtype=torch.float32,device=self.device)
        action = self.epsilon_greedy_policy(state)
        terminal = False
        self.replay_buffer.append(self.last_state, self.last_action, reward, terminal, state)

        # Perform replay steps:
        if self.replay_buffer.size() > self.replay_buffer.minibatch_size and self.time_step>self.warmup_steps: # and self.episode_steps%self.replay_buffer.minibatch_size==0:
            for _ in range(self.num_replay):
                # Get sample experiences from the replay buffer
                experiences = self.replay_buffer.sample()
                loss = optimize_network(experiences, self.discount, self.optimizer, self.target_network, self.q_network,self.device)
                if(len(self.loss)>=self.loss_capacity):
                    del self.loss[0]
                self.loss.append(loss)

        if(self.soft_update):
            self.polyak_update_target_network()
        else:
            if(self.time_step%self.update_freq==0):
                #print("Updating network")
                self.update_target_network()
       
        self.last_state = None
        self.last_action = None

        ### END CODE HERE
        # your code here
        self.last_state = state
        self.last_action = action
        self.time_step += 1
        return action

    def agent_end(self, reward):
        """Run when the agent terminates.
        Args:
            reward (float): the reward the agent received for entering the
                terminal state.
        """
        self.sum_rewards += reward
        self.episode_steps += 1
        self.episode_rewards.append(self.sum_rewards)
        # Set terminal state to an array of zeros
        state = np.zeros_like(self.last_state) #torch.zeros_like(self.last_state,device=self.device)

        # Append new experience to replay buffer
        # Note: look at the replay_buffer append function for the order of arguments
        end_loss = 0
        # your code here
        terminal = True
        self.replay_buffer.append(self.last_state, self.last_action, reward, terminal, state)
        # Perform replay steps:
        if self.replay_buffer.size() > self.replay_buffer.minibatch_size:
            for _ in range(self.num_replay):
                experiences = self.replay_buffer.sample()
                loss = optimize_network(experiences, self.discount, self.optimizer, self.target_network, self.q_network,self.device)
                end_loss = loss
                if(len(self.loss)>=self.loss_capacity):
                    del self.loss[0]
                self.loss.append(loss)
        
        if(self.soft_update):
            self.polyak_update_target_network()
        else:
            if(self.time_step%self.update_freq==0):
                #print("Updating network")
                self.update_target_network()
                
        self.time_step += 1
    
        return end_loss

    def agent_message(self, message):
        if message == "get_sum_reward":
            return self.sum_rewards
        else:
            raise Exception("Unrecognized Message!")
        
    def update_target_network(self):
        self.target_network.load_state_dict(self.q_network.state_dict())

    def polyak_update_target_network(self):
        one = torch.ones(1, requires_grad=False).to(self.device)
        for param, target_param in zip(self.q_network.parameters(), self.target_network.parameters()):
            target_param.data.mul_(1-self.tau)
            target_param.data.addcmul_(param.data, one, value=self.tau)

    def get_loss(self):
        return np.average(np.array(self.loss))

In [507]:
class offlineDQNAgent:
    def __init__(self,buffer_config):
        self.name = "offlineDQN"
        self.device = None
        self.rand_generator = np.random.RandomState() # random seed. Later can be changed by using set_seed method
        self.replay_buffer = ReplayBuffer(buffer_config["replay_buffer_size"],
                                          buffer_config["minibatch_sz"],
                                          buffer_config["observation_size"])
        self.buffer_path = "experienceReplay.json"

    def set_seed(self,seed=1):
        self.rand_generator = np.random.RandomState(seed)
        #random.seed(self.seed)
    
    def set_epsilon_decay(self,n_steps=10000):
        self.eps_decay = 1. - 1./n_steps

    def set_buffer_path(self,path="experienceReplay.json"):
        self.buffer_path = path
    
    def mean_loss(self):
        if(len(self.loss)==0):
            return -1
        else:
            loss = np.mean(self.loss)
            return loss

    def set_device(self,device="cpu"):
        if(device=="cuda"):
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = torch.device("cpu")
    
    def agent_init(self, agent_config):
        if(self.device==None):
            self.device = torch.device("cpu")

        self.state_dim = agent_config["network_config"].get("state_dim")
        self.num_hidden_layers = agent_config["network_config"].get("num_hidden_units")
        self.num_actions = agent_config["network_config"].get("num_actions")
        self.network_type = agent_config["network_config"].get("network_type")
        
        self.q_network = DQN(self.state_dim,self.num_actions,self.num_hidden_layers).to(self.device)
        self.target_network = DQN(self.state_dim,self.num_actions,self.num_hidden_layers).to(self.device)
        self.target_network.load_state_dict(self.q_network.state_dict())
        
        self.step_size = agent_config['step_size']
        self.num_actions = agent_config['network_config']['num_actions']
        self.num_replay = agent_config['num_replay_updates_per_step']
        self.discount = agent_config['gamma']
        
        self.time_step = 0
        
        self.loss = []
        self.episode_rewards = []
        self.loss_capacity = 5_000
        self.epsilon = agent_config['epsilon']
        self.time_step = 0
        self.update_freq = agent_config['update_freq']
    
        self.eps_decay = 0.9999
        self.last_state = None
        self.last_action = None
        self.sum_rewards = 0
        self.episode_steps = 0
        self.optimizer = torch.optim.Adam(self.q_network.parameters(),lr=self.step_size,weight_decay=0.01)
        self.tau = 0.005
        self.replay_buffer.load_buffer(self.buffer_path)

    def epsilon_greedy_policy(self,state):
        epsilon = np.max([self.epsilon,0.05]) 
        state = torch.tensor(state,dtype=torch.float32,requires_grad=False,device=self.device)
        self.epsilon *= self.eps_decay
        a = self.rand_generator.rand()
        if(a>=epsilon):
            with torch.no_grad():
                action_values = self.q_network(state)
            action = torch.argmax(action_values).item()
        else:
            action = self.rand_generator.choice(self.num_actions)
        return action

    def greedy_policy(self,state,epsilon=0.001):
        state = torch.tensor(state,dtype=torch.float32,device=self.device)
        a = self.rand_generator.rand()
        if(a>=epsilon):
            with torch.no_grad():
                action_values = self.q_network(state)
            action = torch.argmax(action_values).item()
        else:
            action = self.rand_generator.choice(self.num_actions)
        return action
    
    def load_buffer(self):
        self.replay_buffer.load_buffer()

    def learn_offline(self):
        self.episode_steps += 1
        # Perform replay steps:
        #if self.replay_buffer.size() > self.replay_buffer.minibatch_size and self.time_step>self.warmup_steps: # and self.episode_steps%self.replay_buffer.minibatch_size==0:
        for _ in range(self.num_replay):
            # Get sample experiences from the replay buffer
            experiences = self.replay_buffer.sample()
            #print(experiences)
            loss = optimize_network(experiences, self.discount, self.optimizer, self.target_network, self.q_network,self.device)
            if(len(self.loss)>=self.loss_capacity):
                del self.loss[0]
            self.loss.append(loss)
        self.update_target_network()

    # Work Required: No.
    def agent_start(self, state):
        """The first method called when the experiment starts, called after
        the environment starts.
        Args:
            state (Numpy array): the state from the
                environment's evn_start function.
        Returns:
            The first action the agent takes.
        """
        self.sum_rewards = 0
        self.episode_steps = 0
        self.last_state = state #torch.tensor(np.array([state]),dtype=torch.float32,device=self.device)
        self.last_action = self.epsilon_greedy_policy(self.last_state)
        self.time_step += 1
        return self.last_action

    def agent_step(self, reward, state):
        """A step taken by the agent.
        Args:
            reward (float): the reward received for taking the last action taken
            state (Numpy array): the state from the
                environment's step based, where the agent ended up after the
                last step
        Returns:
            The action the agent is taking.
        """
        self.sum_rewards += reward
        self.episode_steps += 1
        #state = torch.tensor(np.array([state]),dtype=torch.float32,device=self.device)
        action = self.epsilon_greedy_policy(state)
        terminal = False
        self.replay_buffer.append(self.last_state, self.last_action, reward, terminal, state)

        # Perform replay steps:
        if self.replay_buffer.size() > self.replay_buffer.minibatch_size: # and self.episode_steps%self.replay_buffer.minibatch_size==0:
            for _ in range(self.num_replay):
                # Get sample experiences from the replay buffer
                experiences = self.replay_buffer.sample()
                loss = optimize_network(experiences, self.discount, self.optimizer, self.target_network, self.q_network,self.device)
                if(len(self.loss)>=self.loss_capacity):
                    del self.loss[0]
                self.loss.append(loss)
            self.update_target_network()
       
        self.last_state = None
        self.last_action = None

        ### END CODE HERE
        # your code here
        self.last_state = state
        self.last_action = action
        self.time_step += 1
        return action

    def agent_end(self, reward):
        """Run when the agent terminates.
        Args:
            reward (float): the reward the agent received for entering the
                terminal state.
        """
        self.sum_rewards += reward
        self.episode_steps += 1
        self.episode_rewards.append(self.sum_rewards)
        # Set terminal state to an array of zeros
        state = np.zeros_like(self.last_state) #torch.zeros_like(self.last_state,device=self.device)

        # Append new experience to replay buffer
        # Note: look at the replay_buffer append function for the order of arguments
        end_loss = 0
        # your code here
        terminal = True
        self.replay_buffer.append(self.last_state, self.last_action, reward, terminal, state)
        # Perform replay steps:
        if self.replay_buffer.size() > self.replay_buffer.minibatch_size:
            for _ in range(self.num_replay):
                experiences = self.replay_buffer.sample()
                loss = optimize_network(experiences, self.discount, self.optimizer, self.target_network, self.q_network,self.device)
                end_loss = loss
                if(len(self.loss)>=self.loss_capacity):
                    del self.loss[0]
                self.loss.append(loss)
            self.update_target_network()
                
        self.time_step += 1
    
        return end_loss

    def update_target_network(self):
        one = torch.ones(1, requires_grad=False).to(self.device)
        for param, target_param in zip(self.q_network.parameters(), self.target_network.parameters()):
            target_param.data.mul_(1-self.tau)
            target_param.data.addcmul_(param.data, one, value=self.tau)

    def get_loss(self):
        return np.average(np.array(self.loss))

In [508]:
class RL:
    def __init__(self) -> None:
        self.name = "ButaChanRL"
        self.mean_episode_length = 0
        self.mean_episode_rew = 0
        self.mean_loss= 0
        self.step = 0
        self.output_step = 0
        self.epsiode_rewards = []
        self.episode_lens = []
        self.loss = []
        
        self.model_dir = "./models/"
        self.num_episodes = 0
        self.average_over = 20

    def set_output_step(self,output_step):
        self.output_step = output_step

    def set_model_dir(self,name):
        self.model_dir = name

    def create_model_dir(self):
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)

    def load_model(self,model,filename="model.weights"):
        if(model.name == "DQN"):
            model.target_network.load_state_dict(torch.load(filename))
            model.q_network.load_state_dict(torch.load(filename))
        elif(model.name=="ActorCritic"):
            model.actor_critic_network.load_state_dict(torch.load(filename))
        else:
            NotImplementedError()

    def save_model(self,model,filename="model.weights"):
        if(model.name == "DQN"):
            torch.save(model.q_network.state_dict(),filename)
        elif(model.name=="ActorCritic"):
            torch.save(model.actor_critic_network.state_dict(),filename)
   
    def plot_live(self,data,n_mean=20,plot_start=20):
        plt.ion()
        plt.figure(1)
        plot_data = torch.tensor(data, dtype=torch.float,requires_grad=False)
        plt.clf()
        plt.title('Training...')
        plt.xlabel('Episode')
        plt.ylabel('Episode Reward')
        plt.plot(plot_data.numpy(),"o")
        # Take 100 episode averages and plot them too
        if len(plot_data ) >= plot_start:
            means = plot_data .unfold(0, n_mean, 1).mean(1).view(-1)
            means = torch.cat((torch.zeros(n_mean), means))
            plt.plot(means.numpy())
        plt.pause(0.1)  # pause a bit so that plots are updated

    def episode_summarize(self,episode,episode_reward):
        print(f"Episode: {episode}, Reward: {episode_reward}")

    def summarize(self):
        self.mean_episode_length = 0
        self.mean_episode_rew = 0
        if(len(self.episode_lens)>0):
            if(len(self.episode_lens)>self.average_over):
                self.mean_episode_length = np.average(self.episode_lens[-self.average_over:-1])
                self.mean_episode_rew = np.average(self.epsiode_rewards[-self.average_over:-1])
            else:
                self.mean_episode_length = np.average(self.episode_lens)
                self.mean_episode_rew = np.average(self.epsiode_rewards)
        self.mean_loss = 0
        if(len(self.loss)>0):
            self.mean_loss = np.average(self.loss)
        print(f"Step:{self.step}, Episode:{self.num_episodes} Mean_Epi_Len: {self.mean_episode_length:5.2f},Mean_Epi_Rew {self.mean_episode_rew:5.2f}, Loss: {self.mean_loss:5.2f}")

    def learn(self,agent,env,agent_parameters,NSTEPS=10000,visualize=False,output_step=1000,save_best_weights=False):
        epsiode = 1
        
        self.output_step = output_step
        # prepare agent
        agent.agent_init(agent_parameters)
        
        #agent.set_epsilon_decay(NSTEPS//2)
        state,info= env.reset() 
        #state = torch.tensor(state,dtype=torch.float32,device=agent.device)
        # choose initial action based on agent's results
        action = agent.agent_start(state)
        done = False
        epsiode_reward = 0
        episode_len = 0
        
        for i in tqdm(range(1,NSTEPS+1)):
            self.step = i
            #print(action)
            state,reward,terminated,truncated,info=env.step(action)
            #state = torch.tensor(state,dtype=torch.float32,device=agent.device)
            #state = torch.unsqueeze(state,0)
            #print(i,state,reward,action,done)
            epsiode_reward += reward
            done = terminated or truncated
            if(self.output_step>0 and self.step%self.output_step==0):
                self.summarize()
                #print(f"Epsilon {agent.epsilon:>5.3f}")
                if(visualize):
                    if(len(self.epsiode_rewards)>0):
                        self.plot_live(self.epsiode_rewards)
            if(done):
                loss = agent.agent_end(reward)
                #print("Loss length",len(agent.loss))
                self.loss.append(loss)
                
                if(save_best_weights):
                    self.create_model_dir()
                    if(len(self.epsiode_rewards)==0):
                        model_name = self.model_dir+f"model_{self.step}"
                        self.save_model(agent,model_name)
                    else:
                        if(epsiode_reward>max(self.epsiode_rewards)):
                            model_name = self.model_dir+f"model_{self.step}"
                            self.save_model(agent,model_name)
                self.epsiode_rewards.append(epsiode_reward)
                self.episode_lens.append(episode_len)
                epsiode += 1
                self.num_episodes += 1
                # restart next episode
                state,_= env.reset() 
                #state = torch.tensor(state,dtype=torch.float32,device=agent.device)
                #state = torch.unsqueeze(state,0)
                action = agent.agent_start(state)
                done = False
                epsiode_reward = 0
                episode_len = 0
            else:
                action = agent.agent_step(reward,state)
                episode_len+=1
        return agent
    
    def learn_offline(self,agent,agent_parameters,NSTEPS=10000,output_step=1000):
        self.output_step = output_step
        # prepare agent
        agent.agent_init(agent_parameters)
        for i in tqdm(range(1,NSTEPS+1)):
            agent.learn_offline()
            if(i%1000==0):
                print("Offline mean loss: ",agent.mean_loss())
        return agent
    

    def evaluate(self,agent,env,n_episodes=10,seed=1,visualize=False,eval_espilon=0.001):
        epsiode_rewards = []
        for episode in range(1,n_episodes+1):
            state,info = env.reset()
            #state = torch.tensor(state,dtype=torch.float32,device=agent.device)
            #state = torch.unsqueeze(state,0)
            #action = agent.greedy_policy(state,eval_espilon)
            action = agent.epsilon_greedy_policy(state)
            done = False
            epsiode_reward = 0
            episode_len = 0
            while not done:
                state,reward,terminated,truncated,info=env.step(action)
                #state = torch.tensor(state,dtype=torch.float32,device=agent.device)
                #state = torch.unsqueeze(state,0)
                epsiode_reward += reward
                done = terminated or truncated
                #action = agent.greedy_policy(state,eval_espilon)
                action = agent.epsilon_greedy_policy(state)
                episode_len += 1
            epsiode_rewards.append(epsiode_reward)
            self.episode_summarize(episode,epsiode_reward)
            #if(visualize):
            #    env.summarize()
        mean_rew = np.average(epsiode_rewards)
        std_rew = np.std(epsiode_rewards)
        return (mean_rew,std_rew) 

In [509]:
env = gym.make("CartPole-v1") # just create gym environment natively
torch.set_num_threads(1)  
n_state = env.observation_space.shape[0]
n_actions = env.action_space.n

agent_parameters = { # this is where you can choose neural networks adjust hyperparameters 
'network_config': {
    'state_dim': n_state,
    'num_hidden_units': 128,
    'num_actions': n_actions,
    "network_type":"dqn"
},
'replay_buffer_size': 1_000_000,
'minibatch_sz': 32,
'observation_size':n_state,
'num_replay_updates_per_step': 1,
"step_size": 3e-4,
'gamma': 0.99,
'epsilon': 1,
'update_freq':100,
'warmup_steps':1000,
'double_dqn':False
}
buffer_parameters = {
    "replay_buffer_size":1_000_000,
    "minibatch_sz":32,
    "observation_size":n_state
}
#agent = DQNAgent(buffer_parameters) # You can change to other agents such as SARSA, ActorCritic
agent = offlineDQNAgent(buffer_parameters)
agent.set_device(device="cpu")
rl = RL() # You need this to control the overall process of training
#s,_=env.reset(seed=1)
#agent.set_seed(1)
#agent.replay_buffer.set_seed(1)
agent.replay_buffer.load_buffer("replay_buffer_2024.json")
trained_agent = rl.learn_offline(agent,agent_parameters)
#print(agent.replay_buffer.size())
#trained_agent = rl.learn(agent,env,agent_parameters,NSTEPS=100_000,output_step=1000,visualize=True,save_best_weights=False) # training loop



  1%|          | 58/10000 [00:00<00:35, 283.83it/s]

[array([[-6.78643063e-02, -5.46489596e-01,  1.16474174e-01,
         9.90149319e-01],
       [ 3.35760787e-02, -2.89423149e-02,  3.26016955e-02,
        -3.27154584e-02],
       [-5.01716370e-03, -7.52492726e-01, -7.99900107e-03,
         8.44170332e-01],
       [ 4.08975855e-02, -1.70508832e-01, -8.06274489e-02,
         1.18869096e-01],
       [ 5.83583713e-02,  3.76314610e-01, -2.43161283e-02,
        -5.62329471e-01],
       [ 1.64453134e-01,  1.58360749e-01, -5.48701920e-03,
         1.02495849e-01],
       [-1.36812508e+00, -1.23909378e+00, -8.53576958e-02,
        -5.97624421e-01],
       [-8.91038775e-02, -1.63956061e-01,  5.82424849e-02,
         2.30030924e-01],
       [-3.27616513e-01, -8.83857310e-01, -8.37309435e-02,
         1.98153928e-01],
       [-9.05046821e-01, -9.47263062e-01, -1.19357176e-01,
        -2.06533801e-02],
       [-7.75983036e-02,  1.83490917e-01,  1.28075436e-01,
        -1.18117824e-01],
       [ 4.41548899e-02,  4.22282785e-01, -1.46427024e-02,
     

  1%|          | 116/10000 [00:00<00:34, 282.44it/s]

[array([[-4.26085927e-02, -2.20026851e-01,  2.06812825e-02,
         1.61469474e-01],
       [ 3.01708728e-02, -2.35834606e-02,  1.75795391e-01,
         7.21766949e-01],
       [ 3.36607080e-03,  2.22198740e-01, -2.72981767e-02,
        -2.65587598e-01],
       [-6.38523549e-02, -5.86713612e-01,  4.23819013e-02,
         6.66036725e-01],
       [-9.86942947e-02,  1.97804987e-01,  1.93910152e-01,
         4.92805392e-02],
       [-1.36812508e+00, -1.23909378e+00, -8.53576958e-02,
        -5.97624421e-01],
       [ 3.75481024e-02, -5.68697333e-01, -8.11452866e-02,
         5.84520340e-01],
       [-6.27690479e-02,  3.63030583e-01,  5.54154366e-02,
        -4.24542218e-01],
       [ 4.53649871e-02,  4.35001850e-01, -4.62013595e-02,
        -7.05429316e-01],
       [ 3.31267357e-01, -2.44471610e-01,  1.66949496e-01,
         1.04172027e+00],
       [ 6.20115362e-03, -1.96512774e-01, -1.01409573e-02,
         2.72022814e-01],
       [ 8.56432766e-02,  2.14509904e-01, -3.58040594e-02,
     

  2%|▏         | 175/10000 [00:00<00:34, 287.51it/s]

[array([[-9.96743292e-02, -5.92664182e-01,  3.98694426e-02,
         8.52955997e-01],
       [-4.23801750e-01, -5.95611870e-01, -1.50823081e-02,
         2.71569520e-01],
       [-5.13369143e-01, -1.28766775e-01,  2.07149480e-02,
        -5.46608508e-01],
       [-1.86145268e-02,  1.64573103e-01, -1.51982773e-02,
        -3.16233605e-01],
       [ 3.69451158e-02,  9.66491163e-01, -6.34095520e-02,
        -1.47329390e+00],
       [-2.05407548e-03,  1.80600569e-01, -1.08786626e-02,
        -2.14570522e-01],
       [ 1.55346483e-01,  3.98256838e-01, -6.13443777e-02,
        -2.53046483e-01],
       [ 7.24872425e-02,  3.71529162e-01,  5.86790256e-02,
        -3.64689790e-02],
       [ 1.70461163e-01, -3.58268946e-01, -3.57318744e-02,
         2.22602978e-01],
       [-5.26750050e-02, -7.54069567e-01, -9.37599540e-02,
         2.97665864e-01],
       [-6.86019287e-02,  2.13956926e-02,  7.87807256e-03,
         9.34656560e-02],
       [ 1.66566715e-01, -7.53138840e-01, -1.75752565e-01,
     

  2%|▏         | 204/10000 [00:00<00:34, 284.13it/s]

[array([[-1.54692505e-03,  4.06633735e-01, -1.29815387e-02,
        -5.78241348e-01],
       [ 3.78263257e-02, -4.20793802e-01, -4.68329377e-02,
         4.27235991e-01],
       [ 3.42644542e-01,  4.11180735e-01, -8.61320496e-02,
        -1.12392053e-01],
       [ 2.90598303e-01,  4.95036632e-01,  1.53167605e-01,
         3.54026824e-01],
       [ 6.68935925e-02,  8.04494321e-01, -7.91933686e-02,
        -1.22635400e+00],
       [-5.43399513e-01,  1.82666510e-01,  1.64027326e-02,
        -8.50744009e-01],
       [ 7.95482025e-02, -4.88841049e-02, -3.68426256e-02,
         2.04979584e-01],
       [-6.89283460e-02, -7.37086296e-01,  8.82392526e-02,
         1.23462391e+00],
       [ 5.89085780e-02,  3.71846706e-01, -1.95718985e-02,
        -5.32350063e-01],
       [ 2.31052265e-02, -2.15389624e-01,  5.43219857e-02,
         3.25536281e-01],
       [ 1.35681070e-02, -3.53179753e-01, -4.64951769e-02,
         5.26598871e-01],
       [ 7.30973408e-02,  7.87727773e-01, -7.07396567e-02,
     

  3%|▎         | 262/10000 [00:00<00:35, 276.32it/s]

[array([[ 2.45258808e-01,  1.53151378e-01, -5.35623450e-03,
         2.21455008e-01],
       [ 1.55588984e-02,  4.16697592e-01, -4.80875187e-03,
        -6.19883776e-01],
       [-9.17313397e-02, -3.97149235e-01,  2.88429037e-02,
         5.51326931e-01],
       [ 2.33335897e-01, -1.77922562e-01, -8.14088359e-02,
         2.56568700e-01],
       [ 2.05946472e-02, -1.72050551e-01, -3.02879256e-03,
         2.98895270e-01],
       [-4.46331687e-02, -4.75560222e-03,  2.95336191e-02,
        -2.02176403e-02],
       [ 2.57160957e-03,  1.49419174e-01,  1.62698980e-03,
        -1.58029050e-01],
       [ 7.74951279e-02,  9.16747272e-01, -3.34415399e-02,
        -1.07631755e+00],
       [-1.62070885e-03, -5.36952138e-01,  3.07906549e-02,
         8.69575679e-01],
       [ 9.76793095e-03, -1.58266217e-01,  7.61341602e-02,
         4.89634573e-01],
       [-2.64026243e-02,  1.58456296e-01, -4.09285128e-02,
        -3.16737801e-01],
       [-5.38128950e-02, -2.27219760e-01, -3.31110880e-02,
     

  3%|▎         | 320/10000 [00:01<00:38, 252.20it/s]

[array([[-8.60617459e-01, -3.68167698e-01, -5.06116264e-02,
        -7.61013865e-01],
       [ 1.05559448e-04,  2.49098316e-02,  6.35265037e-02,
         3.95476758e-01],
       [ 4.80506420e-01,  1.07175243e+00,  1.03419788e-01,
        -2.28969902e-02],
       [-3.56258526e-02, -2.99986392e-01, -1.47075698e-01,
        -4.44489270e-01],
       [-1.44160718e-01, -4.13229525e-01,  1.54776111e-01,
         4.26182270e-01],
       [-3.11706103e-02,  1.56464174e-01, -2.95962933e-02,
        -2.72662759e-01],
       [ 3.15027386e-02,  1.99174926e-01, -2.80711446e-02,
        -3.13408524e-01],
       [-1.14135727e-01,  2.48801000e-02, -1.61124811e-01,
        -4.02552485e-01],
       [ 2.15443745e-02,  9.26694393e-01, -7.45273829e-02,
        -1.33697534e+00],
       [-5.15750051e-03, -2.19881758e-01,  3.25015932e-03,
         2.06072137e-01],
       [ 1.89328578e-03, -1.76397398e-01,  1.26774147e-01,
         8.14601183e-01],
       [-2.08023772e-01,  3.96583825e-01,  1.07675409e-02,
     

  4%|▍         | 377/10000 [00:01<00:36, 265.07it/s]

[array([[-5.25051415e-01, -1.07234073e+00, -4.60068742e-03,
         3.43694001e-01],
       [ 2.28042360e-02,  3.82603496e-01,  2.53291931e-02,
        -4.13534999e-01],
       [ 8.66645500e-02,  8.28400016e-01, -1.10948890e-01,
        -1.21482456e+00],
       [ 1.47869334e-01,  5.27169466e-01,  4.41976599e-02,
        -4.42100577e-02],
       [-9.82412230e-03,  4.53534573e-02, -3.27255689e-02,
        -1.54769510e-01],
       [-4.85994108e-02, -2.28129819e-01, -3.35242487e-02,
         1.61862880e-01],
       [ 2.31476463e-02, -9.88851845e-01, -1.90372895e-02,
         1.08503771e+00],
       [-9.63659398e-03, -8.90531461e-04,  5.07484376e-02,
         1.55977607e-01],
       [-3.15268920e-03, -1.88396573e-02,  4.42402102e-02,
        -4.58971523e-02],
       [ 1.34961391e-02,  4.11934525e-01, -1.45771832e-03,
        -5.87116122e-01],
       [ 3.10445111e-02,  6.00103617e-01, -1.44182831e-01,
        -1.16887462e+00],
       [-7.64323771e-02, -9.87525702e-01,  1.16068654e-01,
     

  4%|▍         | 432/10000 [00:01<00:36, 258.84it/s]

[array([[-0.13918546, -0.62230307,  0.04248252,  0.75081182],
       [-0.00980014,  0.3564992 ,  0.02339416, -0.42003903],
       [ 0.31920651,  0.18086858, -0.09360509, -0.11709028],
       [-0.14229487, -0.03967594,  0.15822901,  0.20145646],
       [-0.03769335,  0.17339058, -0.03014198, -0.05562299],
       [ 0.23632361,  0.70301318,  0.14254445,  0.08918682],
       [ 0.00783058,  0.22156288, -0.05684305, -0.28628376],
       [-0.02179163,  0.00504456,  0.05097077,  0.02489962],
       [ 0.03958736, -0.01227044, -0.03358779,  0.02704067],
       [ 0.01105888, -0.04601753, -0.01683446,  0.0123569 ],
       [-0.01007724,  0.17217761,  0.11287244,  0.07547077],
       [-0.12704186, -0.26143974,  0.19464599,  0.78767955],
       [-0.00655979,  0.15762387,  0.04945727, -0.00940532],
       [-0.01352305,  0.76533878, -0.12766871, -1.20593452],
       [ 0.05871443,  0.43332747, -0.03474676, -0.57039666],
       [ 0.01654014,  0.16971533, -0.02955611, -0.24660705],
       [ 0.20927948,  0

  5%|▍         | 487/10000 [00:01<00:36, 258.26it/s]

[array([[-1.17700227e-01, -1.71931255e+00,  1.47465304e-01,
         2.60036445e+00],
       [ 4.56052780e-01,  1.76745251e-01,  6.45984383e-03,
         6.39837325e-01],
       [ 2.93573029e-02, -1.37366783e-02, -2.69106012e-02,
        -5.82656749e-02],
       [-1.77429926e+00, -1.66873205e+00, -1.07253641e-01,
        -1.52399465e-01],
       [ 2.48711780e-02, -2.66790967e-02, -5.28066531e-02,
        -5.55401146e-02],
       [-9.53299925e-02, -3.08380593e-02,  1.83614179e-01,
         3.89550239e-01],
       [ 1.04452312e-01,  9.68702018e-01, -1.34885281e-01,
        -1.56594217e+00],
       [-2.30835471e-03,  2.36848295e-02,  1.28659993e-01,
         4.55836296e-01],
       [-5.50733060e-02, -4.69700098e-01,  1.98765352e-01,
         1.39230716e+00],
       [-1.13943875e-01, -2.05504388e-01,  7.22385645e-02,
         2.80182451e-01],
       [-2.61677653e-01, -2.24802680e-02,  4.88213543e-03,
        -1.39998049e-01],
       [-9.89427865e-02,  1.66118741e-01,  9.16676130e-03,
     

  5%|▌         | 543/10000 [00:02<00:36, 261.52it/s]

[array([[-9.86573845e-02,  1.66369557e-01,  1.82709292e-01,
         4.52444144e-02],
       [-3.45414621e-03, -2.10419502e-02, -9.88858193e-02,
        -1.72565863e-01],
       [ 7.18788058e-02,  6.40912056e-02, -1.65012106e-01,
        -5.04976153e-01],
       [-7.10815862e-02, -8.40242952e-03,  1.34130893e-02,
        -2.76111186e-01],
       [-2.62669101e-02,  1.67862400e-01,  5.49009163e-03,
        -3.47551078e-01],
       [ 2.71182172e-02, -1.92633450e-01, -1.14121944e-01,
        -1.62001804e-01],
       [ 1.20444238e-01,  5.45317471e-01,  5.51915541e-02,
        -4.07031596e-01],
       [-4.62918654e-02, -3.76693219e-01,  2.87330709e-03,
         5.31350553e-01],
       [ 4.56024110e-02,  3.71964216e-01, -1.09639689e-01,
        -7.36242294e-01],
       [ 1.68940216e-01,  1.39892077e+00, -1.29645318e-01,
        -1.93459058e+00],
       [ 8.23031008e-01,  1.32311225e+00,  1.03289122e-02,
        -6.48162007e-01],
       [-2.55919173e-02,  3.81497562e-01,  2.81620398e-02,
     

  6%|▌         | 598/10000 [00:02<00:35, 265.50it/s]

[array([[-4.23047915e-02, -3.74069840e-01, -6.30780607e-02,
         3.80760372e-01],
       [ 1.51214236e-03, -1.49437159e-01, -4.03664559e-02,
         2.28636503e-01],
       [-3.88712026e-02, -4.43503223e-02, -4.32417169e-02,
        -1.46214902e-01],
       [ 1.14351295e-01,  3.36613953e-01, -5.05492724e-02,
        -1.79272681e-01],
       [ 4.29853750e-03, -1.55562893e-01, -3.16806436e-02,
         1.78034991e-01],
       [ 2.66246609e-02, -1.78327173e-01,  6.42192140e-02,
         3.79927605e-01],
       [ 3.67973857e-02, -1.21653247e+00,  1.28264353e-01,
         2.00167584e+00],
       [-1.78918671e-02, -1.89962015e-02,  1.22092590e-01,
         3.66931021e-01],
       [-5.34273125e-02, -5.97696483e-01,  3.10061723e-02,
         8.69798779e-01],
       [ 3.56612116e-01, -4.62377548e-01,  1.46778360e-01,
         1.73729026e+00],
       [-2.48920750e-02, -4.01455201e-02, -3.79128531e-02,
         5.29775359e-02],
       [ 1.25740135e+00,  1.23691177e+00,  1.59284711e-01,
     

  7%|▋         | 685/10000 [00:02<00:34, 273.49it/s]

[array([[ 1.09222680e-02, -9.56704080e-01, -5.24433143e-02,
         1.12063754e+00],
       [ 6.11680388e-01,  9.61549997e-01,  9.77300629e-02,
        -9.47628543e-02],
       [-1.04524074e-02,  3.83452505e-01,  1.41838379e-02,
        -5.21239281e-01],
       [ 2.76419252e-01,  6.12915397e-01, -1.93674609e-01,
        -9.93093252e-01],
       [-2.56386362e-02, -3.60599846e-01,  5.78812808e-02,
         6.73049212e-01],
       [-3.06360237e-02,  2.11846203e-01,  1.83359496e-02,
        -2.66931802e-01],
       [-4.09243673e-01,  4.31389838e-01, -6.70090243e-02,
        -6.48800552e-01],
       [-1.67536847e-02,  1.81326166e-01, -2.22577732e-02,
        -2.68891633e-01],
       [-9.96381976e-03, -2.38404453e-01, -3.31780873e-02,
         1.77196413e-01],
       [-5.86758833e-03,  2.39075255e-03,  3.14905867e-02,
         8.34443942e-02],
       [ 7.55826570e-03,  2.01932564e-01, -1.28891447e-03,
        -3.71027440e-01],
       [-6.05203956e-03,  7.62080491e-01,  5.42554818e-02,
     

  7%|▋         | 742/10000 [00:02<00:33, 277.12it/s]

[array([[ 1.09680034e-01,  9.18531884e-03, -1.87531784e-01,
        -3.47482890e-01],
       [ 4.17429730e-02,  1.75835826e-02, -6.00982085e-03,
         3.04501038e-02],
       [-4.48149852e-02,  4.52218354e-02,  2.04266924e-02,
        -1.40488828e-02],
       [-3.45558554e-01,  7.97259986e-01, -8.10069293e-02,
        -1.29668033e+00],
       [-8.23475141e-03, -1.16428956e-01, -8.38267282e-02,
        -4.79065180e-01],
       [ 2.03088075e-02, -1.84913725e-01,  4.70345318e-02,
         3.58750015e-01],
       [ 2.92452630e-02,  7.77726650e-01,  5.01682833e-02,
        -7.68788517e-01],
       [ 4.68378246e-01,  1.16804790e+00,  1.42441094e-01,
        -6.45135820e-01],
       [-1.33965854e-02, -2.08186135e-01, -2.73018312e-02,
         3.00525844e-01],
       [ 2.84925681e-02, -7.71849275e-01,  1.33230758e-03,
         1.15355194e+00],
       [ 1.27257049e-01,  3.68138701e-01,  2.11934373e-02,
        -2.90492233e-02],
       [-1.38623253e-01,  1.70404419e-01,  1.34175390e-01,
     

  8%|▊         | 799/10000 [00:02<00:33, 278.11it/s]

[array([[-3.76830190e-01, -7.07494080e-01, -1.91453755e-01,
         4.14261311e-01],
       [-2.12955661e-02, -3.16544361e-02,  4.95175496e-02,
         1.54568017e-01],
       [ 9.85804051e-02,  1.59204662e-01,  6.80304170e-02,
        -4.42460552e-02],
       [-2.51526654e-01,  2.12868765e-01,  6.46526739e-02,
        -6.82049632e-01],
       [ 3.25982198e-02,  7.99789548e-01, -9.83031690e-02,
        -1.27551961e+00],
       [ 7.79865533e-02,  2.20351234e-01, -1.35446653e-01,
        -5.13581991e-01],
       [ 2.01507941e-01, -3.28235812e-02,  8.76574144e-02,
         3.13668728e-01],
       [-1.44763738e-02, -1.49685398e-01,  5.09881340e-02,
         2.92124778e-01],
       [-3.78021561e-02, -1.32551515e+00,  1.95203233e-03,
         1.16299975e+00],
       [ 6.97138980e-02,  4.13632512e-01, -1.20122984e-01,
        -7.66183376e-01],
       [-1.58760637e-01, -4.51566398e-01,  1.90393180e-01,
         1.00976717e+00],
       [-4.08415161e-02, -2.32778877e-01, -7.69991130e-02,
     

  8%|▊         | 827/10000 [00:03<00:38, 236.83it/s]

[array([[-1.29284421e-02, -4.17501241e-01, -8.68297927e-03,
         5.30513465e-01],
       [-2.17411648e-02,  4.15729970e-01, -1.00172952e-01,
        -9.04487550e-01],
       [ 1.87319666e-01,  4.02349621e-01, -8.40852559e-02,
        -3.44015539e-01],
       [ 3.25268432e-02, -5.11525154e-01, -9.75461677e-02,
         2.15943530e-01],
       [ 2.92113703e-02,  2.39879102e-01, -6.64613843e-02,
        -3.65486562e-01],
       [ 1.63574710e-01,  1.57682344e-01, -2.07636163e-01,
        -5.85485935e-01],
       [-1.22428931e-01, -3.81367773e-01,  1.17431395e-01,
         6.54121637e-01],
       [ 1.12213922e+00,  6.64129257e-01,  7.66064078e-02,
         9.48090553e-01],
       [ 6.00894809e-01,  7.05196440e-01,  2.25889478e-02,
         2.07640544e-01],
       [ 8.12143832e-02,  6.18585229e-01, -7.07352757e-02,
        -9.56821918e-01],
       [-2.53809750e-01,  3.67805451e-01, -1.19248465e-01,
        -9.55184519e-01],
       [-8.48735720e-02,  1.95889339e-01,  1.54403687e-01,
     

  9%|▉         | 881/10000 [00:03<00:36, 249.47it/s]

[array([[ 1.05899788e-01,  5.79748869e-01, -4.14411128e-02,
        -7.51114726e-01],
       [ 8.12143832e-02,  6.18585229e-01, -7.07352757e-02,
        -9.56821918e-01],
       [-1.78105049e-02, -5.56406081e-01,  4.68768626e-02,
         9.03255939e-01],
       [ 1.22725859e-03, -2.41520941e-01,  6.43622577e-02,
         4.30438191e-01],
       [ 2.11162537e-01,  1.58263758e-01,  9.06928107e-02,
         1.09288543e-01],
       [-6.07281663e-02, -4.42974657e-01,  1.49248913e-01,
         7.58163691e-01],
       [ 1.29999984e-02, -3.71900618e-01,  1.72872059e-02,
         5.23216367e-01],
       [-1.54515669e-01, -3.57976735e-01, -4.83146347e-02,
        -4.18521762e-01],
       [-4.50566001e-02,  1.67444766e-01,  7.21874684e-02,
         3.54734249e-02],
       [-8.35929066e-03,  1.69835150e-01,  6.12116791e-02,
        -1.50906816e-01],
       [-5.79973757e-01,  7.61708692e-02, -1.43342003e-01,
        -1.06507850e+00],
       [-6.81740120e-02, -1.73838288e-01,  9.74738598e-03,
     

  9%|▉         | 933/10000 [00:03<00:36, 249.58it/s]

[array([[-4.55824584e-02,  2.30129406e-01,  6.13418594e-02,
        -1.20891444e-01],
       [-2.82173157e-01, -2.22712636e-01,  3.07286680e-02,
         4.23191786e-02],
       [ 8.03804845e-02,  1.89032331e-01, -1.21887930e-01,
        -4.53423411e-01],
       [ 1.72734702e+00,  1.50142145e+00,  6.99844807e-02,
         3.11406050e-02],
       [-4.71240059e-02, -1.62796259e-01, -2.64909901e-02,
         1.45143628e-01],
       [ 4.05777916e-02,  2.39359692e-01, -3.81822176e-02,
        -4.00957137e-01],
       [-2.17014506e-01, -9.88816619e-01,  1.19960271e-01,
         1.06377053e+00],
       [-6.61448836e-02,  1.68791890e-01,  5.84324524e-02,
        -1.50850832e-01],
       [-4.47023034e-01, -6.99766636e-01, -6.22148737e-02,
         3.06031168e-01],
       [-3.43814678e-03,  9.78988349e-01,  2.59537883e-02,
        -1.14412534e+00],
       [-8.28531861e-01, -1.25996709e+00, -8.86009932e-02,
        -1.35809466e-01],
       [ 5.52726015e-02,  3.82805794e-01, -8.75532627e-02,
     

 10%|▉         | 990/10000 [00:03<00:34, 263.76it/s]

[array([[-1.13020567e-02,  2.35100865e-01,  6.47065938e-02,
        -2.14753002e-01],
       [ 2.17272043e-01,  1.67494923e-01, -4.89344448e-02,
        -3.55197191e-02],
       [-1.42741054e-02,  5.69370627e-01, -1.91467423e-02,
        -8.64460826e-01],
       [ 1.66465640e-02,  1.44613564e-01,  8.15766398e-03,
        -5.20327650e-02],
       [ 4.26259674e-02, -3.52380574e-01,  3.48963179e-02,
         5.79234838e-01],
       [ 1.19300926e+00,  7.46272802e-01, -1.58980023e-02,
         4.19059135e-02],
       [ 4.00782637e-02,  7.32586205e-01, -1.01266891e-01,
        -1.06851208e+00],
       [-1.29238045e+00, -1.04083502e+00, -1.34841517e-01,
        -2.65718579e-01],
       [ 1.05899788e-01,  5.79748869e-01, -4.14411128e-02,
        -7.51114726e-01],
       [ 5.95287904e-02,  7.55062759e-01, -1.44790098e-01,
        -1.42099142e+00],
       [-3.23010213e-03,  1.47625839e-03,  9.10294205e-02,
         5.06572664e-01],
       [-7.66299153e-03,  1.14791084e-03, -4.52681025e-03,
     

 10%|█         | 1047/10000 [00:03<00:34, 256.18it/s]

[array([[ 9.29894671e-02,  5.13242841e-01,  1.84512080e-03,
        -3.24231386e-01],
       [-4.43626791e-01,  4.29218531e-01, -1.79992877e-02,
        -5.99751770e-01],
       [ 5.26005439e-02,  6.17605865e-01, -2.73690801e-02,
        -9.33576822e-01],
       [-3.61799300e-02, -5.76674342e-01,  4.28533927e-02,
         7.55599558e-01],
       [ 4.04059738e-02, -3.30950134e-02, -3.71245332e-02,
        -1.02621242e-01],
       [ 2.19468459e-01,  3.93240213e-01, -9.15766433e-02,
        -3.10341746e-01],
       [-6.58810511e-02, -2.17786297e-01,  1.48840919e-02,
         2.81562805e-01],
       [ 5.33336140e-02,  3.96256447e-01, -4.11500484e-02,
        -5.87062001e-01],
       [-2.97603905e-02, -5.30496955e-01, -2.86813267e-02,
         6.62956357e-01],
       [ 2.29062319e-01, -1.28849459e+00,  3.13652679e-02,
         1.27825916e+00],
       [-1.16050072e-01, -7.99399495e-01,  1.22893341e-01,
         9.22906101e-01],
       [ 2.05385357e-01,  2.95894027e-01, -1.08955633e-02,
     

 11%|█         | 1103/10000 [00:04<00:33, 266.73it/s]

[array([[-4.62845378e-02, -1.70868725e-01,  1.06122851e-01,
         5.00959218e-01],
       [-1.73364133e-01, -8.15277338e-01,  1.82393730e-01,
         1.36758113e+00],
       [-6.77333891e-01, -6.88904524e-01, -2.76619643e-02,
         6.57893494e-02],
       [-7.62104839e-02, -1.32036284e-02,  5.55240847e-02,
        -4.79958057e-02],
       [ 9.30171192e-01,  1.14550376e+00,  6.37449650e-03,
        -1.40652537e-01],
       [ 3.83262187e-02, -3.61301452e-01, -4.00013067e-02,
         4.96462792e-01],
       [ 3.63938779e-01,  1.52551681e-01,  4.96666990e-02,
         2.94369161e-01],
       [ 4.91587460e-01,  1.73311245e+00, -1.96984172e-01,
        -2.09112597e+00],
       [ 1.21471263e-01,  4.12761331e-01,  1.12016378e-02,
        -2.29019318e-02],
       [-3.40303518e-02,  2.27379017e-02,  1.87838003e-02,
        -8.89463816e-03],
       [-3.23040001e-02, -4.18273479e-01,  4.04646508e-02,
         5.60007632e-01],
       [ 7.41199926e-02, -1.95769206e-01, -1.18798293e-01,
     

 12%|█▏        | 1159/10000 [00:04<00:36, 243.08it/s]

[array([[ 5.78999460e-01, -1.00966088e-01,  4.34906185e-02,
         1.21185899e+00],
       [ 1.82945672e-02,  1.52481541e-01,  5.60176149e-02,
         1.04201190e-01],
       [-3.30792591e-02, -1.48142010e-01,  2.66150143e-02,
         2.81003803e-01],
       [-7.64014050e-02, -2.02971607e-01,  9.70233902e-02,
         4.50457782e-01],
       [ 4.09585424e-03,  3.69101673e-01, -9.64374691e-02,
        -6.32642627e-01],
       [-1.93656459e-02,  7.14852661e-03,  4.54599969e-02,
        -2.23232005e-02],
       [ 1.31011426e-01, -2.15657353e-01,  2.29946710e-02,
         4.71987128e-01],
       [-2.88942382e-02, -3.88847142e-02, -3.85341085e-02,
        -1.28482236e-03],
       [ 4.60359594e-03,  1.79921627e-01, -4.79931384e-02,
        -2.96873242e-01],
       [ 5.33480197e-02,  9.26426589e-01,  1.16158992e-01,
        -3.98944408e-01],
       [-4.09798697e-02,  3.41846868e-02,  5.89240305e-02,
         1.90496027e-01],
       [ 1.86616451e-01, -1.78637922e-01, -1.55452956e-02,
     

 12%|█▏        | 1212/10000 [00:04<00:35, 246.12it/s]

[array([[ 0.07093523,  0.36017972, -0.0577572 , -0.60365123],
       [ 0.01869283, -0.37566867, -0.03882695,  0.60444397],
       [ 0.09282617,  0.94785279, -0.19562647, -1.67390597],
       [ 0.41281712,  0.35783792, -0.13473155, -0.28464499],
       [-0.05411715, -1.19475377,  0.07261599,  1.7744931 ],
       [-0.02348893, -0.01407758,  0.02159508,  0.08852084],
       [-0.25953934,  0.59506619,  0.08924788, -0.94545943],
       [-0.49134716, -0.59563208,  0.01108522,  0.27198163],
       [ 0.08616608,  0.82930338, -0.1468567 , -1.35894978],
       [-0.03311731, -0.18203793,  0.02011599,  0.2095255 ],
       [-0.17385821, -0.54247069, -0.16040653, -0.20686674],
       [-0.28662741, -0.02804449,  0.03157505, -0.24051237],
       [ 0.48295811,  1.50595462, -0.14074035, -1.15502989],
       [ 0.09272627, -0.15651546, -0.0268963 ,  0.29782343],
       [ 0.21534956,  0.12718433,  0.04740179,  0.16900772],
       [-0.40961015, -1.1037482 , -0.02888833,  0.90379924],
       [ 0.0913215 ,  0

 12%|█▏        | 1244/10000 [00:04<00:33, 261.65it/s]

[array([[-3.54875374e-04,  5.65726161e-01, -7.08995163e-02,
        -6.88524663e-01],
       [ 1.39277387e+00,  3.31699640e-01,  1.44149929e-01,
         8.49931121e-01],
       [ 2.93689221e-01, -3.81554574e-01,  1.55202091e-01,
         4.31155145e-01],
       [-1.44011565e-02, -3.76984596e-01, -3.91874416e-03,
         4.98327672e-01],
       [-7.18676001e-02, -5.53223848e-01,  1.58709064e-01,
         1.09725749e+00],
       [-8.49345624e-02,  3.04984068e-03,  1.43719152e-01,
         5.34226537e-01],
       [ 1.62415151e-02, -5.26850581e-01, -1.22237355e-01,
         2.13517055e-01],
       [-6.82528913e-02, -2.45362774e-01,  1.33418679e-01,
         7.63013124e-01],
       [ 4.00509238e-01, -7.62097001e-01,  3.96491587e-02,
         7.98077345e-01],
       [ 3.22436206e-02,  3.88147295e-01, -2.04774383e-02,
        -5.86217821e-01],
       [ 5.60712740e-02,  1.99294388e-01, -1.13592893e-01,
        -4.76752400e-01],
       [-3.09381671e-02,  2.21829321e-02,  4.89093326e-02,
     




KeyboardInterrupt: 