torch.roll — PyTorch 2.0 documentation
https://pytorch.org/docs/stable/generated/torch.roll.html

In [1]:
################################
# roll 1D tensor
################################

import torch

x = torch.arange(10)
x

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [2]:
shift = 1
torch.roll(x, shift, dims=0)

tensor([9, 0, 1, 2, 3, 4, 5, 6, 7, 8])

In [3]:
x.roll(shift)

tensor([9, 0, 1, 2, 3, 4, 5, 6, 7, 8])

In [4]:
x

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [5]:
shift = -1
x.roll(shift)

tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 0])

In [6]:
################################
# roll 2D tensor
################################
import torch

x = torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
x

tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

In [7]:
shift = 1
torch.roll(x, shift) 

tensor([[9, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])

In [8]:
# same
x.roll(shift) 

tensor([[9, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])

In [9]:
# same 
x.flatten().roll(shift).reshape(x.shape)

tensor([[9, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])

In [10]:
x

tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

In [11]:
shift = 1
x.roll(shift, dims=0)

tensor([[7, 8, 9],
        [1, 2, 3],
        [4, 5, 6]])

In [12]:
x.roll(shift, dims=1)

tensor([[3, 1, 2],
        [6, 4, 5],
        [9, 7, 8]])

In [13]:
###########################################
# compute q_values
###########################################

# Q value

$$ Q(s, a) = R(s, a, s^{'}) + \gamma V(s^{'}) $$

In [14]:
import gym
import torch.nn as nn
import torch
import numpy as np

class ValueNet(nn.Module):
    def __init__(self, state_dim, hidden_size=128):
        super(ValueNet, self).__init__()
        self.seq = nn.Sequential(
            nn.Linear(state_dim, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
            )
    
    def forward(self, x):
        yhat = self.seq(x)
        return yhat
    
# env 
env = gym.make("CartPole-v1", render_mode="rgb_array")

# valuenet
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
valuenet = ValueNet(state_dim, action_dim)

# parameter
GAMMA = 0.95
max_steps = 1000

# record transitions to lists
rewards = []
states = []
terminateds = []

# run 1 episode
state, info = env.reset()
for i in range(1, max_steps):
    next_state, reward, terminated, truncated, info = env.step(env.action_space.sample())
    
    # add a taransition to lists 
    rewards.append(reward)
    states.append(state)
    terminateds.append(terminated)
    
    if terminated:
        break
    
    # prepare next step
    state = next_state

env.close()

##########################
# compute q_values
##########################

# list to tensor
states = torch.tensor(np.array(states)) # states was a list of numpy arrays
rewards = torch.tensor(rewards)
terminateds = torch.tensor(terminateds)

# (n, 1) tensor -> (n, ) tensor
values = valuenet(states).squeeze() 
values

tensor([0.5118, 0.4311, 0.3736, 0.4242, 0.3700, 0.3558, 0.3414, 0.3267, 0.3353,
        0.3422, 0.3247], grad_fn=<SqueezeBackward0>)

In [15]:
next_values = values.roll(-1)
next_values

tensor([0.4311, 0.3736, 0.4242, 0.3700, 0.3558, 0.3414, 0.3267, 0.3353, 0.3422,
        0.3247, 0.5118], grad_fn=<RollBackward0>)

In [16]:
q_values = rewards + (1.0 - terminateds.float()) * GAMMA * next_values
q_values

tensor([1.4096, 1.3549, 1.4030, 1.3515, 1.3380, 1.3243, 1.3104, 1.3185, 1.3251,
        1.3084, 1.0000], grad_fn=<AddBackward0>)