In [32]:
import numpy as np
import numpy.random as ran
import torch
import torch.nn as nn
import torch.optim as optim

%load_ext autoreload
%autoreload 1
%aimport ContinuousControlNet

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [22]:
class IdentityPolicy(ContinuousControlNet.ContinuousControlNet):
    def __init__(self):
        super().__init__(.1, 5)
    
    def forward(self, obs):
        return obs
    
    def getUnclonedCopy(self):
        return type(self)()

In [17]:
#Test to assert that we compute the normal probability correctly
#And that it doesn't change between the returned value and calling action prob
policy = IdentityPolicy()
ran.seed(2)
torch.manual_seed(3)

obs = ran.normal(0, 1, (5))
action, prob = policy.act(obs)

from scipy.stats import multivariate_normal
import scipy.linalg

#Check these are all the same
print(prob)
print(policy.computeActionProbFromObs(obs, action))
print(multivariate_normal.pdf(action, mean = obs, cov = 0.1**2))

tensor(2.1121e-07)
tensor(2.1121e-07)
2.1121052370629137e-07


In [26]:
#Test to assert that eval() actions are deterministic and that eval actions after cloning are the same
#But that these are not the same as training vals
ran.seed(2)
torch.manual_seed(3)

policy = IdentityPolicy()
obs = ran.normal(0, 1, (5))

policy.eval()
print(obs)
print(policy.act(obs))
print(policy.act(obs))
policy.train()
print(policy.act(obs))


[-0.41675785 -0.05626683 -2.1361961   1.64027081 -1.79343559]
(array([-0.41675785, -0.05626683, -2.1361961 ,  1.6402708 , -1.7934356 ],
      dtype=float32), 1)
(array([-0.41675785, -0.05626683, -2.1361961 ,  1.6402708 , -1.7934356 ],
      dtype=float32), 1)
(array([-0.68294174,  0.10275824, -2.52999081,  1.30571696, -2.08088902]), tensor(2.1121e-07))


In [44]:
class SimplePolicy(ContinuousControlNet.ContinuousControlNet):
    def __init__(self):
        super().__init__(.1, 1)
        self.fc = nn.Linear(5, 1)
        
    def forward(self, obs):
        return self.fc(obs)
    
    def getUnclonedCopy(self):
        return type(self)()

In [28]:
#Test to assert that eval() actions are deterministic across clones
ran.seed(2)
torch.manual_seed(3)

policy = SimplePolicy()
policy.eval()
clonePolicy = policy.clone()
clonePolicy.eval()

print(policy.act(obs))
print(clonePolicy.act(obs)) #Should be the same due to parameter cloning

newPolicy = SimplePolicy()
newPolicy.eval()
print(newPolicy.act(obs)) #Should be different due to different random parameter initialization

(array([-0.42809087], dtype=float32), 1)
(array([-0.42809087], dtype=float32), 1)
(array([1.8498709], dtype=float32), 1)


In [48]:
#Confirm that gradient descent only effects a policy and not its clone
ran.seed(2)
torch.manual_seed(3)

policy = SimplePolicy() #the policy we'll be training
newPolicy = policy.clone() #control 
newPolicy.eval()
optimizer = optim.Adam(policy.parameters(), lr = 0.01)

obs = ran.normal(0, 1, (5))

policy.eval()
print('Before first gradient')
print(policy.act(obs))
print(newPolicy.act(obs))

policy.train()
optimizer.zero_grad()
#Find something arbitrary to do gradient descent on. In this case, the action prob itself
action, _ = policy.act(obs)
action_prob = policy.computeActionProbFromObs(obs, action)
#print(action_prob)
torch.autograd.backward(action_prob)
#[print(parameter.grad) for parameter in [parameters for parameters in policy.parameters()]]
optimizer.step()

policy.eval()
#See if the old policy or the new policy have changed after the optimizer
print('After first gradient)')
print(policy.act(obs))
print(newPolicy.act(obs))

policy.train()
#repeat once more
optimizer.zero_grad()
action,_ = policy.act(obs)
action_prob = policy.computeActionProbFromObs(obs, action)
torch.autograd.backward(action_prob)
optimizer.step()

policy.eval()
print('After second gradient')
print(policy.act(obs))
print(newPolicy.act(obs))

Before first gradient
(array([-0.42809087], dtype=float32), 1)
(array([-0.42809087], dtype=float32), 1)
After first gradient)
(array([-0.35766163], dtype=float32), 1)
(array([-0.42809087], dtype=float32), 1)
After second gradient
(array([-0.40134367], dtype=float32), 1)
(array([-0.42809087], dtype=float32), 1)
