In [1]:
import torch as pt
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pickle
from SymNet import SymNet
import h5py

In [2]:
# Load test data
with h5py.File(r'singleStep_50.h5',"r") as fl:
    state1Grid = np.array(fl["state1Grid"][0])
    state2Grid = np.array(fl["state2Grid"][0])
    rateArray = np.array(fl["rateList"][0])
    dispAll = np.array(fl["dispList"][0])

dxNN = np.load("nnJumpLatVecs.npy")
RtoSiteInd = np.load("RtoSiteInd.npy")
SiteIndtoR = np.load("SiteIndtoR.npy")
GpermNNIdx = np.load("GroupNNpermutations.npy")

NNsiteList = np.load("NNsites_sitewise.npy")
N_ngb = NNsiteList.shape[0]
Nsites = NNsiteList.shape[1]
with open("GroupOpsIndices.pkl", "rb") as fl:
    GIndtoGDict = pickle.load(fl)

with open("supercellBCC.pkl", "rb") as fl:
    superBCC = pickle.load(fl)

In [3]:
N_ngb, Nsites

(9, 512)

## Make the symmetry-related parameters

In [4]:
GnnPerms = pt.tensor(GpermNNIdx).long()
NNsites = pt.tensor(NNsiteList)

Nlayers = 1
NchOuts = [1]

Ng = GnnPerms.shape[0]
Ndim = 3
gdiags = pt.zeros(Ng*Ndim, Ng*Ndim).double()
for gInd, g in GIndtoGDict.items():
    rowStart = gInd * Ndim
    rowEnd = (gInd + 1) * Ndim
    gdiags[rowStart : rowEnd, rowStart : rowEnd] = pt.tensor(g.cartrot).double()

In [5]:
state1Tensor = pt.tensor(state1Grid/2.0).double().view(1, 1, 512)
state2Tensor = pt.tensor(state2Grid/2.0).double().view(1, 1, 512)
dispSpec1 = pt.tensor(dispAll[1, :]).double()
rate = pt.tensor(rateArray).double()

In [6]:
state2Tensor.shape

torch.Size([1, 1, 512])

In [7]:
# Make shells for symmetry-grouping sites
shells = []
AllSites = set()
for siteInd in range(Nsites):
    if siteInd not in AllSites:
        Rsite = SiteIndtoR[siteInd]
        newShell = set()
        for gIng, g in GIndtoGDict.items():
            Rnew, _ = superBCC.crys.g_pos(g, Rsite, (0, 0))
            siteIndNew = RtoSiteInd[Rnew[0], Rnew[1], Rnew[2]]
            newShell.add(siteIndNew)
        AllSites.update(newShell)
        shells.append(list(newShell))

In [8]:
# Next, let's sort them according to distance
def sortkey(shell):
    R = SiteIndtoR[shell[0]]
    xmin = np.linalg.norm(np.dot(superBCC.crys.lattice, R))
    for siteInd in shell:
        R = SiteIndtoR[siteInd]
        x = np.linalg.norm(np.dot(superBCC.crys.lattice, R))
        if x < xmin:
            xmin = x
    return xmin
shellsSorted = sorted(shells, key=sortkey)

In [9]:
# Now let's assign shells to sites
SitesToShells = pt.zeros(Nsites).long()
for shellInd, shell in enumerate(shellsSorted):
    for siteInd in shell:
        SitesToShells[siteInd] = shellInd
pt.unique(SitesToShells, return_counts=True)

(tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28]),
 tensor([ 1,  8,  6, 12, 24,  8,  6, 24, 24, 24,  8, 24, 12, 48, 24,  6, 24, 24,
         24,  4, 24, 24, 24, 24, 48,  3, 12, 12,  6]))

## Let's do backpropagation tests to make sure gradients are calculated correctly

In [10]:
# Let's take a dummy loss function that tries to force
# all vectors to -1
def loss_fn(rate, dx, dy):
    return pt.sum(rate * pt.norm(dx + dy)**2)/6.

In [11]:
TestNet = SymNet(Nlayers, NchOuts, GnnPerms, GIndtoGDict, gdiags, NNsites,
                 SitesToShells, Ndim, act="relu").double()

optimizer = pt.optim.Adam(TestNet.parameters(), lr = 0.001)
TestNet.RotateParams()

# Now let's first evaluate gradients with the network

# We only use the first sample for now
optimizer.zero_grad()
InLayers1, outlayersG1, outlayers1, outVecSites1, y1 = TestNet.forward(state1Tensor,
                                                                       UseShellWeights=False, Test=True)
InLayers2, outlayersG2, outlayers2, outVecSites2, y2 = TestNet.forward(state2Tensor,
                                                                       UseShellWeights=False, Test=True)
