In [2]:
import sys
sys.path.append("../")

In [3]:
import torch as pt
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pickle
from SymNet import SymNet

In [5]:
# Load test data
TestStates = np.load("../CrystalDat/TestStates.npy")
dxNN = np.load("../CrystalDat/nnJumpLatVecs.npy")
RtoSiteInd = np.load("../CrystalDat/RtoSiteInd.npy")
SiteIndtoR = np.load("../CrystalDat/SiteIndtoR.npy")
GpermNNIdx = np.load("../CrystalDat/GroupNNpermutations.npy")

NNsiteList = np.load("../CrystalDat/NNsites_sitewise.npy")
N_ngb = NNsiteList.shape[0]
Nsites = NNsiteList.shape[1]
with open("../CrystalDat/GroupOpsIndices.pkl", "rb") as fl:
    GIndtoGDict = pickle.load(fl)

with open("../CrystalDat/supercellBCC.pkl", "rb") as fl:
    superBCC = pickle.load(fl)

In [6]:
N_ngb, Nsites

(9, 512)

In [11]:
# Make shells for symmetry-grouping sites
shells = []
AllSites = set()
for siteInd in range(Nsites):
    if siteInd not in AllSites:
        Rsite = SiteIndtoR[siteInd]
        newShell = set()
        for gIng, g in GIndtoGDict.items():
            Rnew, _ = superBCC.crys.g_pos(g, Rsite, (0, 0))
            siteIndNew = RtoSiteInd[Rnew[0], Rnew[1], Rnew[2]]
            newShell.add(siteIndNew)
        AllSites.update(newShell)
        shells.append(list(newShell))

In [12]:
# Next, let's sort them according to distance
def sortkey(shell):
    R = SiteIndtoR[shell[0]]
    xmin = np.linalg.norm(np.dot(superBCC.crys.lattice, R))
    for siteInd in shell:
        R = SiteIndtoR[siteInd]
        x = np.linalg.norm(np.dot(superBCC.crys.lattice, R))
        if x < xmin:
            xmin = x
    return xmin
shellsSorted = sorted(shells, key=sortkey)

In [13]:
# Now let's assign shells to sites
SitesToShells = pt.zeros(Nsites).long()
for shellInd, shell in enumerate(shellsSorted):
    for siteInd in shell:
        SitesToShells[siteInd] = shellInd
pt.unique(SitesToShells, return_counts=True)

(tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28]),
 tensor([ 1,  8,  6, 12, 24,  8,  6, 24, 24, 24,  8, 24, 12, 48, 24,  6, 24, 24,
         24,  4, 24, 24, 24, 24, 48,  3, 12, 12,  6]))

## Try constructing with half-translation

In [14]:
# Make shells for symmetry-grouping sites
shells2 = []
N_units = 8
BCvec = np.ones(3, dtype=int)
AllSites = set()
for siteInd in range(Nsites):
    if siteInd not in AllSites:
        Rsite = (SiteIndtoR[siteInd] + BCvec*N_units//2)%N_units - BCvec*N_units//2
        newShell = set()
        for gIng, g in GIndtoGDict.items():
            Rnew, _ = superBCC.crys.g_pos(g, Rsite, (0, 0))
            siteIndNew = RtoSiteInd[Rnew[0], Rnew[1], Rnew[2]]
            newShell.add(siteIndNew)
        AllSites.update(newShell)
        shells2.append(list(newShell))

In [15]:
# Next, let's sort them according to distance
def sortkey2(shell):
    R = SiteIndtoR[shell[0]]
    xmin = np.linalg.norm(np.dot(superBCC.crys.lattice, R))
    for siteInd in shell:
        R = (SiteIndtoR[siteInd] + BCvec*N_units//2)%N_units - BCvec*N_units//2
        x = np.linalg.norm(np.dot(superBCC.crys.lattice, R))
        if x < xmin:
            xmin = x
    return xmin
shellsSorted2 = sorted(shells2, key=sortkey2)

In [16]:
# Now let's assign shells to sites
SitesToShells2 = pt.zeros(Nsites).long()
for shellInd, shell in enumerate(shellsSorted2):
    for siteInd in shell:
        SitesToShells2[siteInd] = shellInd
pt.unique(SitesToShells2, return_counts=True)

(tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28]),
 tensor([ 1,  8,  6, 12, 24,  8,  6, 24, 24, 24,  8, 24, 12, 48, 24,  6, 24, 24,
         24,  4, 24, 24, 24, 24, 48,  3, 12, 12,  6]))

