In [1]:
import numpy as np
import torch as pt
import pickle
from onsager import crystal, cluster, supercell
from tqdm import tqdm

In [2]:
# Load the data
TestStates = np.load("TestStates.npy")[:10]
dxNN = np.load("nnJumpLatVecs.npy")
RtoSiteInd = np.load("RtoSiteInd.npy")
SiteIndtoR = np.load("SiteIndtoR.npy")
GpermNNIdx = np.load("GroupNNpermutations.npy")

NNsiteList = np.load("NNsites_sitewise.npy")
N_ngb = NNsiteList.shape[0]

with open("GroupOpsIndices.pkl", "rb") as fl:
    GIndtoGDict = pickle.load(fl)

with open("supercellBCC.pkl", "rb") as fl:
    superBCC = pickle.load(fl)

## Expand the states

In [3]:
StateTensors = pt.tensor(TestStates/2.0).double().view(TestStates.shape[0], 1, TestStates.shape[1])
StateTensors.shape

torch.Size([10, 1, 512])

In [4]:
StateTensors_repeat = StateTensors.repeat_interleave(N_ngb, dim=1)
StateTensors_repeat.shape

torch.Size([10, 9, 512])

In [5]:
for i in range(0, StateTensors.shape[0]):
    for j in range(N_ngb):
        assert pt.allclose(StateTensors[i], StateTensors_repeat[i,j]), "{} {}".format(i,j)

In [6]:
# Next, repeat the NNsites tensor
NNsiteTensor = pt.tensor(NNsiteList).long()
NNsitesBatchRepeat = NNsiteTensor.unsqueeze(0).repeat(StateTensors_repeat.shape[0], 1, 1)

In [7]:
NNsiteTensor.shape

torch.Size([9, 512])

In [8]:
NNsitesBatchRepeat.shape

torch.Size([10, 9, 512])

In [9]:
for i in range(0, StateTensors_repeat.shape[0]):
    assert pt.equal(NNsiteTensor, NNsitesBatchRepeat[i])

In [10]:
StatesTensorFull = pt.gather(StateTensors_repeat, 2, NNsitesBatchRepeat)
StatesTensorFull.shape

torch.Size([10, 9, 512])

In [11]:
for stateInd in range(StateTensors.shape[0]):
    stateFull = StatesTensorFull[stateInd]
    for i in range(N_ngb):
        indices=NNsiteTensor[i] # i^th nn site indices for all sites
        nnSpecies = StateTensors[stateInd, 0][indices]
        assert pt.equal(stateFull[i], nnSpecies)

## Create a random filter and test input image convolution

In [13]:
# Make a randomized filter with Nch channels
Nch = 1
Psi = pt.rand(Nch, 1, N_ngb, requires_grad=True).double()

In [14]:
# Expand the kernels by adding group permutations of nearest neighbor translations
GnnPermTensor= pt.tensor(GpermNNIdx).long()
Ng = GnnPermTensor.shape[0]

# repeat the permutations Nch times for each channel
GnnPermTensor_repeat = GnnPermTensor.repeat(Nch, 1).view(-1, 1, N_ngb)

# Now repeat each filter channel Ng times and permute according to group ops
Psi_repeat = Psi.repeat_interleave(Ng, dim=0)
Psi_repeat_perm = pt.gather(Psi_repeat, 2, GnnPermTensor_repeat).view(-1, 1*N_ngb)

In [15]:
Psi_repeat_perm.shape

torch.Size([48, 9])

In [16]:
# Test the repeated kernels
for ch in range(Nch):
    FullPsiCh = Psi_repeat_perm[ch*Ng : (ch+1)*Ng]
    for g in range(Ng):
        Psiperm_g = Psi[ch,0][GnnPermTensor[g]]
        assert pt.equal(FullPsiCh[g], Psiperm_g)

In [17]:
pt.sum(Psi_repeat_perm, dim=0)

tensor([17.6272, 26.7087, 26.7087, 26.7087, 26.7087, 26.7087, 26.7087, 26.7087,
        26.7087], dtype=torch.float64, grad_fn=<SumBackward1>)

In [18]:
Psi[0]

tensor([[0.3672, 0.7994, 0.5222, 0.4142, 0.9901, 0.2632, 0.2498, 0.9120, 0.3006]],
       dtype=torch.float64, grad_fn=<SelectBackward>)

In [19]:
0.3672*48

17.625600000000002