In [1]:
import torch as pt
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pickle
from SymNet import SymNet
import h5py

In [2]:
# Load test data
with h5py.File(r'singleStep_50.h5',"r") as fl:
    state1Grid = np.array(fl["state1Grid"][0])
    state2Grid = np.array(fl["state2Grid"][0])
    rateArray = np.array(fl["rateList"][0])
    dispAll = np.array(fl["dispList"][0])

dxNN = np.load("nnJumpLatVecs.npy")
RtoSiteInd = np.load("RtoSiteInd.npy")
SiteIndtoR = np.load("SiteIndtoR.npy")
GpermNNIdx = np.load("GroupNNpermutations.npy")

NNsiteList = np.load("NNsites_sitewise.npy")
N_ngb = NNsiteList.shape[0]
Nsites = NNsiteList.shape[1]
with open("GroupOpsIndices.pkl", "rb") as fl:
    GIndtoGDict = pickle.load(fl)

with open("supercellBCC.pkl", "rb") as fl:
    superBCC = pickle.load(fl)

In [3]:
N_ngb, Nsites

(9, 512)

## Make the symmetry-related parameters

In [4]:
GnnPerms = pt.tensor(GpermNNIdx).long()
NNsites = pt.tensor(NNsiteList)

Nlayers = 1
NchOuts = [1]

Ng = GnnPerms.shape[0]
Ndim = 3
gdiags = pt.zeros(Ng*Ndim, Ng*Ndim).double()
for gInd, g in GIndtoGDict.items():
    rowStart = gInd * Ndim
    rowEnd = (gInd + 1) * Ndim
    gdiags[rowStart : rowEnd, rowStart : rowEnd] = pt.tensor(g.cartrot).double()

In [5]:
state1Tensor = pt.tensor(state1Grid/2.0).double().view(1, 1, 512)
state2Tensor = pt.tensor(state2Grid/2.0).double().view(1, 1, 512)
dispSpec1 = pt.tensor(dispAll[1, :]).double()
rate = pt.tensor(rateArray).double()

In [6]:
state2Tensor.shape

torch.Size([1, 1, 512])

In [7]:
# Make shells for symmetry-grouping sites
shells = []
AllSites = set()
for siteInd in range(Nsites):
    if siteInd not in AllSites:
        Rsite = SiteIndtoR[siteInd]
        newShell = set()
        for gIng, g in GIndtoGDict.items():
            Rnew, _ = superBCC.crys.g_pos(g, Rsite, (0, 0))
            siteIndNew = RtoSiteInd[Rnew[0], Rnew[1], Rnew[2]]
            newShell.add(siteIndNew)
        AllSites.update(newShell)
        shells.append(list(newShell))

In [8]:
# Next, let's sort them according to distance
def sortkey(shell):
    R = SiteIndtoR[shell[0]]
    xmin = np.linalg.norm(np.dot(superBCC.crys.lattice, R))
    for siteInd in shell:
        R = SiteIndtoR[siteInd]
        x = np.linalg.norm(np.dot(superBCC.crys.lattice, R))
        if x < xmin:
            xmin = x
    return xmin
shellsSorted = sorted(shells, key=sortkey)

In [9]:
# Now let's assign shells to sites
SitesToShells = pt.zeros(Nsites).long()
for shellInd, shell in enumerate(shellsSorted):
    for siteInd in shell:
        SitesToShells[siteInd] = shellInd
pt.unique(SitesToShells, return_counts=True)

(tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28]),
 tensor([ 1,  8,  6, 12, 24,  8,  6, 24, 24, 24,  8, 24, 12, 48, 24,  6, 24, 24,
         24,  4, 24, 24, 24, 24, 48,  3, 12, 12,  6]))

## Let's do backpropagation tests to make sure gradients are calculated correctly

In [10]:
# Let's take a dummy loss function that tries to force
# all vectors to -1
def loss_fn(rate, dx, dy):
    return pt.sum(rate * pt.norm(dx + dy)**2)/6.

In [16]:
TestNet = SymNet(Nlayers, NchOuts, GnnPerms, GIndtoGDict, gdiags, NNsites,
                 SitesToShells, Ndim, act="relu").double()

optimizer = pt.optim.Adam(TestNet.parameters(), lr = 0.001)
TestNet.RotateParams()

# Now let's first evaluate gradients with the network

# We only use the first sample for now
optimizer.zero_grad()
InLayers1, outlayersG1, outlayers1, outVecSites1, y1 = TestNet.forward(state1Tensor,
                                                                       UseShellWeights=False, Test=True)
