In [None]:
class DBQN_learn():
    
    def __init__(self, model, target_model, gamma, lr, batch_size, num_steps):
        super().__init__()
        
        self.dbqn = model
        self.target_dbqn = target_model
        self.update_target()
        
        self.t = 0
        self.gamma = gamma
        self.optimizer = optim.Adam(lr = lr, params = self.dbqn.parameters())
        self.batch_size = batch_size
        self.num_steps = num_steps
        
        self.train_freq = 1
        self.update_freq = 100
        self.replay_buffer = PrioritizedReplayBuffer(10000, 0.7)
        self.constt = 1/self.num_steps
        
    def train(self):
        
        buffer_sample = self.replay_buffer.sample(self.batch_size, beta=0.5)
        obs = torch.from_numpy(buffer_sample[0]).float().to(DEVICE)
        act = torch.from_numpy(buffer_sample[1]).long().to(DEVICE)
        rew = torch.from_numpy(buffer_sample[2]).float().to(DEVICE)
        obs1 = torch.from_numpy(buffer_sample[3]).float().to(DEVICE)        
        dones = torch.from_numpy(buffer_sample[4].astype(int)).float().to(DEVICE)    
        wt = torch.from_numpy(buffer_sample[5]).float().to(DEVICE)    
        idxes = buffer_sample[6]
        
        self.dbqn.train()
        
        val1 = self.target_dbqn(obs1, use_sample=False).detach()
        _, max_act = val1.max(1)
        val1 = val1.gather(1, max_act.view(-1, 1)).squeeze()        
        targets = rew + self.gamma * val1 * (1 - dones)     
        
        log_prior, log_variational_posterior, mse, td_errors = self.dbqn.sample_elbo(obs, act, targets)                
        loss = (log_variational_posterior - log_prior) + mse        
                
        weighted_loss = (wt * loss).mean()        
        self.optimizer.zero_grad()
                
        weighted_loss.backward()        
        self.optimizer.step()
                
        for p in self.dbqn.parameters():
            if torch.isnan(p).sum().item() == 1:
                pdb.set_trace()

        writer.add_scalar('data/loss', weighted_loss.detach().numpy(), self.t)        
        writer.add_scalar('data/prior', log_prior.detach().cpu().numpy(), self.t)
        writer.add_scalar('data/posterior', log_variational_posterior.detach().cpu().numpy(), self.t)
        writer.add_scalar('data/mse', mse.mean().detach().cpu().numpy(), self.t)        
                        
        writer.add_scalar('data/w1_sigma', np.mean(self.dbqn.l1.weight.sigma[0].detach().cpu().numpy()).item(), self.t)              
        writer.add_scalar('data/w2_sigma', np.mean(self.dbqn.l2.weight.sigma[0].detach().cpu().numpy()).item(), self.t)        
        
        
        return idxes, td_errors               
        
    def step(self, obs_t, act_t, rew_t, obs_t1, done ):
        
        self.replay_buffer.add(obs_t, act_t, rew_t, obs_t1, done)
        self.t = self.t + 1
        
        if self.t%self.train_freq == 0 and self.t > self.batch_size:
            idxes, td_errors = self.train()
            self.replay_buffer.update_priorities(idxes, np.abs(td_errors) + 1e-6)
            
        if self.t%self.update_freq == 0 and self.t > self.batch_size:
            self.update_target()
        
        return self.act(obs_t1, use_sample=True, num_sample=0)
        
    def act(self, obs, use_sample, num_sample):
        
        obs = torch.from_numpy(obs).float().unsqueeze(0).to(DEVICE)   
        return np.argmax(self.dbqn(obs, use_sample, num_sample).detach().cpu().numpy())
        
    def reset(self, obs):        
        self.t = self.t + 1
        return self.act(obs, use_sample=True, num_sample=2)
        
    def update_target(self):
        self.target_dbqn.load_state_dict(self.dbqn.state_dict())        
        

In [None]:
def test_agent(agent):
    
    count = 0    
    test_episode_rew = 0
    test_return = []
    
    done = False
    agent.dbqn.eval()
    
    obs = env.reset()
    act = agent.act(obs, use_sample=True, num_sample=4)
    
    while count <= 99:
        if done:
            test_return.append(test_episode_rew)
#             print(test_episode_rew)
            test_episode_rew = 0
            count = count + 1
            
            obs = env.reset()
            act = agent.act(obs, use_sample=True, num_sample=4)
        
        obs1, rew, done, _ = env.step(act)       
        act = agent.act(obs1, use_sample=True, num_sample=0)        
        test_episode_rew = test_episode_rew + rew        
    
    agent.dbqn.train()
    return np.mean(np.array(test_return))

