# PPO in MAB environment
Code taken and adapted 
From https://github.com/nikhilbarhate99/PPO-PyTorch/blob/master/PPO_continuous.py

In [11]:
from torch.distributions import MultivariateNormal
import torch
import torch.nn.functional as F
from task.TaskGenerator import SinTaskGenerator
import matplotlib.pyplot as plt
import numpy as np

In [12]:
import gym

class SinMabEnv(gym.Env):
    
    def __init__(self, amplitude=1, phase=0, frequency=1, noise_std=0.1, min_x=-5, max_x=5):
        # Env. parameters
        self.a = amplitude
        self.p = phase
        self.f = frequency
        self.noise_std = noise_std
        
        self.min_x = -5
        self.max_x = 5
        
        self.state = 0 # There is a single state
        
    def reset(self):
        self.state = 0
        return self.state
    
    def get_state(self):
        return self.state
    
    def step(self, action):
        if action < self.min_x or action > self.max_x:
            return self.state, torch.Tensor([0]), False, _ 
        
        noise = torch.from_numpy(np.random.normal(loc=0, scale=self.noise_std, size=1))
        reward = self.a * torch.sin(self.f * action + self.p) + noise
        done = False
        return self.state, reward, done, _

In [13]:
device = "cpu"

env = SinMabEnv()

class Memory:
    def __init__(self):
        self.actions = []
        self.states = []
        self.logprobs = []
        self.rewards = []
        self.is_terminals = []
    
    def clear_memory(self):
        del self.actions[:]
        del self.states[:]
        del self.logprobs[:]
        del self.rewards[:]
        del self.is_terminals[:]

class ActorCritic(torch.nn.Module):
    def __init__(self, state_dim, action_dim, action_std):
        super(ActorCritic, self).__init__()
        # action mean range -1 to 1
        self.actor =  torch.nn.Sequential(
                torch.nn.Linear(state_dim, 128),
                torch.nn.ELU(),
                torch.nn.Linear(128, 32),
                torch.nn.ELU(),
                torch.nn.Linear(32, action_dim),
                torch.nn.ELU()
                )
        # critic
        self.critic = torch.nn.Sequential(
                torch.nn.Linear(state_dim, 128),
                torch.nn.ELU(),
                torch.nn.Linear(128, 32),
                torch.nn.ELU(),
                torch.nn.Linear(32, 1)
                )
        self.action_var = torch.full((action_dim,), action_std*action_std).to(device)
        
    def forward(self):
        raise NotImplementedError
    
    def act_no_store(self, state):
        action_mean = self.actor(state)
        cov_mat = torch.diag(self.action_var).to(device)
        
        dist = MultivariateNormal(action_mean, cov_mat)
        action = dist.sample()
        
        return action.detach()
    
    def act(self, state, memory):
        action_mean = self.actor(state)
        cov_mat = torch.diag(self.action_var).to(device)
        
        dist = MultivariateNormal(action_mean, cov_mat)
        action = dist.sample()
        action_logprob = dist.log_prob(action)
        
        memory.states.append(state)
        memory.actions.append(action)
        memory.logprobs.append(action_logprob)
        
        return action.detach()
    
    def evaluate(self, state, action):   
        action_mean = self.actor(state)
        
        action_var = self.action_var.expand_as(action_mean)
        cov_mat = torch.diag_embed(action_var).to(device)
        
        dist = MultivariateNormal(action_mean, cov_mat)
        
        action_logprobs = dist.log_prob(action)
        dist_entropy = dist.entropy()
        state_value = self.critic(state)
        
        return action_logprobs, torch.squeeze(state_value), dist_entropy

    

class PPO:
    def __init__(self, state_dim, action_dim, action_std, lr, betas, gamma, K_epochs, eps_clip):
        self.lr = lr
        self.betas = betas
        self.gamma = gamma
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs
        
        self.policy = ActorCritic(state_dim, action_dim, action_std).to(device)
        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr, betas=betas)
        
        self.policy_old = ActorCritic(state_dim, action_dim, action_std).to(device)
        self.policy_old.load_state_dict(self.policy.state_dict())
        
        self.MseLoss = torch.nn.MSELoss()
        
    def sample_action_no_store(self, state):
        return self.policy_old.act_no_store(state).cpu().data.numpy().flatten()
    
    def select_action(self, state, memory):
        state = torch.FloatTensor(state.reshape(1, -1)).to(device)
        return self.policy_old.act(state, memory).cpu().data.numpy().flatten()
    
    def update(self, memory):
        # Monte Carlo estimate of rewards:
        rewards = []
        discounted_reward = 0
        for reward, is_terminal in zip(reversed(memory.rewards), reversed(memory.is_terminals)):
            if is_terminal:
                discounted_reward = 0
            discounted_reward = reward + (self.gamma * discounted_reward)
            rewards.insert(0, discounted_reward)
        
        # Normalizing the rewards:
        rewards = torch.tensor(rewards).to(device)
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5)
        
        # convert list to tensor
        old_states = torch.squeeze(torch.stack(memory.states).to(device), 1).detach()
        old_actions = torch.squeeze(torch.stack(memory.actions).to(device), 1).detach()
        old_logprobs = torch.squeeze(torch.stack(memory.logprobs), 1).to(device).detach()
        
        # Optimize policy for K epochs:
        for _ in range(self.K_epochs):
            # Evaluating old actions and values :
            logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)
            state_values = state_values.double()
            dist_entropy = dist_entropy.double()
            
            # Finding the ratio (pi_theta / pi_theta__old):
            ratios = torch.exp(logprobs - old_logprobs.detach())

            # Finding Surrogate Loss:
            advantages = rewards - state_values.detach()   
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1-self.eps_clip, 1+self.eps_clip) * advantages
            loss = -torch.min(surr1, surr2) + 0.5*self.MseLoss(state_values, rewards) - 0.01*dist_entropy
            
            # take gradient step
            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()
            
        # Copy new weights into old policy:
        self.policy_old.load_state_dict(self.policy.state_dict())

