In [3]:
import torch as th

In [10]:
th.ones(2, 3)


 1  1  1
 1  1  1
[torch.FloatTensor of size 2x3]

In [11]:
a = th.Tensor([[1, 2],[3, 4]])
a


 1  2
 3  4
[torch.FloatTensor of size 2x2]

In [12]:
a_gpu = a.cuda()
print(a_gpu)
a_cpu = a_gpu.cpu()
print(a_cpu)


 1  2
 3  4
[torch.cuda.FloatTensor of size 2x2 (GPU 0)]


 1  2
 3  4
[torch.FloatTensor of size 2x2]



In [13]:
a.view(1, -1)


 1  2  3  4
[torch.FloatTensor of size 1x4]

In [14]:
print(a.mean(dim=0))
print(a.std(dim=1))


 2
 3
[torch.FloatTensor of size 2]


 0.7071
 0.7071
[torch.FloatTensor of size 2]



In [2]:
from torch import autograd as ag
ag.Variable

torch.autograd.variable.Variable

In [6]:
a = ag.Variable(th.ones(2, 3), requires_grad=True)
print(a.data)
print(a.grad)



 1  1  1
 1  1  1
[torch.FloatTensor of size 2x3]

None


In [8]:
x = ag.Variable(th.ones(2), requires_grad=True)
y = 5 * (x + 2) ** 2

o = (1 / 2) * th.sum(y)
o.backward()

x.grad

Variable containing:
 15
 15
[torch.FloatTensor of size 2]

In [58]:
import gym
import numpy as np
import collections

dtype = th.cuda.FloatTensor

SAR = collections.namedtuple('SAR', 's a r')

class Environment(object):
  def __init__(self, env_name='CartPole-v0', max_episode_steps=None):
    self.env = gym.make(env_name)
    
    # Observation and action sizes
    discrete = isinstance(self.env.action_space, gym.spaces.Discrete)
    assert discrete
    
    self.ob_dim = self.env.observation_space.shape[0]
    self.ac_dim = self.env.action_space.n
    self.max_episode_steps = max_episode_steps or self.env.spec.max_episode_steps
  
  def sample_rollouts(self, policy, batch_size):
    """Samples complete episodes of at least |batch_size| under |policy|.
    
    Each episode is of at most self.max_episode_steps.
    Complete episodes (up to max_episode_steps) are sampled until at least
    |batch_size|.
    """
    episodes = []
    env_need_reset = True
    steps = 0
    while True:
      if env_need_reset:
        ob = self.env.reset()
        if steps >= batch_size:
          break
        episode = []
        episodes.append(episode)
      ac = int(policy(ob))
      ob, r, env_need_reset, _ = self.env.step(ac)
      episode.append(SAR(ob, ac, r))
      steps += 1
      env_need_reset |= len(episode) >= self.max_episode_steps
    return episodes

class Policy(object):
  def __init__(self, obs_dim, action_dim):
    self.obs_dim = obs_dim
    self.action_dim = action_dim
    self.model = self._create_policy_nn()
    
  def _create_policy_nn(self):
    hidden_dim = 64
    model = th.nn.Sequential(
      th.nn.Linear(self.obs_dim, hidden_dim),
      th.nn.Linear(hidden_dim, self.action_dim),
      th.nn.LogSoftmax(),
    )
    model.cuda()
    return model
    
  def get_action(self, obs_np):
    obs_var = ag.Variable(th.Tensor(obs_np).type(dtype))
    log_probs = self.model(obs_var).data
    probs = th.exp(log_probs)
    ac = th.multinomial(probs, 1)
    out_np = ac.cpu().numpy()
    return out_np

env = Environment()
policy = Policy(env.ob_dim, env.ac_dim)
env.sample_rollouts(policy.get_action, 50)


[[SAR(s=array([-0.02376615, -0.17296766,  0.03296098,  0.30608094]), a=0, r=1.0),
  SAR(s=array([-0.02722551,  0.02166947,  0.0390826 ,  0.02397259]), a=1, r=1.0),
  SAR(s=array([-0.02679212, -0.17399053,  0.03956205,  0.32872584]), a=0, r=1.0),
  SAR(s=array([-0.03027193, -0.3696527 ,  0.04613657,  0.63361766]), a=0, r=1.0),
  SAR(s=array([-0.03766498, -0.56538685,  0.05880892,  0.9404658 ]), a=0, r=1.0),
  SAR(s=array([-0.04897272, -0.37110475,  0.07761824,  0.66682631]), a=1, r=1.0),
  SAR(s=array([-0.05639481, -0.56721548,  0.09095476,  0.98290359]), a=0, r=1.0),
  SAR(s=array([-0.06773912, -0.76343062,  0.11061283,  1.30271355]), a=0, r=1.0),
  SAR(s=array([-0.08300773, -0.95976819,  0.13666711,  1.62787397]), a=0, r=1.0),
  SAR(s=array([-0.1022031 , -1.1562068 ,  0.16922458,  1.95983953]), a=0, r=1.0),
  SAR(s=array([-0.12532723, -0.9632351 ,  0.20842138,  1.72403161]), a=1, r=1.0),
  SAR(s=array([-0.14459194, -1.16004375,  0.24290201,  2.07368902]), a=0, r=1.0)],
 [SAR(s=array([