dy = (y2-y1).view(3)
l = loss_fn(rate, dispSpec1, dy)
l.backward()

  allow_unreachable=True)  # allow_unreachable flag


In [12]:
gradLayer1 = TestNet.weightList[0].grad.data.numpy().copy()

In [13]:
gradR3 = TestNet.wtVC.grad.data.numpy().copy()

In [14]:
gradLayer1

array([[[0., 0., 0., 0., 0., 0., 0., 0., 0.]]])

In [15]:
for i in range(N_ngb):
    print(gradR3[:, i])

[0. 0. 0.]
[3.88578059e-16 3.88578059e-16 3.88578059e-16]
[-3.88578059e-16 -3.88578059e-16 -3.88578059e-16]
[ 3.88578059e-16 -3.88578059e-16 -3.88578059e-16]
[-3.88578059e-16  3.88578059e-16  3.88578059e-16]
[ 3.88578059e-16 -3.88578059e-16  3.88578059e-16]
[-3.88578059e-16  3.88578059e-16 -3.88578059e-16]
[ 3.88578059e-16  3.88578059e-16 -3.88578059e-16]
[-3.88578059e-16 -3.88578059e-16  3.88578059e-16]


In [16]:
np.dot(superBCC.crys.lattice, dxNN.T).T

array([[ 0.5,  0.5,  0.5],
       [-0.5, -0.5, -0.5],
       [ 0.5, -0.5, -0.5],
       [-0.5,  0.5,  0.5],
       [ 0.5, -0.5,  0.5],
       [-0.5,  0.5, -0.5],
       [ 0.5,  0.5, -0.5],
       [-0.5, -0.5,  0.5]])

In [None]:
# Now let's evaluate the gradients explicitly

In [17]:
def gradR3(InputR3):
    
    grad = np.ones((N_ngb, Ndim, Ndim))
    
    for ngb in range(N_ngb):
        # get the nn vector
        if ngb == 0:
            dx = np.zeros(3)
        else:
            dx = dxNN[ngb-1]
            
        for alpha in range(3):
            sumOverSites = np.zeros(3)
            for siteInd in range(Nsites):
                Rsite = SiteIndtoR[siteInd]
                
                sumOverG = np.zeros(3)
                for gInd, g in GIndtoGDict.items():
                    
                    # get the rotated nearest neighbor vector
                    dxCart = np.dot(superBCC.crys.lattice, dx)
                    dxCartRot = np.dot(g.cartrot, dxCart)
                    dxRotLat = np.dot(np.linalg.inv(superBCC.crys.lattice), dxCartRot).astype(int)
                    
                    # get the neighboring site
                    Rngb = (Rsite + dxRotLat)%8
                    siteNgb = RtoSiteInd[Rngb[0], Rngb[1], Rngb[2]]
                    # get the Input value at the nn site
                    s = InputR3[0, 0, siteNgb].detach().item()
                    
                    # Now get the vector to multiply with
                    gvec = np.array([g.cartrot[0, alpha],
                                      g.cartrot[1, alpha], g.cartrot[2, alpha]])
                    
                    sumOverG += s * gvec
                
                sumOverSites += sumOverG
            
            grad[ngb, alpha, :] = sumOverSites
    
    return grad

In [18]:
gradR3_state1 = gradR3(outlayers1[0])
gradR3_state2 = gradR3(outlayers2[0])

In [19]:
gradR3_state1

array([[[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]],

       [[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]],

       [[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]],

       [[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]],

       [[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]],

       [[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]],

       [[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]],

       [[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]],

       [[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]]])

In [28]:
np.set_printoptions(precision=5)
nn = 7
alpha = 0

for nn in range(1, 2):
    dx = dxNN[nn-1]
    dxCart = np.dot(superBCC.crys.lattice, dx)
    SiteGrads = np.zeros((3, 512))
    siteSumVec = np.zeros(3)
    for siteInd in range(Nsites):
        Rsite = SiteIndtoR[siteInd]
        sumVec = np.zeros(3)
        for gInd, g in GIndtoGDict.items():
            dxCartRot = np.dot(g.cartrot, dxCart)
            dxRotLat = np.dot(np.linalg.inv(superBCC.crys.lattice), dxCartRot).astype(int)

            Rngb = (Rsite + dxRotLat)%8

            siteNgb = RtoSiteInd[Rngb[0], Rngb[1], Rngb[2]]

            s = outlayers1[0][0, 0, siteNgb].detach().item()

            gvec = np.array([g.cartrot[0, alpha], g.cartrot[1, alpha], g.cartrot[2, alpha]])

            sumVec += s*gvec
    #         print(s*gvec)
        SiteGrads[:, siteInd] = sumVec
        siteSumVec += sumVec
    print("{}".format(siteSumVec))

[0. 0. 0.]


In [30]:
SiteGrads[:, :4]