InLayers2, outlayersG2, outlayers2, outVecSites2, y2 = TestNet.forward(state2Tensor,
                                                                       UseShellWeights=False, Test=True)
dy = (y2-y1).view(3)
l = loss_fn(rate, dispSpec1, dy)
l.backward()

In [17]:
gradLayer1 = TestNet.weightList[0].grad.data.numpy().copy()

In [18]:
gradR3 = TestNet.wtVC.grad.data.numpy().copy()

In [19]:
gradLayer1

array([[[0., 0., 0., 0., 0., 0., 0., 0., 0.]]])

In [20]:
for i in range(N_ngb):
    print(gradR3[:, i])

[0. 0. 0.]
[8.32667268e-17 8.32667268e-17 8.32667268e-17]
[-8.32667268e-17 -8.32667268e-17 -8.32667268e-17]
[ 8.32667268e-17 -8.32667268e-17 -8.32667268e-17]
[-8.32667268e-17  8.32667268e-17  8.32667268e-17]
[ 8.32667268e-17 -8.32667268e-17  8.32667268e-17]
[-8.32667268e-17  8.32667268e-17 -8.32667268e-17]
[ 8.32667268e-17  8.32667268e-17 -8.32667268e-17]
[-8.32667268e-17 -8.32667268e-17  8.32667268e-17]


In [21]:
np.dot(superBCC.crys.lattice, dxNN.T).T

array([[ 0.5,  0.5,  0.5],
       [-0.5, -0.5, -0.5],
       [ 0.5, -0.5, -0.5],
       [-0.5,  0.5,  0.5],
       [ 0.5, -0.5,  0.5],
       [-0.5,  0.5, -0.5],
       [ 0.5,  0.5, -0.5],
       [-0.5, -0.5,  0.5]])

In [None]:
# Now write an explicit convolution function
# we'll test for a single sample: Input with batch size 1

def conv(net, Input, layerInd, Gdict):
    
    weights = net.weightList[layerInd]
    bias = net.biasList[layerInd]
    
    Nsites = Input.shape[2]
    
    Out = pt.zeros(1, weights.shape[0], Nsites).double()
    
    for chOut in range(weights.shape[0]):
        for siteInd in range(Nsites):
            for gInd, g in Gdict.items():
                Rsite = SiteIndtoR[siteInd]
                sumSite = 0.
                for chIn in range(weights.shape[1]):
                    # permute the weights
                    psi = weights[chOut, chIn]
                    
                    for ngb in range(1, net.N_ngb):
                        dxCart = np.dot(superBCC.crys.lattice, dxNN[ngb-1])
                        
                        dxCartRot = np.dot(g.cartrot, dxCart)
                        
                        dxRotLat = np.dot(np.linalg.inv(superBCC.crys.lattice), dxCartRot).astype(int)
                        
                        Rngb = (Rsite + dxRotLat) % 8
                        
                        siteIndNgb = RtoSiteInd[Rngb[0], Rngb[1], Rngb[2]]
                        
                        sumSite += Input[0, chIn, siteIndNgb]*psi[ngb]
                    
                    sumSite += Input[0, chIn, siteInd]*psi[0]
                
                sumSite += bias[chOut, 0]
                Out[0, chOut, siteInd] += F.relu(sumSite)
    return Out/net.Ng

In [None]:
# Let's do the convs explicitly and look at the gradients
optimizer.zero_grad()
TestNet.RotateParams()

In [None]:
out1 = conv(TestNet, state1Tensor, 0, GIndtoGDict)

In [None]:
out1.shape, outlayers1[0].shape

In [None]:
# compare with 0th layer output for state 1
# when using mat mult method
assert pt.allclose(out1, outlayers1[0])

In [None]:
out2 = conv(TestNet, state2Tensor, 0, GIndtoGDict)

In [None]:
# compare with 0th layer output for state 2
# when using mat mult method
assert pt.allclose(out2, outlayers2[0])

