In [10]:
%config Completer.use_jedi = False

In [1]:
import sys
sys.path.append("/home/sohamc2/VKMC/SymNetworkRuns/CE_Symmetry/Symm_Network/")

In [2]:
import numpy as np
import pickle
import torch as pt
import torch.nn as nn
import torch.nn.functional as F
from SymmLayers import GConv, R3Conv, R3ConvSites, GAvg

In [3]:
class GCNet(nn.Module):
    def __init__(self, GnnPerms, gdiags, NNsites, SitesToShells,
                dim, N_ngb, NSpec, mean=0.0, std=0.1, b=1.0, nl=3, nch=8):

        super().__init__()
        modules = []
        modules += [GConv(NSpec, nch, GnnPerms, NNsites, N_ngb, mean=mean, std=std),
                nn.Softplus(beta=b), GAvg()]

        for i in range(nl):
            modules += [GConv(nch, nch, GnnPerms, NNsites, N_ngb, mean=mean, std=std),
                    nn.Softplus(beta=b), GAvg()]

        modules += [GConv(nch, 1, GnnPerms, NNsites, N_ngb, mean=mean, std=std),
                nn.Softplus(beta=b), GAvg()]

        modules += [R3ConvSites(SitesToShells, GnnPerms, gdiags, NNsites, N_ngb,
            dim, mean=mean, std=std)]

        self.net = nn.Sequential(*modules)

    def forward(self, InState):
        y = self.net(InState)
        return y

In [4]:
CrysDatPath = "../CrysDat/"

In [5]:
dxNN = np.load(CrysDatPath + "nnJumpLatVecs.npy")
RtoSiteInd = np.load(CrysDatPath + "RtoSiteInd.npy")
SiteIndtoR = np.load(CrysDatPath + "SiteIndtoR.npy")
GpermNNIdx = np.load(CrysDatPath + "GroupNNpermutations.npy")
NNsiteList = np.load(CrysDatPath + "NNsites_sitewise.npy")
siteShellIndices = np.load(CrysDatPath + "SitesToShells.npy")
N_ngb = NNsiteList.shape[0]
Nsites = NNsiteList.shape[1]
with open(CrysDatPath + "GroupCartIndices.pkl", "rb") as fl:
    GIndtoGDict = pickle.load(fl)

N_ngb, GpermNNIdx.shape, NNsiteList.shape

"""## Select the torch device"""

if pt.cuda.is_available():
    device = pt.device("cuda:0")
else:
    device = pt.device("cpu")

print(device)
print(pt.__version__)
"""## Make the required crystal data tensors - use double for now"""

SitesToShells = pt.tensor(siteShellIndices).long().to(device)
GnnPerms = pt.tensor(GpermNNIdx).long().to(device)
NNsites = pt.tensor(NNsiteList).long().to(device)

Ng = GnnPerms.shape[0]
Ndim = 3
gdiagsCpu = pt.zeros(Ng*Ndim, Ng*Ndim).double()
for gInd, gCart in GIndtoGDict.items():
    rowStart = gInd * Ndim
    rowEnd = (gInd + 1) * Ndim
    gdiagsCpu[rowStart : rowEnd, rowStart : rowEnd] = pt.tensor(gCart)
gdiags = gdiagsCpu.to(device)

cpu
1.7.1


In [6]:
NSpec = 5
nLayers = 3
gNet = GCNet(GnnPerms, gdiags, NNsites, SitesToShells, Ndim, N_ngb, NSpec,
        mean=0.02, std=0.02, b=1.0, nl=nLayers, nch=16).double().to(device)

In [7]:
for layerInd in range(0, len(gNet.net)-1, 3):
    print(gNet.net[layerInd].Psi.shape)
    print(gNet.net[layerInd].bias.shape)
    print()

torch.Size([16, 5, 13])
torch.Size([16, 1])

torch.Size([16, 16, 13])
torch.Size([16, 1])

torch.Size([16, 16, 13])
torch.Size([16, 1])

torch.Size([16, 16, 13])
torch.Size([16, 1])

torch.Size([1, 16, 13])
torch.Size([1, 1])



In [8]:
for layerInd in range(0, len(gNet.net)-1, 3):
    gNet.net[layerInd].Psi.requires_grad = False
    gNet.net[layerInd].bias.requires_grad = False