In [14]:
solved_reward = 0.93     # stop training if avg_reward > solved_reward
log_interval = 10           # print avg reward in the interval
max_episodes = 10000        # max training episodes
max_timesteps = 1500        # max timesteps in one episode

update_timestep = 2500      # update policy every n timesteps
action_std = 0.25           # constant std for action distribution (Multivariate Normal)
K_epochs = 80               # update policy for K epochs
eps_clip = 0.2              # clip parameter for PPO
gamma = 0.99                # discount factor

lr=0.0001
betas=(0.9, 0.999)

state_dim = 1
action_dim = 1

memory = Memory()
ppo = PPO(state_dim, action_dim, action_std, lr, betas, gamma, K_epochs, eps_clip)

running_reward = 0
avg_length = 0
time_step = 0

In [15]:
# training loop
env_name = "Try"
for i_episode in range(1, max_episodes+1):
    state = torch.Tensor([env.reset()])
    
    for t in range(max_timesteps):
        time_step +=1

        # Running policy_old:
        action = torch.Tensor([ppo.select_action(state, memory)])
        state, reward, done, _ = env.step(action)
        
        state = torch.Tensor([state])

        # Saving reward and is_terminals:
        memory.rewards.append(reward)
        memory.is_terminals.append(done)

        # update if its time
        if time_step % update_timestep == 0:
            ppo.update(memory)
            memory.clear_memory()
            time_step = 0
        running_reward += reward
        
        if done:
            break

    avg_length += t

    if i_episode % log_interval == 0:
        avg = 0
        for _ in range(1000):
            action = torch.Tensor([ppo.sample_action_no_store(torch.Tensor([0]))])
            state, reward, done, _ = env.step(action)
            avg += reward
        avg = avg / 1000
        print("Average reward: {}".format(avg))
        if avg > solved_reward:
            print("########## Solved! ##########")
            torch.save(ppo.policy.state_dict(), './PPO_continuous_solved_{}.pth'.format(env_name))
            break
        
    
    # stop training if avg_reward > solved_reward
    #if running_reward > (log_interval*solved_reward):
    #    print("########## Solved! ##########")
    #    torch.save(ppo.policy.state_dict(), './PPO_continuous_solved_{}.pth'.format(env_name))
    #    break

    # save every 500 episodes
    #if i_episode % 500 == 0:
    #    torch.save(ppo.policy.state_dict(), './PPO_continuous_{}.pth'.format(env_name))

    # logging
    if i_episode % log_interval == 0:
        avg_length = int(avg_length/log_interval)
        running_reward = int((running_reward/log_interval))

        print('Episode {} \t Avg length: {} \t Avg reward: {}'.format(i_episode, avg_length, running_reward))
        running_reward = 0
        avg_length = 0

Average reward: tensor([[0.0426]], dtype=torch.float64)
Episode 10 	 Avg length: 1499 	 Avg reward: -104
Average reward: tensor([[0.2159]], dtype=torch.float64)
Episode 20 	 Avg length: 1499 	 Avg reward: 173
Average reward: tensor([[0.3452]], dtype=torch.float64)
Episode 30 	 Avg length: 1499 	 Avg reward: 394
Average reward: tensor([[0.4753]], dtype=torch.float64)
Episode 40 	 Avg length: 1499 	 Avg reward: 602
Average reward: tensor([[0.5353]], dtype=torch.float64)
Episode 50 	 Avg length: 1499 	 Avg reward: 772
Average reward: tensor([[0.5675]], dtype=torch.float64)
Episode 60 	 Avg length: 1499 	 Avg reward: 818
Average reward: tensor([[0.6807]], dtype=torch.float64)
Episode 70 	 Avg length: 1499 	 Avg reward: 951
Average reward: tensor([[0.7380]], dtype=torch.float64)
Episode 80 	 Avg length: 1499 	 Avg reward: 1061
Average reward: tensor([[0.8206]], dtype=torch.float64)
Episode 90 	 Avg length: 1499 	 Avg reward: 1165
Average reward: tensor([[0.8698]], dtype=torch.float64)
Episo