In [None]:
class BayesianLinear(nn.Module):
    
    def __init__(self, in_features, out_features):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        #Weight Parameters 
        self.init_weight_mean = nn.Parameter(torch.Tensor(out_features, in_features).uniform_(-0.2, 0.2))
        self.init_weight_rho = nn.Parameter(torch.Tensor(out_features, in_features).uniform_(-2, -1))
        self.weight = DifferentiableGaussian(self.init_weight_mean, self.init_weight_rho)
        
        #Bias Parameters 
        self.init_bias_mean = nn.Parameter(torch.Tensor(out_features).uniform_(-0.2, 0.2))
        self.init_bias_rho = nn.Parameter(torch.Tensor(out_features).uniform_(-2, -1))
        self.bias = DifferentiableGaussian(self.init_bias_mean, self.init_bias_rho)
        
        #Prior Distributions
        pi = 0.5
        prior_sigma1 = torch.FloatTensor([math.exp(-0)])
        prior_sigma2 = torch.FloatTensor([math.exp(1)])
        self.weight_prior = ScaleMixtureGaussian(pi, prior_sigma1, prior_sigma2)
        self.bias_prior = ScaleMixtureGaussian(pi, prior_sigma1, prior_sigma2)
        
        self.weight_sample, self.bias_sample = self.weight.sample().detach(), self.bias.sample().detach()
        self.log_prior = 0
        self.log_variational_posterior = 0

        
    def sample(self, MC_samples=4):
        wt_val = []; bias_val = []
        for i in range(MC_samples):
            wt_val.append(self.weight.sample())
            bias_val.append(self.bias.sample())
            
        return torch.mean(torch.stack(wt_val), dim=0).data, torch.mean(torch.stack(bias_val), dim=0).data    
    
    def forward(self, input, use_sample, num_sample=0):
        
        if torch.isnan(self.weight.mean).sum().item() == 1:
            pdb.set_trace()
            
        if use_sample:       
            weight = self.weight.sample()
            bias = self.bias.sample()            
            if num_sample:
                self.weight_sample.data, self.bias_sample.data = self.sample()
            weight.data = self.weight_sample
            bias.data = self.bias_sample
                            
        else:
            weight = self.weight.mean
            bias = self.bias.mean
        if self.training :
            self.log_prior = self.weight_prior.log_prob(weight) + self.bias_prior.log_prob(bias)
            self.log_variational_posterior = self.weight.log_prob(weight) + self.bias.log_prob(bias)
        else:
            self.log_prior, self.log_variational_posterior = 0, 0
        
        return F.linear(input, weight, bias)


In [None]:
class BayesianNetwork(nn.Module):
    
    def __init__(self, input_features, output_features, num_samples, batch_size):
        super().__init__()
        
        self.num_samples = num_samples
        self.l1 = BayesianLinear(env.observation_space.shape[0], 32)
        self.l2 = BayesianLinear(32, env.action_space.n)
        
        layer_arr = [self.l1, self.l2] 
        self.layer_arr = nn.ModuleList(layer_arr)       
        self.layer_num = len(layer_arr)
                
        self.batch_size = batch_size
        
    def forward(self, x, use_sample, num_sample=0):        
        for i in range(self.layer_num - 1):            
            x = F.relu(self.layer_arr[i](x, use_sample, num_sample))            
        x = self.layer_arr[i+1](x, use_sample, num_sample)
        return x
        
    def log_prior(self):
        log_prior = 0        
        for i in range(self.layer_num):
            log_prior = log_prior + self.layer_arr[i].log_prior
        return log_prior
    
    def log_variational_posterior(self):
        log_posterior = 0
#         pdb.set_trace()
        for i in range(self.layer_num):
            log_posterior = log_posterior + self.layer_arr[i].log_variational_posterior
        return log_posterior
            
    def sample_elbo(self, obs, act, targets):
        
        outputs = torch.zeros(self.num_samples, self.batch_size).to(DEVICE)
        log_prior = torch.zeros(self.num_samples).to(DEVICE)
        log_variational_posterior = torch.zeros(self.num_samples).to(DEVICE)
        for i in range(self.num_samples):
            val = self.forward(obs, use_sample=True, num_sample=1)
            outputs[i] = val.gather(1, act.view(-1, 1)).squeeze()
            log_prior[i] = log_prior[i] + self.log_prior()/STEPS
            log_variational_posterior[i] = log_variational_posterior[i] + self.log_variational_posterior()/STEPS
        
        loss = nn.MSELoss(reduction='none')
        mse = loss(outputs.mean(0), targets)
        log_prior = log_prior.mean()