In [None]:
# from out1 and out2 construct R3 convolution
def R3Conv(Input, weights):
    Nsites = Input.shape[2]
    R3OutSites = pt.zeros(3, Nsites).double()
    for siteInd in range(Nsites):
        Rsite = SiteIndtoR[siteInd]
        sumG = pt.zeros(3).double()
        for gInd, g in GIndtoGDict.items():
            gTens = pt.tensor(g.cartrot).double()
            sumNN = pt.zeros(3).double()
            for ngb in range(1, N_ngb):
                dxCart = np.dot(superBCC.crys.lattice, dxNN[ngb-1])
                dxCartRot = np.dot(g.cartrot, dxCart)
                dxRotLat = np.dot(np.linalg.inv(superBCC.crys.lattice), dxCartRot).astype(int)
                Rngb = (Rsite + dxRotLat) % 8
                siteIndNgb = RtoSiteInd[Rngb[0], Rngb[1], Rngb[2]]
                psiG = pt.mv(gTens, weights[:, ngb])
                
                sumNN += psiG * Input[0, 0, siteIndNgb]
            sumNN += pt.mv(gTens, weights[:, 0])*Input[0, 0, siteInd]
            sumG += sumNN
        # get the shell of the site
        shellweight = TestNet.SiteShellWeights[siteInd]
        R3OutSites[:, siteInd] = shellweight * sumG/Ng
    return R3OutSites

In [None]:
R3weights = TestNet.wtVC
R3Out1 = R3Conv(out1, R3weights)

In [None]:
R3Out2 = R3Conv(out2, R3weights)

In [None]:
pt.allclose(outVecSites1[0], R3Out1)

In [None]:
pt.allclose(outVecSites2[0], R3Out2)

In [None]:
R3Out1

In [None]:
y1Exp = pt.sum(R3Out1, dim = 1)/Nsites
y2Exp = pt.sum(R3Out2, dim = 1)/Nsites

In [None]:
dyExp = (y2Exp - y1Exp).view(3)

In [None]:
dyExp, dispSpec1

In [None]:
l = loss_fn(rate, dispSpec1, dyExp)
l.backward()

In [None]:
TestNet.wtVC.grad

In [None]:
gradR3

In [None]:
# Now let's evaluate the gradients explicitly

In [26]:
def gradR3(InputR3):
    
    grad = np.ones((N_ngb, Ndim, Ndim))
    
    for ngb in range(N_ngb):
        # get the nn vector
        if ngb == 0:
            dx = np.zeros(3)
        else:
            dx = dxNN[ngb-1]
            
        for alpha in range(3):
            sumOverSites = np.zeros(3)
            for siteInd in range(Nsites):
                Rsite = SiteIndtoR[siteInd]
                
                sumOverG = np.zeros(3)
                for gInd, g in GIndtoGDict.items():
                    
                    # get the rotated nearest neighbor vector
                    dxCart = np.dot(superBCC.crys.lattice, dx)
                    dxCartRot = np.dot(g.cartrot, dxCart)
                    dxRotLat = np.dot(np.linalg.inv(superBCC.crys.lattice), dxCartRot).astype(int)
                    
                    # get the neighboring site
                    Rngb = (Rsite + dxRotLat)%8
                    siteNgb = RtoSiteInd[Rngb[0], Rngb[1], Rngb[2]]
                    # get the Input value at the nn site
                    s = InputR3[0, 0, siteNgb].detach().item()
                    
                    # Now get the vector to multiply with
                    gvec = np.array([g.cartrot[0, alpha],
                                      g.cartrot[1, alpha], g.cartrot[2, alpha]])
                    
                    sumOverG += s * gvec
                
                sumOverSites += sumOverG
            
            grad[ngb, alpha, :] = sumOverSites
    
    return grad

In [27]:
gradR3_state1 = gradR3(outlayers1[0])
gradR3_state2 = gradR3(outlayers2[0])

In [28]:
Ndim

3

In [40]:
np.set_printoptions(precision=5)
alpha = 0
for nn in range(1, 2):
    dx = dxNN[nn-1]
    dxCart = np.dot(superBCC.crys.lattice, dx)
    siteSumVec = np.zeros(3)
    SiteContribs = np.zeros(3, Nsites)
    for siteInd in range(Nsites):
        Rsite = SiteIndtoR[siteInd]
        sumVec = np.zeros(3)
        for gInd, g in GIndtoGDict.items():
            dxCartRot = np.dot(g.cartrot, dxCart)
            dxRotLat = np.dot(np.linalg.inv(superBCC.crys.lattice), dxCartRot).astype(int)

            Rngb = (Rsite + dxRotLat)%8

            siteNgb = RtoSiteInd[Rngb[0], Rngb[1], Rngb[2]]

            s = outlayers1[0][0, 0, siteNgb].detach().item()

            gvec = np.array([g.cartrot[0, alpha], g.cartrot[1, alpha], g.cartrot[2, alpha]])

            sumVec += s*gvec
#             print(s*gvec)
        siteSumVec += sumVec
        print("{}".format(siteSumVec))

