In [1]:
import torch as pt
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pickle
from SymNet import SymNet

In [2]:
# Load test data
TestStates = np.load("TestStates.npy")[:10]
dxNN = np.load("nnJumpLatVecs.npy")
RtoSiteInd = np.load("RtoSiteInd.npy")
SiteIndtoR = np.load("SiteIndtoR.npy")
GpermNNIdx = np.load("GroupNNpermutations.npy")

NNsiteList = np.load("NNsites_sitewise.npy")
N_ngb = NNsiteList.shape[0]
Nsites = NNsiteList.shape[1]
with open("GroupOpsIndices.pkl", "rb") as fl:
    GIndtoGDict = pickle.load(fl)

with open("supercellBCC.pkl", "rb") as fl:
    superBCC = pickle.load(fl)

In [3]:
N_ngb, Nsites

(9, 512)

In [4]:
# Make shells for symmetry-grouping sites
shells = []
AllSites = set()
for siteInd in range(Nsites):
    if siteInd not in AllSites:
        Rsite = SiteIndtoR[siteInd]
        newShell = set()
        for gIng, g in GIndtoGDict.items():
            Rnew, _ = superBCC.crys.g_pos(g, Rsite, (0, 0))
            siteIndNew = RtoSiteInd[Rnew[0], Rnew[1], Rnew[2]]
            newShell.add(siteIndNew)
        AllSites.update(newShell)
        shells.append(list(newShell))

In [5]:
# Next, let's sort them according to distance
def sortkey(shell):
    R = SiteIndtoR[shell[0]]
    xmin = np.linalg.norm(np.dot(superBCC.crys.lattice, R))
    for siteInd in shell:
        R = SiteIndtoR[siteInd]
        x = np.linalg.norm(np.dot(superBCC.crys.lattice, R))
        if x < xmin:
            xmin = x
    return xmin
shellsSorted = sorted(shells, key=sortkey)

In [6]:
# Now let's assign shells to sites
SitesToShells = pt.zeros(Nsites).long()
for shellInd, shell in enumerate(shellsSorted):
    for siteInd in shell:
        SitesToShells[siteInd] = shellInd
pt.unique(SitesToShells, return_counts=True)

(tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28]),
 tensor([ 1,  8,  6, 12, 24,  8,  6, 24, 24, 24,  8, 24, 12, 48, 24,  6, 24, 24,
         24,  4, 24, 24, 24, 24, 48,  3, 12, 12,  6]))

# Testing

## First, we test for symmetry

In [7]:
GnnPerms = pt.tensor(GpermNNIdx).long()
NNsites = pt.tensor(NNsiteList)

Nlayers = 3
NchOuts = [2, 2, 1]

Ng = GnnPerms.shape[0]
Ndim = 3
gdiags = pt.zeros(Ng*Ndim, Ng*Ndim).double()
for gInd, g in GIndtoGDict.items():
    rowStart = gInd * Ndim
    rowEnd = (gInd + 1) * Ndim
    gdiags[rowStart : rowEnd, rowStart : rowEnd] = pt.tensor(g.cartrot).double()

TestNet = SymNet(Nlayers, NchOuts, GnnPerms, GIndtoGDict, gdiags, NNsites,
                 SitesToShells, Ndim, act="relu").double()

In [8]:
TestNet.RotateParams()

In [9]:
StateTensors = pt.tensor(TestStates/2.0).double().view(TestStates.shape[0], 1, TestStates.shape[1])
N_batch = StateTensors.shape[0]
StateTensors.shape

torch.Size([10, 1, 512])

In [10]:
InLayers, outlayersG, outlayers, outVecSites, out = TestNet.forward(StateTensors, Test=True)

In [11]:
out