#         log_prior.data = torch.clamp(log_prior.data, -1000, 1000)        
        log_variational_posterior = log_variational_posterior.mean()
        td_errors = outputs.mean(0) - targets
        
        return log_prior, log_variational_posterior, mse, td_errors.detach().cpu().numpy()

In [None]:
class DBQN_learn():
    
    def __init__(self, model, target_model, gamma, lr, batch_size, num_steps):
        super().__init__()
        
        self.dbqn = model
        self.target_dbqn = target_model
        self.update_target()
        
        self.t = 0
        self.gamma = gamma
        self.optimizer = optim.Adam(lr = lr, params = self.dbqn.parameters())
        self.batch_size = batch_size
        self.num_steps = num_steps
        
        self.train_freq = 1
        self.update_freq = 100
        self.replay_buffer = PrioritizedReplayBuffer(10000, 0.7)
        self.constt = 1/self.num_steps
        
    def train(self):
        
        buffer_sample = self.replay_buffer.sample(self.batch_size, beta=0.5)
        obs = torch.from_numpy(buffer_sample[0]).float().to(DEVICE)
        act = torch.from_numpy(buffer_sample[1]).long().to(DEVICE)
        rew = torch.from_numpy(buffer_sample[2]).float().to(DEVICE)
        obs1 = torch.from_numpy(buffer_sample[3]).float().to(DEVICE)        
        dones = torch.from_numpy(buffer_sample[4].astype(int)).float().to(DEVICE)    
        wt = torch.from_numpy(buffer_sample[5]).float().to(DEVICE)    
        idxes = buffer_sample[6]
        
        self.dbqn.train()
        
        val1 = self.target_dbqn(obs1, use_sample=False).detach()
        _, max_act = val1.max(1)
        val1 = val1.gather(1, max_act.view(-1, 1)).squeeze()        
        targets = rew + self.gamma * val1 * (1 - dones)     
        
        log_prior, log_variational_posterior, mse, td_errors = self.dbqn.sample_elbo(obs, act, targets)                
        loss = (log_variational_posterior - log_prior) + mse        
                
        weighted_loss = (wt * loss).mean()        
        self.optimizer.zero_grad()
                
        weighted_loss.backward()
        old_wt = []
        for p in self.dbqn.parameters():
            old_wt.append(deepcopy(p))            
        self.optimizer.step()
                
        for p in self.dbqn.parameters():
            if torch.isnan(p).sum().item() == 1:
                pdb.set_trace()
#         writer.add_scalar('data/constt', self.constt, self.t)
        writer.add_scalar('data/loss', weighted_loss.detach().numpy(), self.t)        
        writer.add_scalar('data/prior', log_prior.detach().cpu().numpy(), self.t)
        writer.add_scalar('data/posterior', log_variational_posterior.detach().cpu().numpy(), self.t)
        writer.add_scalar('data/mse', mse.mean().detach().cpu().numpy(), self.t)        
                        
#         writer.add_scalar('data/w1_mu', self.dbqn.l1.weight.mean[0][TMP1].detach().cpu().numpy().item(), self.t)
        writer.add_scalar('data/w1_sigma', np.mean(self.dbqn.l1.weight.sigma[0].detach().cpu().numpy()).item(), self.t)
              
#         writer.add_scalar('data/w2_mu', self.dbqn.l2.weight.mean[0][TMP2].detach().cpu().numpy().item(), self.t)
        writer.add_scalar('data/w2_sigma', np.mean(self.dbqn.l2.weight.sigma[0].detach().cpu().numpy()).item(), self.t)        
        
        
        return idxes, td_errors               
        
    def step(self, obs_t, act_t, rew_t, obs_t1, done ):
        
        self.replay_buffer.add(obs_t, act_t, rew_t, obs_t1, done)
        self.t = self.t + 1
        
        if self.t%self.train_freq == 0 and self.t > self.batch_size:
            idxes, td_errors = self.train()
            self.replay_buffer.update_priorities(idxes, np.abs(td_errors) + 1e-6)
            
        if self.t%self.update_freq == 0 and self.t > self.batch_size:
            self.update_target()
        
        return self.act(obs_t1, use_sample=True, num_sample=0)
        
    def act(self, obs, use_sample, num_sample):
        
        obs = torch.from_numpy(obs).float().unsqueeze(0).to(DEVICE)   
        return np.argmax(self.dbqn(obs, use_sample, num_sample).detach().cpu().numpy())
        
    def reset(self, obs):        
        self.t = self.t + 1
        return self.act(obs, use_sample=True, num_sample=2)
        
    def update_target(self):
        self.target_dbqn.load_state_dict(self.dbqn.state_dict())        
        

