In [1]:
import sys
sys.path.append("../")

In [2]:
import numpy as np
import torch as pt
import torch.nn as nn
import torch.nn.functional as F
import pickle

In [3]:
from onsager import crystal, cluster, supercell
from SymmLayers import GConv, GAvg, R3Conv

In [4]:
# Read and convert all the crystal data
# Load test data
TestStates = np.load("../CrystalDat/TestStates.npy")
dxNN = np.load("../CrystalDat/nnJumpLatVecs.npy")
RtoSiteInd = np.load("../CrystalDat/RtoSiteInd.npy")
SiteIndtoR = np.load("../CrystalDat/SiteIndtoR.npy")
GpermNNIdx = np.load("../CrystalDat/GroupNNpermutations.npy")

NNsiteList = np.load("../CrystalDat/NNsites_sitewise.npy")
N_ngb = NNsiteList.shape[0]
Nsites = NNsiteList.shape[1]
with open("../CrystalDat/GroupOpsIndices.pkl", "rb") as fl:
    GIndtoGDict = pickle.load(fl)

with open("../CrystalDat/supercellBCC.pkl", "rb") as fl:
    superBCC = pickle.load(fl)

In [5]:
N_ngb, Nsites

(9, 512)

In [6]:
N_units = 8

## Try constructing with half-translation

In [7]:
def periodic(x, N_units):
    half = N_units//2
    
    if x > half:
        x -= N_units
    
    elif x < -half:
        x += N_units     

    return x

periodNp = np.frompyfunc(periodic, 2, 1)

# Next, let's sort them according to distance
def sortkey(shell):
    R = periodNp(SiteIndtoR[shell[0]], N_units)
    xmin = np.linalg.norm(np.dot(superBCC.crys.lattice, R))
    for siteInd in shell:
        R = periodNp(SiteIndtoR[siteInd], N_units)
        x = np.linalg.norm(np.dot(superBCC.crys.lattice, R))
        if x < xmin:
            xmin= x
    return xmin

In [8]:
# Make shells for symmetry-grouping sites
shells = []
BCvec = np.ones(3, dtype=int)
AllSites = set()
for siteInd in range(Nsites):
    if siteInd not in AllSites:
        Rsite = periodNp(SiteIndtoR[siteInd], N_units)
        newShell = set()
        for gIng, g in GIndtoGDict.items():
            Rnew = periodNp(superBCC.crys.g_pos(g, Rsite, (0, 0))[0], N_units)
            siteIndNew = RtoSiteInd[Rnew[0], Rnew[1], Rnew[2]]
            newShell.add(siteIndNew)
        AllSites.update(newShell)
        shells.append(list(newShell))

In [9]:
shellsSorted = sorted(shells, key=sortkey)

In [10]:
# Now let's assign shells to sites
SitesToShells = pt.zeros(Nsites).long()
for shellInd, shell in enumerate(shellsSorted):
    for siteInd in shell:
        SitesToShells[siteInd] = shellInd
pt.unique(SitesToShells, return_counts=True)

(tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28]),
 tensor([ 1,  8,  6, 12, 24,  8,  6, 24, 24, 24,  8, 24, 12, 48, 24,  6, 24, 24,
         24,  4, 24, 24, 24, 24, 48,  3, 12, 12,  6]))

## Next we make symmetry parameters

In [20]:
GnnPerms = pt.tensor(GpermNNIdx)
NNsites = pt.tensor(NNsiteList)

Ng = GnnPerms.shape[0]
Ndim = 3
gdiags = pt.zeros(Ng*Ndim, Ng*Ndim).double()
for gInd, g in GIndtoGDict.items():
    rowStart = gInd * Ndim
    rowEnd = (gInd + 1) * Ndim
    gdiags[rowStart : rowEnd, rowStart : rowEnd] = pt.tensor(g.cartrot)

In [21]:
# Now let's build the network
class GCNet(nn.Module):
    def __init__(self, GnnPerms, NNsites, SitesToShells,
                dim=3, N_ngb=N_ngb, mean=1.0, std=0.1):
        
        super().__init__()
        
        self.net = nn.Sequential(
            GConv(1, 4, GnnPerms, NNsites, N_ngb, mean=1.0, std=0.1),
            nn.Softplus(),
            GAvg(),
            GConv(4, 1, GnnPerms, NNsites, N_ngb, mean=1.0, std=0.1),
            nn.Softplus(),
            GAvg(),
            R3Conv(SitesToShells, GnnPerms, gdiags, NNsites, N_ngb,
                   dim, mean=1.0, std=0.1)        
        )
    
    def forward(self, InState):
        y = self.net(InState)
        return y

In [22]:
gNet = GCNet(GnnPerms.long(), NNsites, SitesToShells,
                dim=3, N_ngb=N_ngb, mean=1.0, std=0.1).double()

In [35]:
Nbatch = TestStates.shape[0]
StateTensors = pt.tensor(TestStates/2.0).double().view(Nbatch, 1, TestStates.shape[1])
StateTensors.shape

torch.Size([48, 1, 512])

In [36]:
y=gNet(StateTensors)

In [37]:
y_np = y.data.numpy().copy()

In [40]:
y0 = y_np[0].copy()
for gInd, g in GIndtoGDict.items():
    y = y_np[gInd]
    assert np.allclose(np.dot(g.cartrot,y0), y)
print("Symmetry assertions passed")

Symmetry assertions passed