In [None]:
runs = 5
run_result = []

# lambda1 = lambda lr: lr * 0.999

In [None]:
for run in range(runs):
    
    lr = 1e-3
    batch_size = 32
    gamma = 0.95
    STEPS = 30000
    writer = SummaryWriter()
    dbqn = BayesianNetwork(env, 4, batch_size).to(DEVICE)    
    target_dbqn = BayesianNetwork(env, 4, batch_size).to(DEVICE)
    agent = DBQN_learn(dbqn, target_dbqn, gamma, lr, batch_size, STEPS)
    
#     scheduler = optim.lr_scheduler.LambdaLR(agent.optimizer, lr_lambda=lambda1)
    
    done = False

    episode_rew = 0
    episode_count = 0
    res = []

    obs = env.reset()
    act = agent.reset(obs)         

    while agent.t <= STEPS or episode_count < 300:

        if done:
            print("Episode " + str(episode_count) + " with reward = " + str(episode_rew))  
            writer.add_scalar('data/reward', episode_rew, episode_count)
            res.append(episode_rew)
            episode_rew = 0
            episode_count = episode_count + 1                
            
#             if episode_count < 100:
#                 scheduler.step()

            if episode_count%25 == 0:
                test_result = test_agent(agent)
                print("Test Result = " + str(test_result))
                writer.add_scalar('data/test_reward', test_result, episode_count)

            obs = env.reset()
            act = agent.reset(obs)   

        obs1, rew, done, _ = env.step(act)       
        act = agent.step(obs, act, rew, obs1, done)
        obs = obs1
        episode_rew = episode_rew + rew    
            
    writer.export_scalars_to_json("./all_scalars.json")
    writer.close()
        
    run_result.append(res)

Episode 0 with reward = 10.0
Episode 1 with reward = 9.0
Episode 2 with reward = 10.0
Episode 3 with reward = 14.0
Episode 4 with reward = 11.0
Episode 5 with reward = 10.0
Episode 6 with reward = 10.0
Episode 7 with reward = 12.0
Episode 8 with reward = 11.0
Episode 9 with reward = 30.0
Episode 10 with reward = 10.0
Episode 11 with reward = 11.0
Episode 12 with reward = 14.0
Episode 13 with reward = 11.0
Episode 14 with reward = 9.0
Episode 15 with reward = 10.0
Episode 16 with reward = 11.0
Episode 17 with reward = 15.0
Episode 18 with reward = 10.0
Episode 19 with reward = 10.0
Episode 20 with reward = 10.0
Episode 21 with reward = 9.0
Episode 22 with reward = 9.0
Episode 23 with reward = 11.0
Episode 24 with reward = 10.0
Test Result = 9.59
Episode 25 with reward = 13.0
Episode 26 with reward = 12.0
Episode 27 with reward = 9.0
Episode 28 with reward = 10.0
Episode 29 with reward = 9.0
Episode 30 with reward = 9.0
Episode 31 with reward = 12.0
Episode 32 with reward = 9.0
Episode 3

Episode 260 with reward = 200.0
Episode 261 with reward = 200.0
Episode 262 with reward = 200.0
Episode 263 with reward = 200.0
Episode 264 with reward = 200.0
Episode 265 with reward = 200.0
Episode 266 with reward = 200.0
Episode 267 with reward = 200.0
Episode 268 with reward = 200.0
Episode 269 with reward = 200.0
Episode 270 with reward = 200.0
Episode 271 with reward = 200.0
Episode 272 with reward = 121.0
Episode 273 with reward = 177.0
Episode 274 with reward = 200.0
Test Result = 171.21
Episode 275 with reward = 200.0
Episode 276 with reward = 200.0
Episode 277 with reward = 200.0
Episode 278 with reward = 200.0
Episode 279 with reward = 200.0
Episode 280 with reward = 200.0
Episode 281 with reward = 200.0
Episode 282 with reward = 200.0
Episode 283 with reward = 200.0
Episode 284 with reward = 200.0
Episode 285 with reward = 200.0
Episode 286 with reward = 164.0
Episode 287 with reward = 200.0
Episode 288 with reward = 200.0


In [None]:
min_length = 10000
for i in range(len(run_result)):    
    if min_length > len(run_result[i]):
        min_length = len(run_result[i])   

In [None]:
tmp_result = [run_result[i][:min_length] for i in range(10)]
tmp_result = np.stack( tmp_result, axis=0 )

In [None]:
plt.plot(np.mean(tmp_result, 0))