In [None]:
import torch
from torch import nn 
import numpy as np
from collections import deque
import gym

In [None]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.common=nn.Sequential(
            nn.Linear(8, 48),
            nn.ReLU(),
            nn.Linear(48, 48),
            nn.ReLU()
        )
        self.actor=nn.Sequential(
            nn.Linear(48, 4),
            nn.Softmax(dim=-1)
        )
        
        self.critic=nn.Linear(48, 4)

    def forward(self, x):
        x=self.common(x)
        return self.actor(x), self.critic(x)

In [None]:
def _cuda_tensor(x):
    return torch.tensor(x, dtype=torch.float32, device='cuda')

class Agent:
    def __init__(self):
        self.gamma=.99
        self.model=Model().to('cuda')
        self.optimizer=torch.optim.Adam(self.model.parameters(), 5e-5)
        self.env=gym.make('LunarLander-v2')
        self.episode_memory=deque()

    def step(self, i_state:np.ndarray):
        with torch.no_grad():
            probs, _=self.model.forward(_cuda_tensor(i_state))
        action=np.random.choice((0, 1, 2, 3), p=probs.detach().cpu().numpy())
        f_state, reward, done, _=self.env.step(action)
        self.episode_memory[-1].append(action)
        self.episode_memory.append([np.copy(i_state), action, reward, done, np.copy(f_state)])
        self.env.render('rgb_array')
        if not done:
            try:
                self.step(f_state)
            except:
                self.episode_memory[-1][3]=True
                return
        
    def learn(self):
        self.episode_memory.reverse()
        for transition in self.episode_memory:
            try:
                i_state, action, reward, done, f_state, next_action=transition
            except:
                i_state, action, reward, done, f_state=transition
            
            probs_tensor, qvals_tensor=self.model.forward(_cuda_tensor(i_state))
            
            policy_grad=-torch.log(probs_tensor)[action]*qvals_tensor.detach()[action]
            if not done:
                with torch.no_grad():
                    _, next_qvals=self.model.forward(_cuda_tensor(f_state))
                critic_loss=(reward+self.gamma*next_qvals[next_action]-qvals_tensor[action])**2
            else:
                critic_loss=(reward-qvals_tensor[action])**2
            
            loss=policy_grad+critic_loss
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
    
    def play_episode(self):
        self.episode_memory.clear()
        self.step(self.env.reset())


In [None]:
agent=Agent()

In [None]:
for episode in range(1500):
    agent.play_episode()
    agent.learn()