[-3.65644  7.47129 -7.31287]
[-2.06584 12.87674 -1.90742]
[-5.64307 23.92525  1.82822]
[-5.40545 12.87674 20.18961]
[ 0.15841  3.65644 36.88119]
[14.70495 11.04852 40.61684]
[11.12772 18.4406  40.69604]
[14.94257 33.38317 36.88119]
[ 9.2995  53.8104  35.05298]
[-10.8109   63.1099   36.88119]
[-16.29555  72.40941  31.23813]
[-12.79753  90.69159  23.76684]
[-20.26882  90.69159  34.73615]
[-31.23813  87.19357  45.86388]
[-36.72278  89.02178  40.37922]
[-36.88119  99.9911   40.37922]
[-27.58169 105.47575  53.17675]
[-23.92525 112.78862  64.30447]
[-20.34802 112.86783  64.38368]
[-29.40991 114.77526  58.81982]
[-38.551   120.25991  53.17675]
[-45.94308 123.83714  60.41042]
[-40.53763 122.08813  62.15943]
[-40.45843 118.35249  54.92576]
[-22.09703 129.401    55.16338]
[-11.04852 125.66536  77.18121]
[-10.96931 118.11487  95.5426 ]
[ -1.82822 108.97378  94.0312 ]
[  3.65644 106.98715  99.51586]
[ 3.65644 99.67427 99.67427]
[  9.2995  101.50249 108.81536]
[ 18.5198  107.06635 114.22081]
[ 18.5

[ 16.53317 -16.53317 -19.87279]
[  5.32624 -20.26882 -16.13714]
[ 14.62575 -21.93862 -14.30892]
[ 16.69158 -16.37476  -5.24704]
[ 10.96931 -18.28219  -3.49803]
[  5.56386 -20.18961 -20.18961]
[  7.47129 -21.93862 -40.37922]
[  5.56386 -20.18961 -38.63021]
[ -5.48466 -23.76684 -42.20744]
[-18.36139 -25.67427 -51.42774]
[-23.68763 -23.84605 -56.91239]
[-23.92525 -23.76684 -64.30447]
[-34.65695 -34.81536 -67.8817 ]
[-27.34407 -38.47179 -64.38368]
[-21.93862 -22.09703 -73.60398]
[-32.90793 -18.28219 -80.91685]
[-47.61289 -21.85942 -69.86833]
[-53.25596 -20.0312  -64.38368]
[-44.11486 -14.54654 -73.52477]
[-40.53763 -14.62575 -77.26041]
[-24.16287 -16.69158 -71.69655]
[-31.47574 -16.69158 -67.8817 ]
[-24.16287 -20.18961 -49.59952]
[-16.77079  -9.2995  -49.3619 ]
[-15.02178  -7.55049 -43.95645]
[-20.42723  -9.2995  -38.39259]
[ -7.39208 -11.04852 -32.98714]
[ -1.82822 -12.79753 -27.42328]
[-18.12377 -25.59506 -14.62575]
[-18.20298 -22.01783 -10.89011]
[-5.32624 -9.14109 -8.98268]
[  3.97326 

In [None]:
shell = 2
sitesNN = set({})
for dxR in dxNN:
    siteInd = RtoSiteInd[dxR[0], dxR[1], dxR[2]]
    sitesNN.add(siteInd)

In [None]:
np.set_printoptions(precision=5)
nn = 7
alpha = 0

for nn in range(1, 2):
    dx = dxNN[nn-1]
    dxCart = np.dot(superBCC.crys.lattice, dx)
    siteSumVec = np.zeros(3)
    for siteInd in (sitesNN):
        Rsite = SiteIndtoR[siteInd]
        sumVec = np.zeros(3)
        for gInd, g in GIndtoGDict.items():
            dxCartRot = np.dot(g.cartrot, dxCart)
            dxRotLat = np.dot(np.linalg.inv(superBCC.crys.lattice), dxCartRot).astype(int)

            Rngb = (Rsite + dxRotLat)%8

            siteNgb = RtoSiteInd[Rngb[0], Rngb[1], Rngb[2]]

            s = out1[0, 0, siteNgb].detach().item()

            gvec = np.array([g.cartrot[0, alpha], g.cartrot[1, alpha], g.cartrot[2, alpha]])

            sumVec += s*gvec
    #         print(s*gvec)
        siteSumVec += sumVec
    print("{}".format(siteSumVec))

In [None]:
sitesNN

In [None]:
smVec = 0.
for site in sitesNN:
    smVec += R3Out1[:, site]

In [None]:
smVec + R3Out1[:, 0]

In [None]:
smVec2 = 0.
for site in sitesNN:
    smVec2 += R3Out2[:, site]

In [None]:
smVec2 + R3Out2[:, 0]