tensor([[-11.7782,   6.6165,  -2.1869],
        [ -6.6165,   2.1869,  11.7782],
        [ 11.7782,   6.6165,  -2.1869],
        [  2.1869,  -6.6165,  11.7782],
        [ -6.6165, -11.7782,   2.1869],
        [  2.1869,   6.6165, -11.7782],
        [ 11.7782,  -6.6165,  -2.1869],
        [ 11.7782,   2.1869,   6.6165],
        [ -2.1869,  11.7782,  -6.6165],
        [ -2.1869,  -6.6165, -11.7782]], dtype=torch.float64,
       grad_fn=<DivBackward0>)

In [12]:
# First, Let's test for symmetry conformity of each layer
for layer in range(Nlayers):
    
    layerOut = outlayers[layer]
    
    for ch in range(NchOuts[layer]):
        
        out0 = layerOut[0, ch]
        
        for sampInd in range(layerOut.shape[0]):
            outsamp = layerOut[sampInd, ch]
            g = GIndtoGDict[sampInd]
            out0Transf = pt.zeros_like(out0)
            for siteInd in range(Nsites):
                Rsite = SiteIndtoR[siteInd]
                Rnew, _ = superBCC.crys.g_pos(g, Rsite, (0, 0))
                
                Rnew %= 8
                
                siteIndNew = RtoSiteInd[Rnew[0], Rnew[1], Rnew[2]]
                
                out0Transf[siteIndNew] = out0[siteInd]
                
            assert pt.allclose(outsamp, out0Transf)

## Then we convolve explicitly and verify outputs

In [24]:
def layerConvTest(net, layerInd, Input, GOuts, SiteIndtoR, RtoSiteInd, dxNN):
        
    weightAll = net.weightList[layerInd]
    biasAll = net.biasList[layerInd]
    
    NchOut = weightAll.shape[0]
    NchIn = weightAll.shape[1]
    
    for sampInd in range(N_batch):

        for chOut in range(NchOut):
            bias = biasAll[chOut][0]

            for gInd, g in GIndtoGDict.items():

                sampOut = pt.zeros(net.Nsites).double()
                for siteInd in range(net.Nsites):
                    Rsite = SiteIndtoR[siteInd]
                    sumSite = 0.

                    for chIn in range(NchIn):    
                        psi_ch_in = weightAll[chOut, chIn]

                        for ngb in range(1, N_ngb):
                            dxCart = np.dot(superBCC.crys.lattice, dxNN[ngb-1])
                            dxCartRot = np.dot(g.cartrot, dxCart)
                            dxRotLat = np.dot(np.linalg.inv(superBCC.crys.lattice), dxCartRot).astype(int)
                            Rngb = (Rsite + dxRotLat)%8
                            siteIndNgb = RtoSiteInd[Rngb[0], Rngb[1], Rngb[2]]

                            sumSite += Input[sampInd, chIn, siteIndNgb] * psi_ch_in[ngb]

                        sumSite += Input[sampInd, chIn, siteInd] * psi_ch_in[0]

                    sumSite += bias

                    sampOut[siteInd] = F.relu(sumSite)

                assert pt.allclose(sampOut, GOuts[sampInd, chOut, gInd]), "{} {}".format(sampInd, gInd)
    print("Layer {} tests done".format(layerInd))

In [25]:
outlayers[0].shape

torch.Size([10, 2, 512])

In [26]:
# Let's do the 0th layer first

# get the outputs
outlayer0G = outlayersG[0] # Before group averaging
sumOutsLayer0 = outlayers[0] # After group averaging

layerConvTest(TestNet, 0, StateTensors, outlayer0G, SiteIndtoR, RtoSiteInd, dxNN)

Layer 0 tests done


In [27]:
TestNet.weightList[0].shape

torch.Size([2, 1, 9])

In [28]:
# Then check summing across G channels
for sampInd in range(StateTensors.shape[0]):
    for chInd in range(sumOutsLayer0.shape[1]):        
        for siteInd in range(Nsites):
            sumNet = sumOutsLayer0[sampInd, chInd, siteInd]
            sumCalc = 0.
            for gInd in range(Ng):
                sumCalc += outlayer0G[sampInd, chInd, gInd, siteInd]
            assert pt.allclose(sumCalc/Ng, sumNet)

