In [2]:
import torch
import gym
from gym import envs
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torch.distributions import Categorical
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
import pickle

In [3]:
lenobs = 21190
class ActorCritic(nn.Module):
    def __init__(self):
        super(ActorCritic, self).__init__()
        self.l1 = nn.Linear(lenobs,25)
        self.l2 = nn.Linear(25,50)
        self.actor_lin1 = nn.Linear(50,3)
        self.l3 = nn.Linear(50,25)
        self.critic_lin1 = nn.Linear(25,1)

    def forward(self,x):
        x = F.normalize(x,dim=0)
        y = F.relu(self.l1(x))
        y = F.normalize(y,dim=0)
        y = F.relu(self.l2(y))
        y = F.normalize(y,dim=0)
#         actor = F.softmax(self.actor_lin1(y),dim=0)
        actor = F.log_softmax(self.actor_lin1(y),dim=0)
        c = F.relu(self.l3(y.detach()))
        critic = torch.tanh(self.critic_lin1(c))
        return actor, critic

In [4]:
env = gym.make('Pong-v0')
env.unwrapped.get_action_meanings()

['NOOP', 'FIRE', 'RIGHT', 'LEFT', 'RIGHTFIRE', 'LEFTFIRE']

In [5]:
tb = SummaryWriter()
# env = gym.make('PongNoFrameskip-v0')
env = gym.make('Pong-v0')

moveMapping = {
    0:0,
    1:2,
    2:3
}

# 0 NOOP
# 2 RIGHT
# 3 LEFT

model = ActorCritic()
optimizer = optim.Adam(lr=1e-3,params=model.parameters())

one_reward_count = [0]
pickle.dump(one_reward_count, open('one_reward_count.pkl', 'wb'))

i_episode = 0
preprocess = transforms.Compose([
    transforms.ToPILImage(),
])


minloss = float('inf')

while True:
    i_episode+=1
    model.train()
    print('Epoch {}'.format(i_episode))
    values = []
    rewards = []
    logprobs = []
    observation = env.reset()

    done = False
    N = 0
    
    
    while done == False and N<100:
        N+=1
        pobservation = torch.from_numpy(observation).permute(2,0,1)
        cropped_image = transforms.functional.crop(preprocess(pobservation),32,15,163,130)
        gs_image = transforms.functional.to_grayscale(cropped_image)
        gs_tensor = transforms.ToTensor()(gs_image)
        flattened_pobservation = gs_tensor.view(-1).float()
        policy, value = model(flattened_pobservation)
        values.append(value.item())
        sampler = Categorical(policy)
        action = sampler.sample()
        logprobs.append(policy[action.item()].item())
        observation, reward, done, log = env.step(moveMapping[action.item()])
        # To save the number of times the agent won
        if reward == 1.0:
            one_reward_count = pickle.load(open('one_reward_count.pkl', 'rb'))
            one_reward_count[0]+=1
            pickle.dump(one_reward_count, open('one_reward_count.pkl', 'wb'))
            
        if reward == 1.0:
            rewards.append(5.0)
        elif reward == -1.0:
            rewards.append(-6.0)
        else:
            rewards.append(reward)
        

    torch_values = torch.tensor(values, requires_grad = True).view(-1)#.flip(0)
    torch_rewards = torch.tensor(rewards, requires_grad = True).flip(0)
    torch_logprobs = torch.tensor(logprobs, requires_grad = True)#.flip(0)
    

    returns = []
    gamma = 0.90
    clc = 0.1
    ret = torch.tensor([0])
    for r in torch_rewards:
        ret = r + gamma*ret
        returns.append(ret)
    returns = torch.tensor(returns, requires_grad = True)
    returns = F.normalize(returns,dim=0)
    actor_loss = -1*torch_logprobs * (returns - torch_values.detach())
    critic_loss = torch.pow(torch_values - returns,2)
    loss = actor_loss.sum() + clc*critic_loss.sum()
    tb.add_scalar('Loss',loss,i_episode)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    #Saving the best model
    if loss < minloss:
        minloss = loss
        model.eval()
        model_path = './a2c.pth'
        state_dict = model.state_dict()
        torch.save(state_dict, model_path)
    
    
    print('---------------')

print('Completed Training')

Epoch 1
---------------
Epoch 2
---------------
Epoch 3
---------------
Epoch 4
---------------
Epoch 5
---------------
Epoch 6
---------------
Epoch 7
---------------
Epoch 8
---------------
Epoch 9
---------------
Epoch 10
---------------
Epoch 11
---------------
Epoch 12
---------------
Epoch 13
---------------
Epoch 14
---------------
Epoch 15
---------------
Epoch 16
---------------
Epoch 17
---------------
Epoch 18
---------------
Epoch 19
---------------
Epoch 20
---------------
Epoch 21
---------------
Epoch 22
---------------
Epoch 23
---------------
Epoch 24
---------------
Epoch 25
---------------
Epoch 26
---------------
Epoch 27
---------------
Epoch 28
---------------
Epoch 29
---------------
Epoch 30
---------------
Epoch 31
---------------
Epoch 32
---------------
Epoch 33
---------------
Epoch 34
---------------
Epoch 35
---------------
Epoch 36
---------------
Epoch 37
---------------
Epoch 38
---------------
Epoch 39
---------------
Epoch 40
---------------
Epoch 41


KeyboardInterrupt: 

Play with trained model

In [None]:
import time
observation = env.reset()
model.eval()
done = False
a = time.time()
while done == False:
    pobservation = torch.from_numpy(observation).permute(2,0,1)
    cropped_image = transforms.functional.crop(preprocess(pobservation),32,15,163,130)
    gs_image = transforms.functional.to_grayscale(cropped_image)
    gs_tensor = transforms.ToTensor()(gs_image)
    flattened_pobservation = gs_tensor.view(-1).float()
    policy, value = model(flattened_pobservation)
    sampler = Categorical(policy)
    action = sampler.sample()
    observation, reward, done, log = env.step(moveMapping[action.item()])
    env.render()
    
if done:
    env.close()
    
b = time.time()-a
print(b)

Play without trained model

In [None]:
# env = gym.make('PongNoFrameskip-v0')
env = gym.make('Pong-v0')

# moveMapping = {
#     0:2,
#     1:3
# }

m1 = ActorCritic()
observation = env.reset()
import time
# observation = env.reset()
m1.eval()
done = False
a = time.time()
count = 0
preprocess = transforms.Compose([
    transforms.ToPILImage(),
])
while done == False:
    pobservation = torch.from_numpy(observation).permute(2,0,1)
    
    cropped_image = transforms.functional.crop(preprocess(pobservation),32,15,163,130)
    gs_image = transforms.functional.to_grayscale(cropped_image)
    gs_tensor = transforms.ToTensor()(gs_image)
    
    flattened_pobservation = gs_tensor.view(-1).float()
    policy, value = m1(flattened_pobservation)
    sampler = Categorical(policy)
    action = sampler.sample()
    observation, reward, done, log = env.step(moveMapping[action.item()])
    if reward == 1:
        count+=1
    env.render()
    
if done:
    env.close()
    
b = time.time()-a
# print(b)
print(count)