In [18]:
pt.equal(SitesToShells, SitesToShells2)

True

# Testing

## First, we test for symmetry

In [15]:
GnnPerms = pt.tensor(GpermNNIdx).long()
NNsites = pt.tensor(NNsiteList)

Nlayers = 5
NchOuts = [4, 4, 4, 4, 1]

Ng = GnnPerms.shape[0]
Ndim = 3
gdiags = pt.zeros(Ng*Ndim, Ng*Ndim).double()
for gInd, g in GIndtoGDict.items():
    rowStart = gInd * Ndim
    rowEnd = (gInd + 1) * Ndim
    gdiags[rowStart : rowEnd, rowStart : rowEnd] = pt.tensor(g.cartrot).double()

TestNet = SymNet(Nlayers, NchOuts, GnnPerms, GIndtoGDict, gdiags, NNsites,
                 SitesToShells, Ndim, act="relu").double()

In [22]:
for weight in TestNet.weightList:
    print(weight.shape[0]*weight.shape[1]*weight.shape[2])

36
144
144
144
36


In [23]:
TestNet.wtVC.shape

torch.Size([3, 9])

In [32]:
sum(list(p.numel() for p in TestNet.parameters()))

577

In [31]:
dict(TestNet.named_parameters())

{'wtVC': Parameter containing:
 tensor([[1.9749, 3.0354, 3.6529, 1.8340, 1.1075, 1.7974, 0.7462, 2.4259, 1.3908],
         [1.6630, 3.2360, 3.7552, 2.2002, 1.7305, 1.1286, 1.9091, 2.3309, 1.8651],
         [1.3223, 1.6829, 1.9491, 1.4509, 1.0662, 2.2843, 2.9345, 1.4059, 1.1303]],
        dtype=torch.float64, requires_grad=True),
 'ShellWeights': Parameter containing:
 tensor([ 1.7806,  2.0272,  1.9919,  1.9572,  1.3736,  0.4948,  1.5764,  3.6946,
          4.4155,  1.6355,  1.9060,  2.6997,  1.1356,  1.2409,  2.2100,  2.0653,
          3.1895,  3.1247,  0.9288,  2.7076,  1.7366, -0.7505,  2.2041,  1.3572,
          3.2588,  1.6586,  1.5657,  2.4159,  4.3161], dtype=torch.float64,
        requires_grad=True),
 'weightList.0': Parameter containing:
 tensor([[[1.5276, 1.5334, 0.0286, 2.2256, 1.4351, 1.2056, 2.8965, 0.5994,
           2.2503]],
 
         [[2.7141, 2.4980, 3.3050, 1.5546, 3.6777, 3.7705, 2.3494, 1.7919,
           2.1461]],
 
         [[1.9423, 1.3845, 1.9641, 1.4264, 1.60

In [9]:
StateTensors = pt.tensor(TestStates/2.0).double().view(TestStates.shape[0], 1, TestStates.shape[1])
N_batch = StateTensors.shape[0]
StateTensors.shape

torch.Size([10, 1, 512])

In [10]:
InLayers, outlayersG, outlayers, outVecSites, out = TestNet.forward(StateTensors, Test=True)

In [11]:
out

tensor([[-11.7782,   6.6165,  -2.1869],
        [ -6.6165,   2.1869,  11.7782],
        [ 11.7782,   6.6165,  -2.1869],
        [  2.1869,  -6.6165,  11.7782],
        [ -6.6165, -11.7782,   2.1869],
        [  2.1869,   6.6165, -11.7782],
        [ 11.7782,  -6.6165,  -2.1869],
        [ 11.7782,   2.1869,   6.6165],
        [ -2.1869,  11.7782,  -6.6165],
        [ -2.1869,  -6.6165, -11.7782]], dtype=torch.float64,
       grad_fn=<DivBackward0>)

In [12]:
# First, Let's test for symmetry conformity of each layer
for layer in range(Nlayers):
    
    layerOut = outlayers[layer]
    
    for ch in range(NchOuts[layer]):
        
        out0 = layerOut[0, ch]
        
        for sampInd in range(layerOut.shape[0]):
            outsamp = layerOut[sampInd, ch]
            g = GIndtoGDict[sampInd]
            out0Transf = pt.zeros_like(out0)
            for siteInd in range(Nsites):
                Rsite = SiteIndtoR[siteInd]
                Rnew, _ = superBCC.crys.g_pos(g, Rsite, (0, 0))
                
                Rnew %= 8
                
                siteIndNew = RtoSiteInd[Rnew[0], Rnew[1], Rnew[2]]
                
                out0Transf[siteIndNew] = out0[siteInd]
                
            assert pt.allclose(outsamp, out0Transf)

## Then we convolve explicitly and verify outputs

In [24]:
def layerConvTest(net, layerInd, Input, GOuts, SiteIndtoR, RtoSiteInd, dxNN):
        
    weightAll = net.weightList[layerInd]
    biasAll = net.biasList[layerInd]
    
    NchOut = weightAll.shape[0]
    NchIn = weightAll.shape[1]
    
    for sampInd in range(N_batch):

        for chOut in range(NchOut):
            bias = biasAll[chOut][0]

            for gInd, g in GIndtoGDict.items():

                sampOut = pt.zeros(net.Nsites).double()
                for siteInd in range(net.Nsites):
                    Rsite = SiteIndtoR[siteInd]
                    sumSite = 0.

                    for chIn in range(NchIn):    
                        psi_ch_in = weightAll[chOut, chIn]

                        for ngb in range(1, N_ngb):
                            dxCart = np.dot(superBCC.crys.lattice, dxNN[ngb-1])
                            dxCartRot = np.dot(g.cartrot, dxCart)
                            dxRotLat = np.dot(np.linalg.inv(superBCC.crys.lattice), dxCartRot).astype(int)
                            Rngb = (Rsite + dxRotLat)%8
                            siteIndNgb = RtoSiteInd[Rngb[0], Rngb[1], Rngb[2]]

                            sumSite += Input[sampInd, chIn, siteIndNgb] * psi_ch_in[ngb]

                        sumSite += Input[sampInd, chIn, siteInd] * psi_ch_in[0]

                    sumSite += bias

                    sampOut[siteInd] = F.relu(sumSite)

                assert pt.allclose(sampOut, GOuts[sampInd, chOut, gInd]), "{} {}".format(sampInd, gInd)
    print("Layer {} tests done".format(layerInd))

In [25]:
outlayers[0].shape

torch.Size([10, 2, 512])

In [26]:
# Let's do the 0th layer first

# get the outputs
outlayer0G = outlayersG[0] # Before group averaging
sumOutsLayer0 = outlayers[0] # After group averaging

layerConvTest(TestNet, 0, StateTensors, outlayer0G, SiteIndtoR, RtoSiteInd, dxNN)

Layer 0 tests done


In [27]:
TestNet.weightList[0].shape

torch.Size([2, 1, 9])

In [28]:
# Then check summing across G channels
for sampInd in range(StateTensors.shape[0]):
    for chInd in range(sumOutsLayer0.shape[1]):        
        for siteInd in range(Nsites):
            sumNet = sumOutsLayer0[sampInd, chInd, siteInd]
            sumCalc = 0.
            for gInd in range(Ng):
                sumCalc += outlayer0G[sampInd, chInd, gInd, siteInd]
            assert pt.allclose(sumCalc/Ng, sumNet)

In [29]:
# Now let's check the 1th layer convolution
outlayer1G = outlayersG[1]
sumOutsLayer1 = outlayers[1]

layerConvTest(TestNet, 1, sumOutsLayer0, outlayer1G, SiteIndtoR, RtoSiteInd, dxNN)

Layer 1 tests done


In [30]:
for sampInd in range(StateTensors.shape[0]):
    for chInd in range(sumOutsLayer1.shape[1]):        
        for siteInd in range(Nsites):
            sumNet = sumOutsLayer1[sampInd, chInd, siteInd]
            sumCalc = 0.
            for gInd in range(Ng):
                sumCalc += outlayer1G[sampInd, chInd, gInd, siteInd]
            assert pt.allclose(sumCalc/Ng, sumNet)

In [31]:
## Next test layer 2 outputs
outlayer2G = outlayersG[2]
sumOutsLayer2 = outlayers[2]

layerConvTest(TestNet, 2, sumOutsLayer1, outlayer2G, SiteIndtoR, RtoSiteInd, dxNN)

Layer 2 tests done


In [32]:
for sampInd in range(StateTensors.shape[0]):
    for chInd in range(sumOutsLayer2.shape[1]):        
        for siteInd in range(Nsites):
            sumNet = sumOutsLayer2[sampInd, chInd, siteInd]
            sumCalc = 0.
            for gInd in range(Ng):
                sumCalc += outlayer2G[sampInd, chInd, gInd, siteInd]
            assert pt.allclose(sumCalc/Ng, sumNet)

In [33]:
Nsites

512

In [40]:
# Now we need to test the R3 convolution

PsiR3 = TestNet.wtVC

out3Sites0 = outVecSites[0]

for sampInd in range(N_batch):
    
    R3OutSites = pt.zeros(3, Nsites).double()
    
    for siteInd in range(Nsites):
        Rsite = SiteIndtoR[siteInd]
        sumSite = pt.zeros(3).double()
        
        # get the shell of this site
        shellInd = SitesToShells[siteInd]
        shellWeight = TestNet.ShellWeights[shellInd]
        
        for gInd, g in GIndtoGDict.items():
            
            gRotTens = pt.tensor(g.cartrot).double()            
            for ngb in range(1, N_ngb):
                PsiRot = pt.matmul(gRotTens, PsiR3[:, ngb])
                dxCart = np.dot(superBCC.crys.lattice, dxNN[ngb-1])
                dxCartRot = np.dot(g.cartrot, dxCart)
                dxRotLat = np.dot(np.linalg.inv(superBCC.crys.lattice), dxCartRot).astype(int)
                Rngb = (Rsite + dxRotLat)%8
                siteIndNgb = RtoSiteInd[Rngb[0], Rngb[1], Rngb[2]]
                sumSite += sumOutsLayer2[sampInd, 0, siteIndNgb]*PsiRot
            
            sumSite += sumOutsLayer2[sampInd, 0, siteInd]*pt.matmul(gRotTens, PsiR3[:, 0])
        
        R3OutSites[:, siteInd] = shellWeight*sumSite/TestNet.Ng
        
    assert pt.allclose(R3OutSites, outVecSites[sampInd])
    
    # Now check symmetry relationship
    gSamp = GIndtoGDict[sampInd]
    gSampTens = pt.tensor(gSamp.cartrot).double()
    for siteInd in range(Nsites):
        vec0 = out3Sites0[:, siteInd]
        Rsite = SiteIndtoR[siteInd]
        Rg, _ = superBCC.crys.g_pos(gSamp, Rsite, (0,0))
        Rg%=8
        siteIndg = RtoSiteInd[Rg[0], Rg[1], Rg[2]]
        assert pt.allclose(outVecSites[sampInd, :, siteIndg], pt.matmul(gSampTens, vec0))

print("R3 conv tests passed")

R3 conv tests passed