array([[ -5.18219,   1.36567,  -4.77371,   1.22542],
       [ 11.18132,   7.3648 ,  15.95503, -15.95503],
       [-10.36437,   7.3648 ,   5.59066,  26.3194 ]])

In [39]:
np.sum(SiteGrads,axis=1)

array([0., 0., 0.])

In [51]:
sumVec = 0.
Rsite = SiteIndtoR[2]
for gInd, g in GIndtoGDict.items():
    dxCartRot = np.dot(g.cartrot, dxCart)
    dxRotLat = np.dot(np.linalg.inv(superBCC.crys.lattice), dxCartRot).astype(int)

    Rngb = (Rsite + dxRotLat)%8

    siteNgb = RtoSiteInd[Rngb[0], Rngb[1], Rngb[2]]

    s = outlayers1[0][0, 0, siteNgb].detach().item()

    gvec = np.array([g.cartrot[0, alpha], g.cartrot[1, alpha], g.cartrot[2, alpha]])
    if not np.allclose(gvec[1], 0.0):
        print("{}\t{}\t{}\t{}".format(siteNgb, s, gvec[1], s*gvec[1]))
        sumVec += s*gvec[1]
sumVec

58	9.359555155038834	1.0	9.359555155038834
450	8.064008824527264	-1.0	-8.064008824527264
3	10.655101485550404	1.0	10.655101485550404
10	6.768462494015694	-1.0	-6.768462494015694
66	10.655101485550404	1.0	10.655101485550404
1	10.655101485550404	-1.0	-10.655101485550404
58	9.359555155038834	1.0	9.359555155038834
3	10.655101485550404	1.0	10.655101485550404
450	8.064008824527264	-1.0	-8.064008824527264
66	10.655101485550404	1.0	10.655101485550404
505	7.859771877527237	-1.0	-7.859771877527237
75	10.655101485550404	1.0	10.655101485550404
505	7.859771877527237	-1.0	-7.859771877527237
1	10.655101485550404	-1.0	-10.655101485550404
75	10.655101485550404	1.0	10.655101485550404
10	6.768462494015694	-1.0	-6.768462494015694


15.955029860138893

In [50]:
sumVec = 0.
Rsite = SiteIndtoR[3]
for gInd, g in GIndtoGDict.items():
    dxCartRot = np.dot(g.cartrot, dxCart)
    dxRotLat = np.dot(np.linalg.inv(superBCC.crys.lattice), dxCartRot).astype(int)

    Rngb = (Rsite + dxRotLat)%8

    siteNgb = RtoSiteInd[Rngb[0], Rngb[1], Rngb[2]]

    s = outlayers1[0][0, 0, siteNgb].detach().item()

    gvec = np.array([g.cartrot[0, alpha], g.cartrot[1, alpha], g.cartrot[2, alpha]])
    
    if not np.allclose(gvec[1], 0.0):
        print("{}\t{}\t{}\t{}".format(siteNgb, s, gvec[1], s*gvec[1]))
        sumVec += s*gvec[1]
sumVec

59	6.564225547015667	1.0	6.564225547015667
451	9.359555155038834	-1.0	-9.359555155038834
4	8.064008824527264	1.0	8.064008824527264
11	11.950647816061974	-1.0	-11.950647816061974
67	10.450864538550377	1.0	10.450864538550377
2	13.246194146573544	-1.0	-13.246194146573544
59	6.564225547015667	1.0	6.564225547015667
4	8.064008824527264	1.0	8.064008824527264
451	9.359555155038834	-1.0	-9.359555155038834
67	10.450864538550377	1.0	10.450864538550377
506	9.155318208038807	-1.0	-9.155318208038807
76	10.655101485550404	1.0	10.655101485550404
506	9.155318208038807	-1.0	-9.155318208038807
2	13.246194146573544	-1.0	-13.246194146573544
76	10.655101485550404	1.0	10.655101485550404
11	11.950647816061974	-1.0	-11.950647816061974


-15.955029860138893

In [52]:
SiteIndtoR[58]

array([0, 7, 2])

In [53]:
SiteIndtoR[451]

array([7, 0, 3])

In [58]:
for ngb in dxNN:
    Rngb = (SiteIndtoR[451] + ngb)%8
    print(state2Tensor[0, 0, RtoSiteInd[Rngb[0], Rngb[1], Rngb[2]]].item())

0.5
0.0
0.0
0.0
0.0
0.5
0.5
0.0


In [59]:
for ngb in dxNN:
    Rngb = (SiteIndtoR[58] + ngb)%8
    print(state2Tensor[0, 0, RtoSiteInd[Rngb[0], Rngb[1], Rngb[2]]].item())

0.5
0.0
0.5
0.5
0.0
0.5
0.5
0.0