print(list(filter(lambda p : p.requires_grad, gNet.parameters())))

[Parameter containing:
tensor([[ 0.0265,  0.0281,  0.0437,  0.0386,  0.0497,  0.0072,  0.0125,  0.0316,
         -0.0057,  0.0008,  0.0442,  0.0411,  0.0178],
        [ 0.0329, -0.0093,  0.0099,  0.0052,  0.0500,  0.0340,  0.0621,  0.0356,
          0.0228,  0.0387,  0.0199, -0.0173,  0.0203],
        [ 0.0255, -0.0066,  0.0225,  0.0496,  0.0047,  0.0216,  0.0031,  0.0137,
          0.0060,  0.0279,  0.0157,  0.0093,  0.0007]], dtype=torch.float64,
       requires_grad=True), Parameter containing:
tensor([-0.0017,  0.0354,  0.0036,  0.0100,  0.0041,  0.0189,  0.0334,  0.0069,
        -0.0123,  0.0151,  0.0315,  0.0159,  0.0506,  0.0083,  0.0258,  0.0475,
         0.0441,  0.0144,  0.0008,  0.0045,  0.0311,  0.0291,  0.0149,  0.0203,
         0.0055, -0.0037,  0.0494, -0.0033,  0.0366], dtype=torch.float64,
       requires_grad=True)]


In [11]:
R3C = gNet.net[-1]
R3C.wtVC, R3C.ShellWeights

(Parameter containing:
 tensor([[ 0.0265,  0.0281,  0.0437,  0.0386,  0.0497,  0.0072,  0.0125,  0.0316,
          -0.0057,  0.0008,  0.0442,  0.0411,  0.0178],
         [ 0.0329, -0.0093,  0.0099,  0.0052,  0.0500,  0.0340,  0.0621,  0.0356,
           0.0228,  0.0387,  0.0199, -0.0173,  0.0203],
         [ 0.0255, -0.0066,  0.0225,  0.0496,  0.0047,  0.0216,  0.0031,  0.0137,
           0.0060,  0.0279,  0.0157,  0.0093,  0.0007]], dtype=torch.float64,
        requires_grad=True),
 Parameter containing:
 tensor([-0.0017,  0.0354,  0.0036,  0.0100,  0.0041,  0.0189,  0.0334,  0.0069,
         -0.0123,  0.0151,  0.0315,  0.0159,  0.0506,  0.0083,  0.0258,  0.0475,
          0.0441,  0.0144,  0.0008,  0.0045,  0.0311,  0.0291,  0.0149,  0.0203,
          0.0055, -0.0037,  0.0494, -0.0033,  0.0366], dtype=torch.float64,
        requires_grad=True))

## Now verify the symmetry

In [17]:
rand_state = np.zeros((10, 5, 512))
for samp in range(10):
    for site in range(1, 512):
        ch = np.random.randint(0, 5)
        rand_state[samp, ch, site] = 1.0

In [18]:
rand_state = pt.tensor(rand_state)
rand_state.shape

torch.Size([10, 5, 512])

In [19]:
y = gNet(rand_state).detach().numpy()

In [20]:
G = GIndtoGDict[20]

In [21]:
with open("../CrysDat/supercellFCC.pkl", "rb") as fl:
    superFCC = pickle.load(fl)

In [22]:
rand_state_G = pt.zeros_like(rand_state)
sitePerm = np.zeros(512, dtype=int)
for site in range(512):
    Rsite = SiteIndtoR[site]
    RsiteCart = np.dot(superFCC.crys.lattice, Rsite)
    RsiteCartNew = np.dot(G, RsiteCart)
    RsiteNew, _ = superFCC.crys.cart2pos(RsiteCartNew)
    RsiteNew = RsiteNew % 8
    siteNew = RtoSiteInd[RsiteNew[0], RsiteNew[1], RsiteNew[2]]
    sitePerm[site] = siteNew
    for samp in range(10):
        occ = rand_state[samp, :, site]
        rand_state_G[samp, :, siteNew] = occ[:]

In [23]:
y2 = gNet(rand_state_G).detach().numpy()

In [24]:
for site in range(512):
    for samp in range(10):
        assert np.allclose(np.dot(G, y[samp, :, site]), y2[samp, :, sitePerm[site]])
print("Symmetry assertion complete")

Symmetry assertion complete
