In [1]:
import torch as pt
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pickle
from SymNet import SymNet

In [2]:
# Load test data
TestStates = np.load("TestStates.npy")[:10]
dxNN = np.load("nnJumpLatVecs.npy")
RtoSiteInd = np.load("RtoSiteInd.npy")
SiteIndtoR = np.load("SiteIndtoR.npy")
GpermNNIdx = np.load("GroupNNpermutations.npy")

NNsiteList = np.load("NNsites_sitewise.npy")
N_ngb = NNsiteList.shape[0]
Nsites = NNsiteList.shape[1]
with open("GroupOpsIndices.pkl", "rb") as fl:
    GIndtoGDict = pickle.load(fl)

with open("supercellBCC.pkl", "rb") as fl:
    superBCC = pickle.load(fl)

In [3]:
N_ngb, Nsites

(9, 512)

# Testing

## First, we test for symmetry

In [4]:
GnnPerms = pt.tensor(GpermNNIdx).long()
NNsites = pt.tensor(NNsiteList)

Nlayers = 1
NchOuts = [1]

Ng = GnnPerms.shape[0]
Ndim = 3
gdiags = pt.zeros(Ng*Ndim, Ng*Ndim).double()
for gInd, g in GIndtoGDict.items():
    rowStart = gInd * Ndim
    rowEnd = (gInd + 1) * Ndim
    gdiags[rowStart : rowEnd, rowStart : rowEnd] = pt.tensor(g.cartrot).double()

In [5]:
StateTensors = pt.tensor(TestStates/2.0).double().view(TestStates.shape[0], 1, TestStates.shape[1])
N_batch = StateTensors.shape[0]
StateTensors.shape

torch.Size([10, 1, 512])

## Let's do backpropagation tests to make sure gradients are calculated correctly

In [6]:
# Let's take a dummy loss function that tries to force
# all vectors to -1
def loss_fn(batchVecs):
    N_batch = batchVecs.shape[0]
    return pt.sum(pt.norm(batchVecs + pt.ones_like(batchVecs), dim=1)**2)/N_batch

In [7]:
# Now write an explicit convolution function
# we'll test for a single sample: Input with batch size 1

def conv(net, Input, layerInd, Filter, Gdict):
    
    weights = net.weightList[layerInd]
    biasList = net.biasList[layerInd]
    
    Nsites = Input.shape[2]
    
    Out = pt.zeros(1, weights.shape[0], Nsites).double()
    
    for chOut in range(weights.shape[0]):
        for siteInd in range(Nsites):
            for gInd, g in Gdict.items():
                Rsite = SiteIndtoR[siteInd]
                sumSite = 0.
                for chIn in range(weights.shape[1]):
                    # permute the weights
                    psi_perm = weights[ChOut, chIn][net.GnnPerms[gInd]]

                    for ngb in range(1, net.N_ngb):
                        dx = dxNN[ngb-1]
                        Rngb = (Rsite + dx) % 8
                        siteIndNgb = RtoSiteInd[Rngb[0], Rngb[1], Rngb[2]]
                        
                        sumSite += Input[0, chIn, siteIndNgb]*psi_perm[gnb]
                    
                    sumSite += Input[0, chIn, siteInd]*psi_perm[siteInd]
                
                sumSite += bias[chOut, 0]
                Out[0, chOut, siteInd] += F.softplus(sumSite)
    return Out/net.Ng

In [11]:
TestNet = SymNet(Nlayers, NchOuts, GnnPerms, GIndtoGDict, gdiags, NNsites, 3, active="relu").double()
optimizer = pt.optim.Adam(TestNet.parameters(), lr = 0.001)
TestNet.RotateParams()

In [12]:
# Now let's first evaluate gradients with the network

# We only use the first sample for now
optimizer.zero_grad()
InLayers, outlayersG, outlayers, outVecSites, out = TestNet.forward(StateTensors[0].view(-1, 1, 512),
                                                                    Test=True)
l = loss_fn(out)
l.backward(retain_graph=True)

In [13]:
TestNet.weightList[0].grad

tensor([[[4.2793e-15, 9.3221e-15, 1.9611e-15, 2.2934e-15, 2.0710e-15,
          9.0832e-15, 4.1914e-15, 6.0701e-15, 2.2797e-16]]],
       dtype=torch.float64)

In [16]:
TestNet.wtVC.grad

tensor([[-4.4409e-16,  1.5543e-15, -1.9984e-15,  2.8866e-15, -1.9984e-15,
          2.4425e-15, -3.3307e-15,  2.8866e-15, -4.6629e-15],
        [-4.4409e-16,  1.9984e-15, -1.1102e-15, -3.3307e-15,  3.3307e-15,
         -2.4425e-15,  1.5543e-15,  2.8866e-15, -2.4425e-15],
        [ 4.4409e-16,  1.7764e-15, -2.6645e-15, -1.7764e-15,  2.6645e-15,
          1.7764e-15, -3.1086e-15, -2.6645e-15,  3.1086e-15]],
       dtype=torch.float64)

In [17]:
outVecSites.shape

torch.Size([1, 3, 512])

In [22]:
sumOutsLayer0 = outlayers[0]

In [26]:
yInd = 2
k = 1
yVecCart = np.dot(superBCC.crys.lattice, dxNN[yInd])
siteVecs = pt.zeros(3, Nsites).double()
siteGcomps = pt.zeros(Nsites, Ng, 3)

for siteInd in range(Nsites):
    
    Rsite = SiteIndtoR[siteInd]
    
    sumg = pt.zeros(3).double()
    
    for gInd, g in GIndtoGDict.items():
        
        gTens = pt.tensor(g.cartrot).double()
        
        # Get the vector
        gvec = pt.tensor([gTens[0, k], gTens[1, k], gTens[2, k]]).double()
        
        yvecCartG = np.dot(g.cartrot, yVecCart)
        
        yvecLat = np.dot(np.linalg.inv(superBCC.crys.lattice), yvecCartG).astype(int)
        
        RNew = (Rsite + yvecLat)%8
        
        siteNgb = RtoSiteInd[RNew[0], RNew[1], RNew[2]]
        
        s = sumOutsLayer0[0, 0, siteNgb]
        
        sumg += s*gvec
    
    siteVecs[:, siteInd] = sumg

In [27]:
siteVecs

tensor([[ 0.2299, -0.1445,  0.4050,  ..., -0.0531, -0.4143,  0.0202],
        [-0.2087, -0.0083,  0.0572,  ...,  0.0639, -0.5701,  0.0202],
        [-0.0034,  0.3385, -0.3975,  ..., -0.0934,  0.2141,  0.3243]],
       dtype=torch.float64)