In [1]:
import time

import torch
import torch.nn.functional as F
import numpy as np

In [5]:
states = torch.tensor([
    [0, 1, 2],
    [3, 4, 7]
], dtype=torch.int64) # (N_STATES, STATE_SIZE)

print(states.shape)
# print(states)

torch.Size([2, 3])


In [43]:
expanded_states = states.unsqueeze(dim=1).expand(
    states.shape[0],
    4,
    states.shape[1],
).reshape(
    4 * states.shape[0],
    states.shape[1],
) # (N_STATES * N_GENS, STATE_SIZE)

expanded_states

tensor([[0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [3, 4, 7],
        [3, 4, 7],
        [3, 4, 7],
        [3, 4, 7]])

In [44]:
generators = torch.tensor([
    [0, 1, 2],
    [2, 1, 0],
    [1, 2, 0],
    [1, 0, 2]
])
print("generators.shape", generators.shape) # (N_GENS, STATE_SIZE)

generators.shape torch.Size([4, 3])


In [49]:
expanded_actions = torch.arange(0, 4).unsqueeze(dim=0).expand(
    states.shape[0],
    4
).reshape(
    states.shape[0] * 4
) # (N_GENS * STATE_SIZE)
expanded_actions

tensor([0, 1, 2, 3, 0, 1, 2, 3])

In [51]:
neighbours_states = torch.gather(
    input=expanded_states,
    dim=1,
    index=generators[expanded_actions, :]
)
neighbours_states # (N_STATES * N_GENS, STATE_SIZE) [A1(S1), A2(S1), ..., AN(SN)]

tensor([[0, 1, 2],
        [2, 1, 0],
        [1, 2, 0],
        [1, 0, 2],
        [3, 4, 7],
        [7, 4, 3],
        [4, 7, 3],
        [4, 3, 7]])

In [52]:
states

tensor([[0, 1, 2],
        [3, 4, 7]])

In [53]:
expanded_states

tensor([[0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [0, 1, 2],
        [3, 4, 7],
        [3, 4, 7],
        [3, 4, 7],
        [3, 4, 7]])

In [54]:
neighbors_policy = torch.tensor([
    [-1, -2, -3, -4],
    [-10, -20, -30, -40]
]) # (N_STATES, N_GEN)

In [57]:
neighbors_policy_flatten = neighbors_policy.reshape(
    neighbors_policy.shape[0] * neighbors_policy.shape[1]
) # (N_STATES * N_GEN) [POLICY_(A1(S1)), POLICY_(A2(S1)), ..., POLICY_(AN(SN))]
neighbors_policy_flatten

tensor([ -1,  -2,  -3,  -4, -10, -20, -30, -40])

In [58]:
parent_cumulative_policy = torch.tensor([
    0.1, 0.9
]) # (N_STATES)

In [61]:
expanded_parent_cumulative_policy = parent_cumulative_policy.unsqueeze(dim=1).expand(
    parent_cumulative_policy.shape[0],
    4
).reshape(
    4 * parent_cumulative_policy.shape[0]
) # (N_GENS * N_STATES) [CUM(S1), CUM(S1), ..., CUM(SN), CUM(SN)]

expanded_parent_cumulative_policy

tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.9000, 0.9000, 0.9000, 0.9000])

In [63]:
[0, 1, 2, 3, 4][:-2]

[0, 1, 2]