In [1]:
import numpy as np
import torch as pt
import torch.nn as nn
import torch.nn.functional as F
import pickle

In [3]:
from onsager import crystal, cluster, supercell
from SymmLayers import GConv, GAvg, R3ConvSites, R3Conv

# Load Crystal data

In [6]:
# Read crystal data
CrysDatPath = "../CrysDat_BCC/"

# Load the test batch of random states related to each other by symmetry
TestStates = np.load(CrysDatPath + "TestStates.npy")

# Load symmetry operations with which they were constructed
with open(CrysDatPath + "GroupOpsIndices.pkl", "rb") as fl:
    GIndtoGDict = pickle.load(fl)

dxNN = np.load(CrysDatPath + "nnJumpLatVecs_2nn.npy")
RtoSiteInd = np.load(CrysDatPath + "RtoSiteInd.npy")
SiteIndtoR = np.load(CrysDatPath + "SiteIndtoR.npy")
GpermNNIdx = np.load(CrysDatPath + "GroupNNpermutations_2nn.npy")
siteShellIndices = np.load(CrysDatPath + "SitesToShells.npy")

NNsiteList = np.load(CrysDatPath + "NNsites_sitewise_2nn.npy")
N_ngb = NNsiteList.shape[0]
Nsites = NNsiteList.shape[1]

with open(CrysDatPath + "supercellBCC.pkl", "rb") as fl:
    superBCC = pickle.load(fl)

In [7]:
N_ngb, Nsites

(15, 512)

In [8]:
N_units = 8

## Next we make symmetry parameters

In [9]:
GnnPerms = pt.tensor(GpermNNIdx)
NNsites = pt.tensor(NNsiteList)
SitesToShells = pt.tensor(siteShellIndices).long()

Ng = GnnPerms.shape[0]
Ndim = 3
gdiags = pt.zeros(Ng*Ndim, Ng*Ndim).double()
for gInd, g in GIndtoGDict.items():
    rowStart = gInd * Ndim
    rowEnd = (gInd + 1) * Ndim
    gdiags[rowStart : rowEnd, rowStart : rowEnd] = pt.tensor(g.cartrot)

# Now let's build the network

In [10]:
# for now, let's make an 8-channel, 4-layer network
class GCNet(nn.Module):
    def __init__(self, GnnPerms, NNsites, SitesToShells,
                dim=3, N_ngb=N_ngb, mean=0., std=0.1):
        
        super().__init__()
        
        self.net = nn.Sequential(
            GConv(1, 8, GnnPerms, NNsites, N_ngb, mean=mean, std=std),
            nn.Softplus(),
            GAvg(),
            
            GConv(8, 8, GnnPerms, NNsites, N_ngb, mean=mean, std=std),
            nn.Softplus(),
            GAvg(),

            GConv(8, 8, GnnPerms, NNsites, N_ngb, mean=mean, std=std),
            nn.Softplus(),
            GAvg(),

            GConv(8, 8, GnnPerms, NNsites, N_ngb, mean=mean, std=std),
            nn.Softplus(),
            GAvg(),

            # The last GConv layer must have a single out channel
            GConv(8, 1, GnnPerms, NNsites, N_ngb, mean=mean, std=std),
            nn.Softplus(),
            GAvg(),

            R3ConvSites(SitesToShells, GnnPerms, gdiags, NNsites, N_ngb,
                   dim, mean=mean, std=std*10)        
        )
    
    def forward(self, InState):
        y = self.net(InState)
        return y

In [11]:
m = 0.
s = .2
gNet = GCNet(GnnPerms.long(), NNsites, SitesToShells,
                dim=3, N_ngb=N_ngb, mean=m, std=s).double()

## Now Pass the Input states and get the output

In [13]:
Nbatch = TestStates.shape[0]
StateTensors = pt.tensor(TestStates/2.0).double().view(Nbatch, 1, TestStates.shape[1])
StateTensors.shape

torch.Size([48, 1, 512])

In [14]:
y=gNet(StateTensors)

In [15]:
y_np = y.data.numpy().copy()

In [16]:
y_np.shape

(48, 3, 512)

In [17]:
# Check that subsequent y vectors are related by symmetry
y0 = y_np[0].copy()
for gInd, g in GIndtoGDict.items():
    for site in range(512):
        Rsite = SiteIndtoR[site]
        RsiteNew, _ = superBCC.crys.g_pos(g, Rsite, (0, 0))
        RsiteNew = RsiteNew % 8 # bring back into supercell
        siteNew = RtoSiteInd[RsiteNew[0], RsiteNew[1], RsiteNew[2]]
        assert np.allclose(np.dot(g.cartrot, y0[:, site]), y_np[gInd, :, siteNew])
        
print("Vector Layer Symmetry assertions passed")

Vector Layer Symmetry assertions passed


In [18]:
def checkSymLayer(out, Nch):
    for ch in range(Nch):
        outsamp0 = out[0, ch, :]
        for gInd, g in GIndtoGDict.items():
            outsamp = pt.zeros_like(outsamp0)
            for siteInd in range(512):
                RSite = SiteIndtoR[siteInd]
                Rnew, _ = superBCC.crys.g_pos(g, RSite, (0,0))
                Rnew %= 8
                siteIndNew = RtoSiteInd[Rnew[0], Rnew[1], Rnew[2]]
                outsamp[siteIndNew] = outsamp0[siteInd]
            assert pt.allclose(outsamp, out[gInd, ch, :])

In [19]:
# Check symmetries of internal layers:
out = pt.clone(StateTensors)
for l in range(0, 15, 3):
    assert out.shape[1]==gNet.net[l].Psi.shape[1]
    out = gNet.net[l].forward(out)
    out = gNet.net[l+1].forward(out)
    out = gNet.net[l+2].forward(out)
    Nch = out.shape[1]
    assert Nch==gNet.net[l].Psi.shape[0]
    checkSymLayer(out, Nch)
    print("Layer {} symmetry assertion passed".format((l + 3)//3))

Layer 1 symmetry assertion passed
Layer 2 symmetry assertion passed
Layer 3 symmetry assertion passed
Layer 4 symmetry assertion passed
Layer 5 symmetry assertion passed


In [20]:
out.shape

torch.Size([48, 1, 512])

In [21]:
# Check site-weighted vector sum
from SymmLayers import R3Conv
C3Conv = R3Conv(SitesToShells, GnnPerms.long(), gdiags, NNsites, N_ngb,
                   dim=3, mean=m, std=s).double()

y = C3Conv(out).detach().numpy()

In [22]:
for gInd, g in GIndtoGDict.items():
    assert np.allclose(np.dot(g.cartrot, y[0]), y[gInd])
print("Weighted sum of vector tests passed")

Weighted sum of vector tests passed
