In [1]:
import torch as pt
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

In [2]:
class Net(nn.Module):
    
    def __init__(self, m, jList):
        """
        m : input dimensionality
        """
        super().__init__()
        
        self.N1L1 = nn.Linear(m, 64)
        self.N1L2 = nn.Linear(64, 64)
        self.N1L3 = nn.Linear(64, 64)
        self.N1L4 = nn.Linear(64, 3, bias=False)
        
        self.N2L1 = nn.Linear(6, 64)
        self.N2L2 = nn.Linear(64, 64)
        self.N2L3 = nn.Linear(64, 64)
        self.N2L4 = nn.Linear(64, 3, bias=False)
        
        self.jList = jList
    
    def subNet1(self, inState):
        out1 = F.leaky_relu(self.N1L1(inState))
        out1 = F.leaky_relu(self.N1L2(out1))
        out1 = F.leaky_relu(self.N1L3(out1))
        return self.N1L4(out1)
    
    def subNet2(self, in2):
        out2 = F.leaky_relu(self.N2L1(in2))
        out2 = F.leaky_relu(self.N2L2(out2))
        out2 = F.leaky_relu(self.N2L3(out2))
        return self.N2L4(out2)
    
    def forward(self, inState):
        
        out1 = self.subNet1(inState)
        
        out2 = pt.zeros_like(out1)
        
        for j in self.jList:
            # concatenate the q values for this site to all sites
            # Self correlation allowed for now
            out1j = out1[j].repeat([out1.shape[0], 1])
            in2 = pt.cat((out1, out1j), dim = 1)
            out2 += self.subNet2(in2)
        
        return out2

In [3]:
x = pt.rand(5,3)
y = pt.rand(3)
x,y

(tensor([[0.6021, 0.0649, 0.5998],
         [0.1056, 0.6579, 0.0586],
         [0.4123, 0.4300, 0.1710],
         [0.8238, 0.8015, 0.6003],
         [0.6230, 0.5182, 0.6396]]),
 tensor([0.7019, 0.2249, 0.1857]))

In [4]:
z = y.repeat([5,1])
z

tensor([[0.7019, 0.2249, 0.1857],
        [0.7019, 0.2249, 0.1857],
        [0.7019, 0.2249, 0.1857],
        [0.7019, 0.2249, 0.1857],
        [0.7019, 0.2249, 0.1857]])

In [5]:
pt.cat((x,z), dim=1)

tensor([[0.6021, 0.0649, 0.5998, 0.7019, 0.2249, 0.1857],
        [0.1056, 0.6579, 0.0586, 0.7019, 0.2249, 0.1857],
        [0.4123, 0.4300, 0.1710, 0.7019, 0.2249, 0.1857],
        [0.8238, 0.8015, 0.6003, 0.7019, 0.2249, 0.1857],
        [0.6230, 0.5182, 0.6396, 0.7019, 0.2249, 0.1857]])

In [6]:
m = 4  # no. of features per site
N = 7 # No. of sites
jList = [1,3,4] # jump sites
X = pt.rand((N, m)).double()
X

tensor([[0.6409, 0.5702, 0.7404, 0.0263],
        [0.4199, 0.7718, 0.8205, 0.4603],
        [0.4594, 0.5396, 0.9294, 0.8874],
        [0.8762, 0.2657, 0.7161, 0.8022],
        [0.7508, 0.1731, 0.3376, 0.4173],
        [0.8306, 0.5140, 0.5209, 0.7198],
        [0.2668, 0.9711, 0.4964, 0.4781]], dtype=torch.float64)

In [7]:
net = Net(m, jList).double()

In [8]:
Y_x = net.forward(X)

In [9]:
# simulate group op with random permutation
# with constraint that sites in jList are permuted
# amongst each other only
perm = pt.randperm(Y_x.shape[0])
perm[1] = 4
perm[3] = 1
perm[4] = 3
while pt.unique(perm).shape[0] != Y_x.shape[0]:
    perm = pt.randperm(Y_x.shape[0])
    perm[1] = 4
    perm[3] = 1
    perm[4] = 3
perm

tensor([5, 4, 0, 1, 3, 2, 6])

In [10]:
Y_x_perm = net.forward(X[perm])

In [11]:
# Floating pt addition is not
# always associative - use allclose
pt.allclose(Y_x_perm, Y_x[perm])

True

In [13]:
# Check if gradient works
def sq_sum(y):
    return pt.sum(pt.sum(y*y, dim=1))*0.5

opt = pt.optim.SGD(net.parameters(), lr=0.01)
y = net.forward(X)
l = sq_sum(y)
l.backward()

In [14]:
net = Net(m, jList).double()
out1 = net.subNet1(X)
Y_x = net.forward(X) 

In [15]:
for i in range(N):
    out1_i = out1[i].view(-1,3)
    out2_i = pt.zeros(3).view(-1,3).double()
    for j in jList:
        in2 = pt.cat((out1_i, out1[j].view(-1,3)), dim=1)
        out2_i += net.subNet2(in2)
    
    assert pt.allclose(Y_x[i], out2_i), i