In [29]:
# Now let's check the 1th layer convolution
outlayer1G = outlayersG[1]
sumOutsLayer1 = outlayers[1]

layerConvTest(TestNet, 1, sumOutsLayer0, outlayer1G, SiteIndtoR, RtoSiteInd, dxNN)

Layer 1 tests done


In [30]:
for sampInd in range(StateTensors.shape[0]):
    for chInd in range(sumOutsLayer1.shape[1]):        
        for siteInd in range(Nsites):
            sumNet = sumOutsLayer1[sampInd, chInd, siteInd]
            sumCalc = 0.
            for gInd in range(Ng):
                sumCalc += outlayer1G[sampInd, chInd, gInd, siteInd]
            assert pt.allclose(sumCalc/Ng, sumNet)

In [31]:
## Next test layer 2 outputs
outlayer2G = outlayersG[2]
sumOutsLayer2 = outlayers[2]

layerConvTest(TestNet, 2, sumOutsLayer1, outlayer2G, SiteIndtoR, RtoSiteInd, dxNN)

Layer 2 tests done


In [32]:
for sampInd in range(StateTensors.shape[0]):
    for chInd in range(sumOutsLayer2.shape[1]):        
        for siteInd in range(Nsites):
            sumNet = sumOutsLayer2[sampInd, chInd, siteInd]
            sumCalc = 0.
            for gInd in range(Ng):
                sumCalc += outlayer2G[sampInd, chInd, gInd, siteInd]
            assert pt.allclose(sumCalc/Ng, sumNet)

In [33]:
Nsites

512

In [40]:
# Now we need to test the R3 convolution

PsiR3 = TestNet.wtVC

out3Sites0 = outVecSites[0]

for sampInd in range(N_batch):
    
    R3OutSites = pt.zeros(3, Nsites).double()
    
    for siteInd in range(Nsites):
        Rsite = SiteIndtoR[siteInd]
        sumSite = pt.zeros(3).double()
        
        # get the shell of this site
        shellInd = SitesToShells[siteInd]
        shellWeight = TestNet.ShellWeights[shellInd]
        
        for gInd, g in GIndtoGDict.items():
            
            gRotTens = pt.tensor(g.cartrot).double()            
            for ngb in range(1, N_ngb):
                PsiRot = pt.matmul(gRotTens, PsiR3[:, ngb])
                dxCart = np.dot(superBCC.crys.lattice, dxNN[ngb-1])
                dxCartRot = np.dot(g.cartrot, dxCart)
                dxRotLat = np.dot(np.linalg.inv(superBCC.crys.lattice), dxCartRot).astype(int)
                Rngb = (Rsite + dxRotLat)%8
                siteIndNgb = RtoSiteInd[Rngb[0], Rngb[1], Rngb[2]]
                sumSite += sumOutsLayer2[sampInd, 0, siteIndNgb]*PsiRot
            
            sumSite += sumOutsLayer2[sampInd, 0, siteInd]*pt.matmul(gRotTens, PsiR3[:, 0])
        
        R3OutSites[:, siteInd] = shellWeight*sumSite/TestNet.Ng
        
    assert pt.allclose(R3OutSites, outVecSites[sampInd])
    
    # Now check symmetry relationship
    gSamp = GIndtoGDict[sampInd]
    gSampTens = pt.tensor(gSamp.cartrot).double()
    for siteInd in range(Nsites):
        vec0 = out3Sites0[:, siteInd]
        Rsite = SiteIndtoR[siteInd]
        Rg, _ = superBCC.crys.g_pos(gSamp, Rsite, (0,0))
        Rg%=8
        siteIndg = RtoSiteInd[Rg[0], Rg[1], Rg[2]]
        assert pt.allclose(outVecSites[sampInd, :, siteIndg], pt.matmul(gSampTens, vec0))

print("R3 conv tests passed")

R3 conv tests passed
