In [1]:
import torch
import gym
from gym import envs
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
# from torch.autograd import Variable
from torch.distributions import Categorical
from torchvision import transforms
# from torch.utils.tensorboard import SummaryWriter

In [2]:
lenobs = 21190
class ActorCritic(nn.Module):
    def __init__(self):
        super(ActorCritic, self).__init__()
        self.l1 = nn.Linear(lenobs,25)
        self.l2 = nn.Linear(25,50)
        self.actor_lin1 = nn.Linear(50,3)
        self.l3 = nn.Linear(50,25)
        self.critic_lin1 = nn.Linear(25,1)

    def forward(self,x):
        x = F.normalize(x,dim=0)
        y = F.relu(self.l1(x))
        y = F.normalize(y,dim=0)
        y = F.relu(self.l2(y))
        y = F.normalize(y,dim=0)
#         actor = F.softmax(self.actor_lin1(y),dim=0)
        actor = F.log_softmax(self.actor_lin1(y),dim=0)
        c = F.relu(self.l3(y.detach()))
        critic = torch.tanh(self.critic_lin1(c))
        return actor, critic

In [3]:
env = gym.make('Pong-v0')
env.unwrapped.get_action_meanings()

['NOOP', 'FIRE', 'RIGHT', 'LEFT', 'RIGHTFIRE', 'LEFTFIRE']

In [5]:
# tb = SummaryWriter()
# env = gym.make('PongNoFrameskip-v0')
env = gym.make('Pong-v0')

moveMapping = {
    0:0,
    1:2,
    2:3
}

# 0 NOOP
# 2 RIGHT
# 3 LEFT

model = ActorCritic()
optimizer = optim.Adam(lr=1e-4,params=model.parameters())

model.train()


for i_episode in range(10):
    print('Epoch {}'.format(i_episode))
    values = []
    rewards = []
    logprobs = []
    observation = env.reset()
    preprocess = transforms.Compose([
    transforms.ToPILImage(),
    ])
    
    
    done = False
    N = 0
    while done == False and N<10:
        N+=1
#         print(t)
        pobservation = torch.from_numpy(observation).permute(2,0,1)
        cropped_image = transforms.functional.crop(preprocess(pobservation),32,15,163,130)
        gs_image = transforms.functional.to_grayscale(cropped_image)
        gs_tensor = transforms.ToTensor()(gs_image)
    
        flattened_pobservation = gs_tensor.view(-1).float()
        policy, value = model(flattened_pobservation)
#         print('Policy:{}'.format(policy))
        values.append(value.item())
#         print('Values:{}'.format(values))
        sampler = Categorical(policy)
        action = sampler.sample()
#         print(action.item())
#         action = np.random.choice(np.array([0,1]), p = policy.view(2,).data.numpy())
        logprobs.append(policy[action.item()].item())
#         print('Logprobs: {}'.format(logprobs))
#         print('Action: {}'.format('right' if action==2 else 'left'))
        
        observation, reward, done, log = env.step(moveMapping[action.item()])
#         rewards.append(reward)
#         print('rewards: {}'.format(rewards))
#         print('---')
        if done:
            rewards.append(1.0)
        else:
            rewards.append(reward)
        
      
    # Reversing because earlier actions need to be discounted
    torch_values = torch.tensor(values, requires_grad = True).view(-1).flip(0)
    torch_rewards = torch.tensor(rewards, requires_grad = True).flip(0)
    torch_logprobs = torch.tensor(logprobs, requires_grad = True).flip(0)
#     print((torch_logprobs<0).sum()/torch_logprobs.shape[0])
    

    
    
    returns = []
    gamma = 0.90
    clc = 0.1
    ret = torch.tensor([0])
    for r in torch_rewards:
        ret = r + gamma*ret
        returns.append(ret)
    returns = torch.tensor(returns, requires_grad = True)
#     print(returns.shape)
    returns = F.normalize(returns,dim=0)
#     print('returns.mean():{} returns.std():{}'.format(returns.mean(), returns.std()))
    actor_loss = -1*torch_logprobs * (returns - torch_values.detach())
    critic_loss = torch.pow(torch_values - returns,2)
    loss = actor_loss.sum() + clc*critic_loss.sum()
#     tb.add_scalar('Loss',loss,i_episode)
#     print('Loss: {}'.format(loss))
#     print('Starting Backpropagation')
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print('Return: {}'.format(returns.sum().item()))
    print('---------------')

print('Completed Training')

Epoch 0
Return: 0.0
---------------
Epoch 1
Return: 0.0
---------------
Epoch 2
Return: 0.0
---------------
Epoch 3
Return: 0.0
---------------
Epoch 4
Return: 0.0
---------------
Epoch 5
Return: 0.0
---------------
Epoch 6
Return: 0.0
---------------
Epoch 7
Return: 0.0
---------------
Epoch 8
Return: 0.0
---------------
Epoch 9
Return: 0.0
---------------
Completed Training


Play with trained model

In [8]:
import time
observation = env.reset()
model.eval()
done = False
a = time.time()
while done == False:
    pobservation = torch.from_numpy(observation).permute(2,0,1)
    cropped_image = transforms.functional.crop(preprocess(pobservation),32,15,163,130)
    gs_image = transforms.functional.to_grayscale(cropped_image)
    gs_tensor = transforms.ToTensor()(gs_image)
    flattened_pobservation = gs_tensor.view(-1).float()
    policy, value = model(flattened_pobservation)
    sampler = Categorical(policy)
    action = sampler.sample()
    observation, reward, done, log = env.step(moveMapping[action.item()])
    env.render()
    
if done:
    env.close()
    
b = time.time()-a
print(b)

6.323566198348999


Play without trained model

In [9]:
# env = gym.make('PongNoFrameskip-v0')
env = gym.make('Pong-v0')

# moveMapping = {
#     0:2,
#     1:3
# }

m1 = ActorCritic()
observation = env.reset()
import time
# observation = env.reset()
m1.eval()
done = False
a = time.time()
count = 0
preprocess = transforms.Compose([
    transforms.ToPILImage(),
])
while done == False:
    pobservation = torch.from_numpy(observation).permute(2,0,1)
    
    cropped_image = transforms.functional.crop(preprocess(pobservation),32,15,163,130)
    gs_image = transforms.functional.to_grayscale(cropped_image)
    gs_tensor = transforms.ToTensor()(gs_image)
    
    flattened_pobservation = gs_tensor.view(-1).float()
    policy, value = m1(flattened_pobservation)
    sampler = Categorical(policy)
    action = sampler.sample()
    observation, reward, done, log = env.step(moveMapping[action.item()])
    if reward == 1:
        count+=1
    env.render()
    
if done:
    env.close()
    
b = time.time()-a
# print(b)
print(count)

0
