In [1]:
import torch
import numpy as np
import gym

env = gym.make("CartPole-v1")

In [18]:
def _create_grid(lower, upper, bins, offsets):
    return [np.linspace(lower[dim], upper[dim], bins[dim] + 1)[1:-1] + offsets[dim] for dim in range(len(bins))]

def _create_tilings(lower, upper, specs):
    return [_create_grid(lower, upper, bins, offsets) for bins, offsets in specs]

def _discretize(sample, grid):
    return tuple(int(np.digitize(s, g)) for s, g in zip(sample, grid))

def _tile_encoding(sample, tilings):
    return [_discretize(sample, grid) for grid in tilings]

def _get_indices(tile_encoding):
    n_bins = 10
    n_tilings = len(tile_encoding)
    indices = [i*n_bins + j + n*(n_bins**2) for n, (i, j) in enumerate(tile_encoding)]
    features = np.zeros(n_bins**2*n_tilings)
    features[indices]=1

    return features

def _tiling_features(x, lower, upper, specs):

    tilings = _create_tilings(lower, upper, specs)
    
    if len(x.size()) == 1:
        tile_encoding = _tile_encoding(x[[0,2]], tilings)
        features = _get_indices(tile_encoding)
        
        return torch.cat([x[[1,3]],torch.Tensor(features)], -1)
    
    elif len(x.size()) == 2:       
        features = []
        for xi in x:
            tile_encoding = _tile_encoding(xi[[0,2]], tilings)
            features.append(_get_indices(tile_encoding))
        
        print(features)

        return torch.cat([x[:,[1,3]],torch.Tensor(features)], -1)







In [14]:
tilings = _create_tilings(lower=[-3,-0.3], upper=[3,0.3], specs=[([10,10],[0,0]), ([10,10],[2,2])])
print(tilings)
for xi in x:    
    t=_tile_encoding(xi, tilings)
    print(t)
    print(_get_indices(t))


[[array([-2.4, -1.8, -1.2, -0.6,  0. ,  0.6,  1.2,  1.8,  2.4]), array([-0.24, -0.18, -0.12, -0.06,  0.  ,  0.06,  0.12,  0.18,  0.24])], [array([-0.4,  0.2,  0.8,  1.4,  2. ,  2.6,  3.2,  3.8,  4.4]), array([1.76, 1.82, 1.88, 1.94, 2.  , 2.06, 2.12, 2.18, 2.24])]]
[(6, 5), (3, 0)]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0.]
[(6, 9), (3, 5)]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0

In [10]:
x = torch.Tensor([[1,0,3,1],[1,2,3,0],[2,3,1,1]])
x1 = torch.Tensor([1,2,3,4])
len(x.size())

2

In [19]:
print(_tiling_features(x, lower=[-3,-0.3], upper=[3,0.3], specs=[([10,10],[0,0]), ([10,10],[2,2]), ([10,10],[-2,-2])]))

[tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.,

ValueError: only one element tensors can be converted to Python scalars

In [76]:
env.observation_space.low

array([-4.8000002e+00, -3.4028235e+38, -4.1887903e-01, -3.4028235e+38],
      dtype=float32)