In [65]:
import torch
import gym
from gym import envs
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torch.autograd import Variable
from torch.distributions import Categorical

In [82]:
lenobs = 100800
class ActorCritic(nn.Module):
    def __init__(self):
        super(ActorCritic, self).__init__()
        self.l1 = nn.Linear(lenobs,25)
        self.l2 = nn.Linear(25,50)
        self.actor_lin1 = nn.Linear(50,2)
        self.l3 = nn.Linear(50,25)
        self.critic_lin1 = nn.Linear(25,1)

    def forward(self,x):
        x = F.normalize(x,dim=0)
        y = F.relu(self.l1(x))
        y = F.relu(self.l2(y))
        actor = F.softmax(self.actor_lin1(y),dim=0)
#         actor = F.log_softmax(self.actor_lin1(y),dim=0)
        c = F.relu(self.l3(y.detach()))
        critic = torch.tanh(self.critic_lin1(c))
        return actor, critic

In [155]:
env = gym.make('PongNoFrameskip-v0')
moveMapping = {
    0:2,
    1:3
}

model = ActorCritic()
optimizer = optim.Adam(lr=1e-4,params=model.parameters())

'''
loss = Variable(loss, requires_grad = True)
actor_loss = Variable(actor_loss, requires_grad = True)
critic_loss = Variable(critic_loss, requires_grad = True)
'''

for i_episode in range(200):
#     reward = 0.0
    values = []
    rewards = []
    logprobs = []
    observation = env.reset()
    print('---------------')
    done = False
    N = 0
    while N<10 and done == False:
        N+=1
#         print(t)
        pobservation = torch.from_numpy(observation)
        flattened_pobservation = pobservation.view(-1).float()
        policy, value = model(flattened_pobservation)
#         print(policy)
        values.append(value)
        sampler = Categorical(policy)
        action = sampler.sample()
#         print(action)
#         action = np.random.choice(np.array([0,1]), p = policy.view(2,).data.numpy())
        logprobs.append(torch.log(policy.view(-1)[action]))
#         print('Action: {}'.format('right' if action==2 else 'left'))
        observation, reward, done, log = env.step(moveMapping[action.item()])
        rewards.append(reward)
#         print('{}.Reward: {}'.format(t,reward))
#         print('---')
        if done:
            print(done)
            print('Episode:{} State:{} Reward:{}'.format(i_episode,t,reward))
            print("Episode finished after {} timesteps".format(N+1))    
            break
        else:
            G = value
        
      
    # Reversing because earlier actions have greater value
    torch_values = torch.tensor(values, requires_grad = True).view(-1).flip(0)
    torch_rewards = torch.tensor(rewards, requires_grad = True).flip(0)
    torch_logprobs = torch.tensor(logprobs, requires_grad = True).flip(0)
    
    
    returns = []
    gamma = 0.95
    clc = 0.1
    ret = torch.tensor([G])
    for r in torch_rewards:
        ret = r + gamma*ret
        returns.append(ret)
    returns = torch.tensor(returns, requires_grad = True)
    returns = F.normalize(returns,dim=0)
    actor_loss = -1*torch_logprobs * (returns - torch_values)
    critic_loss = torch.pow(torch_values - returns,2)
    loss = actor_loss.sum() + clc*critic_loss.sum()
    print('Loss: {}'.format(loss))
#     print('Starting Backpropagation')
    loss.backward()
    optimizer.step()
#     print('Completed Backpropagation')


---------------
Loss: 1.0126217603683472
---------------
Loss: 1.0142792463302612
---------------
Loss: 1.0116404294967651
---------------
Loss: 1.0172951221466064
---------------
Loss: 1.0147579908370972
---------------
Loss: 1.0038697719573975
---------------
Loss: 1.0099772214889526
---------------
Loss: 1.0079989433288574
---------------
Loss: 1.0111610889434814
---------------
Loss: 1.0080387592315674
---------------
Loss: 1.0110759735107422
---------------
Loss: 1.0115782022476196
---------------
Loss: 1.0074641704559326
---------------
Loss: 1.0081162452697754
---------------
Loss: 1.0098222494125366
---------------
Loss: 1.0104503631591797
---------------
Loss: 1.011675238609314
---------------
Loss: 1.014186978340149
---------------
Loss: 1.013444185256958
---------------
Loss: 1.0095933675765991
---------------
Loss: 1.0130388736724854
---------------
Loss: 1.007004737854004
---------------
Loss: 1.0086712837219238
---------------
Loss: 1.0161775350570679
---------------
Loss

In [None]:
############################ Experimentations ###############################################

In [139]:
from torch.distributions import Categorical
a = torch.tensor([56,43,12,78,9]).float()
# Note that this is equivalent to what used to be called multinomial
m = Categorical(a)
b = m.sample()
b

tensor(3)

In [152]:
a = torch.tensor(3)
a.item()

3

In [None]:
#############################################################################################