In [None]:
def test_agent(agent):
    
    count = 0    
    test_episode_rew = 0
    test_return = []
    
    done = False
    agent.dbqn.eval()
    
    obs = env.reset()
    act = agent.act(obs, use_sample=True, num_sample=4)
    
    while count <= 99:
        if done:
            test_return.append(test_episode_rew)
#             print(test_episode_rew)
            test_episode_rew = 0
            count = count + 1
            
            obs = env.reset()
            act = agent.act(obs, use_sample=True, num_sample=4)
        
        obs1, rew, done, _ = env.step(act)       
        act = agent.act(obs1, use_sample=True, num_sample=0)        
        test_episode_rew = test_episode_rew + rew        
    
    agent.dbqn.train()
    return np.mean(np.array(test_return))

In [None]:
runs = 5
run_result = []

# lambda1 = lambda lr: lr * 0.999

In [None]:
for run in range(runs):
    
    lr = 1e-3
    batch_size = 32
    gamma = 0.95
    STEPS = 30000
    writer = SummaryWriter()
    dbqn = BayesianNetwork(env, 4, batch_size).to(DEVICE)    
    target_dbqn = BayesianNetwork(env, 4, batch_size).to(DEVICE)
    agent = DBQN_learn(dbqn, target_dbqn, gamma, lr, batch_size, STEPS)
    
#     scheduler = optim.lr_scheduler.LambdaLR(agent.optimizer, lr_lambda=lambda1)
    
    done = False

    episode_rew = 0
    episode_count = 0
    res = []

    obs = env.reset()
    act = agent.reset(obs)         

    while agent.t <= STEPS or episode_count < 300:

        if done:
            print("Episode " + str(episode_count) + " with reward = " + str(episode_rew))  
            writer.add_scalar('data/reward', episode_rew, episode_count)
            res.append(episode_rew)
            episode_rew = 0
            episode_count = episode_count + 1                
            
#             if episode_count < 100:
#                 scheduler.step()

            if episode_count%25 == 0:
                test_result = test_agent(agent)
                print("Test Result = " + str(test_result))
                writer.add_scalar('data/test_reward', test_result, episode_count)

            obs = env.reset()
            act = agent.reset(obs)   

        obs1, rew, done, _ = env.step(act)       
        act = agent.step(obs, act, rew, obs1, done)
        obs = obs1
        episode_rew = episode_rew + rew    
            
    writer.export_scalars_to_json("./all_scalars.json")
    writer.close()
        
    run_result.append(res)

Episode 0 with reward = 10.0
Episode 1 with reward = 9.0
Episode 2 with reward = 10.0
Episode 3 with reward = 14.0
Episode 4 with reward = 11.0
Episode 5 with reward = 10.0
Episode 6 with reward = 10.0
Episode 7 with reward = 12.0
Episode 8 with reward = 11.0
Episode 9 with reward = 30.0
Episode 10 with reward = 10.0
Episode 11 with reward = 11.0
Episode 12 with reward = 14.0
Episode 13 with reward = 11.0
Episode 14 with reward = 9.0
Episode 15 with reward = 10.0
Episode 16 with reward = 11.0
Episode 17 with reward = 15.0
Episode 18 with reward = 10.0
Episode 19 with reward = 10.0
Episode 20 with reward = 10.0
Episode 21 with reward = 9.0
Episode 22 with reward = 9.0
Episode 23 with reward = 11.0
Episode 24 with reward = 10.0
Test Result = 9.59
Episode 25 with reward = 13.0
Episode 26 with reward = 12.0
Episode 27 with reward = 9.0
Episode 28 with reward = 10.0
Episode 29 with reward = 9.0
Episode 30 with reward = 9.0
Episode 31 with reward = 12.0
Episode 32 with reward = 9.0
Episode 3

In [None]:
min_length = 10000
for i in range(len(run_result)):    
    if min_length > len(run_result[i]):
        min_length = len(run_result[i])   

In [None]:
tmp_result = [run_result[i][:min_length] for i in range(10)]
tmp_result = np.stack( tmp_result, axis=0 )

In [None]:
plt.plot(np.mean(tmp_result, 0))