In [1]:
import numpy as np
import torch as pt
import pickle
from onsager import crystal, cluster, supercell

In [2]:
# Load the data
TestStates = np.load("TestStates.npy")
dxNN = np.load("nnJumpLatVecs.npy")
RtoSiteInd = np.load("RtoSiteInd.npy")
SiteIndtoR = np.load("SiteIndtoR.npy")
GpermNNIdx = np.load("GroupNNpermutations.npy")

with open("GroupOpsIndices.pkl", "rb") as fl:
    GIndtoGDict = pickle.load(fl)

with open("supercellBCC.pkl", "rb") as fl:
    superBCC = pickle.load(fl)

In [43]:
# Convert states to tensors
StateTensors = pt.tensor(TestStates/2.0).double()
permTensor = pt.tensor(GpermNNIdx).long()

In [44]:
# Make a randomized filter (single channel first)
Psi = pt.rand(TestStates.shape[1], requires_grad=True).double()

In [51]:
Psi_repeat_perm = pt.gather(Psi.unsqueeze(0).repeat(GpermNNIdx.shape[0],1), 1, permTensor)

CPU times: user 872 µs, sys: 107 µs, total: 979 µs
Wall time: 1.61 ms


In [54]:
state0 = StateTensors[0]
state1 = StateTensors[1]
u = GIndtoGDict[1]

In [66]:
# Do the convolution with the original state
out0 = pt.nn.functional.softplus(pt.matmul(Psi_repeat_perm, state0))

# Do the convolution with a rotated state
out1 = pt.nn.functional.softplus(pt.matmul(Psi_repeat_perm, state1))

In [67]:
# Transform the product with the original state
out0Transf = pt.zeros_like(out0)

In [68]:
GtoGIndDict = {}
for gInd, g in GIndtoGDict.items():
    GtoGIndDict[g] = gInd

In [69]:
# Multiply all Group ops with u^-1 and store new indices
uInv = u.inv()
for g, gInd in GtoGIndDict.items():
    newGInd = GtoGIndDict[uInv*g]
    for siteInd in range(state0.shape[1]):
        Rsite = SiteIndtoR[siteInd]
        RsiteNew, (ch, idx) = superBCC.crys.g_pos(uInv, Rsite, (0,0))
        RsiteNew %= 8
        siteIndNew = RtoSiteInd[RsiteNew[0], RsiteNew[1], RsiteNew[2]]
        
        out0Transf[gInd, siteInd] = out0[newGInd, siteIndNew]

In [70]:
pt.allclose(out0Transf, out1)

True