In [1]:
import torch as pt
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

In [14]:
class Net(nn.Module):
    
    def __init__(self, m, jList):
        """
        m : input dimensionality
        """
        super().__init__()
        
        self.N1L1 = nn.Linear(m, 64)
        self.N1L2 = nn.Linear(64, 64)
        self.N1L3 = nn.Linear(64, 64)
        self.N1L4 = nn.Linear(64, 3)
        
        self.N2L1 = nn.Linear(6, 64)
        self.N2L2 = nn.Linear(64, 64)
        self.N2L3 = nn.Linear(64, 64)
        self.N2L4 = nn.Linear(64, 3)
        
        self.jList = jList
    
    def subNet1(self, inState):
        out1 = F.leaky_relu(self.N1L1(inState))
        out1 = F.leaky_relu(self.N1L2(out1))
        out1 = F.leaky_relu(self.N1L3(out1))
        return self.N1L4(out1)
    
    def subNet2(self, in2):
        out2 = F.leaky_relu(self.N2L1(in2))
        out2 = F.leaky_relu(self.N2L2(out2))
        out2 = F.leaky_relu(self.N2L3(out2))
        return self.N2L4(out2)
    
    def forward(self, inState):
        
        out1 = self.subNet1(inState)
        
        out2 = pt.zeros_like(out1)
        
        for j in self.jList:
            # concatenate the q values for this site to all sites
            # Self correlation allowed for now
            in2 = pt.cat((out1, out1[j].repeat([out1.shape[0], 1])), dim = 1)
            out2 += self.subNet2(in2)
        
        return out2

In [None]:
x = pt.rand(5,3)
y = pt.rand(3)
x,y

In [None]:
z = y.repeat([5,1])
z

In [None]:
pt.cat((x,z), dim=1)

In [4]:
m = 4
N = 7
jList = [1,3,4]
X = pt.rand((N, m)).double()
X

tensor([[0.7734, 0.1324, 0.5890, 0.2512],
        [0.4313, 0.6103, 0.3311, 0.6348],
        [0.5896, 0.4304, 0.7134, 0.6115],
        [0.5376, 0.2732, 0.6855, 0.5245],
        [0.8192, 0.6285, 0.3535, 0.2143],
        [0.2955, 0.6404, 0.8958, 0.8796],
        [0.8777, 0.1423, 0.9256, 0.0512]], dtype=torch.float64)

In [15]:
net = Net(m, jList).double()

In [17]:
Y_x = net.forward(X)

In [22]:
# simulate group op with random permutation
# with constraint that sites in jList are permuted
# amongst each other only
perm = pt.randperm(Y_x.shape[0])
perm[1] = 4
perm[3] = 1
perm[4] = 3
while pt.unique(perm).shape[0] != Y_x.shape[0]:
    perm = pt.randperm(Y_x.shape[0])
    perm[1] = 4
    perm[3] = 1
    perm[4] = 3
perm

tensor([2, 4, 5, 1, 3, 0, 6])

In [27]:
Y_x_perm = net.forward(X[perm])
Y_x_perm

tensor([[-0.3569, -0.4478,  0.0506],
        [-0.3579, -0.4477,  0.0503],
        [-0.3561, -0.4478,  0.0509],
        [-0.3569, -0.4479,  0.0509],
        [-0.3569, -0.4478,  0.0508],
        [-0.3579, -0.4476,  0.0503],
        [-0.3581, -0.4475,  0.0498]], dtype=torch.float64,
       grad_fn=<AddBackward0>)

In [32]:
# Floating pt addition is not
# always associative - use allclose
pt.allclose(Y_x_perm, Y_x[perm])

True

In [31]:
Y_x[perm]

tensor([[-0.3569, -0.4478,  0.0506],
        [-0.3579, -0.4477,  0.0503],
        [-0.3561, -0.4478,  0.0509],
        [-0.3569, -0.4479,  0.0509],
        [-0.3569, -0.4478,  0.0508],
        [-0.3579, -0.4476,  0.0503],
        [-0.3581, -0.4475,  0.0498]], dtype=torch.float64,
       grad_fn=<IndexBackward>)

In [42]:
def sq_sum(y):
    return pt.sum(pt.sum(y*y, dim=1))*0.5

In [44]:
opt = pt.optim.SGD(net.parameters(), lr=0.01)
y = net.forward(X)
l = sq_sum(y)
l.backward()