In [34]:
from typing import List, Tuple

import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.utilities import DistributedType

import torch
from torch import Tensor, nn
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim.optimizer import Optimizer
from torch.utils.data.dataset import IterableDataset

import gym 
import gym_simplifiedtetris


In [39]:
class CriticNet(nn.Module):
    def __init__(self, obs_size, hidden_size = 50):
        super().__init__()
        
        self.critic = nn.Sequential(
            nn.Linear(obs_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1)
        )
        
    def forward(self, x):
        value = self.critic(x)
        return value

class ActorNet(nn.Module):
    def __init__(self, obs_size, n_actions, hidden_size = 50):
        super().__init__()

        self.actor = nn.Sequential(
            nn.Linear(obs_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_actions),
        )

    def forward(self, x):
        logits = self.actor(x)
        dist = Categorical(logits=logits)
        action = pi.sample()

        return dist, action


class ActorCritic():
    def __init__(self, critic, actor, hidden_size = 50):
        self.critic = critic
        self.actor = actor 
    
    @torch.no_grad()
    def __call__(self, state: torch.Tensor):
        dist, action = self.actor(state)
        probs = dist.log_prob(actions)
        val = self.critic(state)
        
        return dist, action, probs, val


In [36]:
class RLDataSet(IterableDataset):
    def __init__(self, batch_maker):
        self.batch_maker = batch_maker
    def __iter__(self):
        return self.batch_maker()

In [37]:
class PPOLightning(LightningModule):
    
    def __init__(
        self,
        lr: float = 3e-4,
        batch_size: int = 5,
        clip_eps: float = 0.2,
        tau : float = 0.95,
        epoch_steps: int = 50,
        gamma: float = 0.99
    ):
        super().__init__()
        self.save_hyperparameters()
        
        self.env = gym.make("simplifiedtetris-binary-10x10-1-v0")
        self.state = torch.Tensor(self.env.reset())
        self.ep_step = 0
        obs_size = self.env.observation_space.shape[0]
        n_actions = self.env.action_space.n
        
        self.batch_states = []
        self.batch_actions = []
        self.batch_probs = []
        self.batch_advs = []
        self.batch_vals = []
        self.ep_rewards = []
        self.ep_vals = []
        
        self.critic = CriticNet(obs_size)
        self.actor = ActorNet(obs_size,n_actions)
        
        self.agent = ActorCritic(self.critic, self.actor)
    
    def forward(self, x):
        
        dist, action = self.actor(x)
        val = self.critic(x)
        
        return dist, action, val
        
    def act_loss(self,state,action,prob_old,val,adv):
        dist, _ = self.actor(state)
        prob = dist.log_prob(action)
        ratio = torch.exp(prob - prob_old)
        #PPO update
        clip = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * adv
        loss = -(torch.min(ratio * adv, clip)).mean()
        return loss
    
    def crit_loss(self,state,action,prob_old,val,adv):
        val_new = self.critic(state)
        #MSE
        loss = (val - val_new).pow(2).mean()
        return loss
        
    def compute_gae(self, rewards, next_value, values):
        
        rs = rewards + [next_value]
        vals = values + [next_value]
        
        x = []
        for i in range(len(rs)-1):
            x.append(rs[i]+self.gamma*vals[i+1] - vals[i])
    
        a = self.compute_reward(x, self.gamma * self.tau)

        return a
    
    def compute_reward(rewards, gamma):
        
        reward = []
        sum_rs = 0

        for r in reversed(rewards):
            sum_rs = (sum_rs * discount) + r
            reward.append(sum_rs)

        return list(reversed(reward))

    
    def make_batch(self):
        for i in range(self.hparams.epoch_steps):
            dist, action, probs, val = self.agent(self.state)
            next_state, reward, done, _ = env.step(action.item())
            self.ep_step += 1
            
            self.batch_states.append(self.state)
            self.batch_actions.append(action)
            self.batch_probs.append(probs)
            self.ep_rewards.append(reward)
            self.ep_vals.append(val.item())
            
            end = i == (self.hparams.epoch_steps -1)
            if done or end:
                
                if end and not done:
                    #if epoch ends before terminal state, bootstrap value
                    with torch.no_grad():
                        _,_,_,val = self.agent(self.state)
                        next_val = val.item()
                else:
                    next_val = 0
                
                #compute batch discounted rewards (GAE)
                self.ep_rewards.append(last_val)
                self.batch_vals += self.compute_reward(self.ep_rewards, self.gamma)
                self.batch_advs += self.compute_gae(self.ep_rewards,self.ep_vals, next_val)
                
                self.ep_rewards.clear()
                self.ep_vals.clear()
                self.ep_step = 0
                self.state = torch.Tensor(self.env.reset())
                
            if end:
                batch = zip(self.batch_states,
                            self.batch_actions,
                            self.batch_probs,
                            self.batch_vals,
                            self.batch_advs)
                for s, a, p, v, a in batch:
                    yield s, a, p, v, a
                
                self.batch_states.clear()
                self.batch_actions.clear()
                self.batch_probs.clear()
                self.batch_vals.clear()
                self.batch_advs.clear()
    
    def training_step(self, batch, batch_id, opt_id):
        
        state,action,prob_old,val,adv = batch
        # normalize adv
        adv = (adv - adv.mean())/adv.std()
        
        if opt_id == 0:
            loss = self.act_loss(state, action, prob_old, val, adv)
            return loss

        elif opt_id == 1:
            loss = self.crit_loss(state, action, prob_old, val, adv)
            return loss

    
    def configure_optimizers(self) -> List[Optimizer]:
        a_opt = optim.Adam(self.actor.parameters(), lr=self.hparams.lr)
        c_opt = optim.Adam(self.critic.parameters(), lr=self.hparams.lr)
        return a_opt,c_opt
    
    def __dataloader(self):
        dataset = RLDataSet(self.make_batch)
        dataloader = DataLoader(dataset=dataset, batch_size=self.hparams.batch_size)
        return dataloader
    
    def train_dataloader(self):
        return self.__dataloader()


In [38]:
model = PPOLightning()

trainer = Trainer(
    gpus=0,
    max_epochs=200,
)

trainer.fit(model)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name   | Type      | Params
-------------------------------------
0 | critic | CriticNet | 5.2 K 
1 | actor  | ActorNet  | 5.6 K 
-------------------------------------
10.8 K    Trainable params
0         Non-trainable params
10.8 K    Total params
0.043     Total estimated model params size (MB)
Widget Javascript not detected.  It may not be installed or enabled properly. Reconnecting the current kernel may help.


AttributeError: 'FloatProgress' object has no attribute 'style'