In [1]:
import numpy as np
import torch as pt
import pickle
from onsager import crystal, cluster, supercell
from tqdm import tqdm

In [4]:
# Load the data
TestStates = np.load("../CrystalDat/TestStates.npy")[:10]
dxNN = np.load("../CrystalDat/nnJumpLatVecs.npy")
RtoSiteInd = np.load("../CrystalDat/RtoSiteInd.npy")
SiteIndtoR = np.load("../CrystalDat/SiteIndtoR.npy")
GpermNNIdx = np.load("../CrystalDat/GroupNNpermutations.npy")

NNsiteList = np.load("../CrystalDat/NNsites_sitewise.npy")
N_ngb = NNsiteList.shape[0]

with open("../CrystalDat/GroupOpsIndices.pkl", "rb") as fl:
    GIndtoGDict = pickle.load(fl)

with open("../CrystalDat/supercellBCC.pkl", "rb") as fl:
    superBCC = pickle.load(fl)

In [5]:
GpermNNIdx.shape

(48, 9)

## Expand the states

In [6]:
StateTensors = pt.tensor(TestStates/2.0).double().view(TestStates.shape[0], 1, TestStates.shape[1])
StateTensors.shape

torch.Size([10, 1, 512])

In [7]:
StateTensors_repeat = StateTensors.repeat_interleave(N_ngb, dim=1)
StateTensors_repeat.shape

torch.Size([10, 9, 512])

In [8]:
for i in range(0, StateTensors.shape[0]):
    for j in range(N_ngb):
        assert pt.allclose(StateTensors[i, 0], StateTensors_repeat[i,j]), "{} {}".format(i,j)

In [17]:
# Next, repeat the NNsites tensor
NNsiteTensor = pt.tensor(NNsiteList).long()
NNsitesBatchRepeat = NNsiteTensor.unsqueeze(0).repeat(StateTensors_repeat.shape[0],1, 1)

In [18]:
NNsiteTensor.shape

torch.Size([9, 512])

In [19]:
NNsitesBatchRepeat.shape

torch.Size([10, 9, 512])

In [20]:
for i in range(0, StateTensors_repeat.shape[0]):
    assert pt.equal(NNsiteTensor, NNsitesBatchRepeat[i])

In [21]:
StatesTensorFull = pt.gather(StateTensors_repeat, 2, NNsitesBatchRepeat)
StatesTensorFull.shape

torch.Size([10, 9, 512])

In [22]:
for stateInd in range(StateTensors.shape[0]):
    stateFull = StatesTensorFull[stateInd]
    for i in range(N_ngb):
        indices=NNsiteTensor[i] # i^th nn site indices for all sites
        nnSpecies = StateTensors[stateInd, 0][indices]
        assert pt.equal(stateFull[i], nnSpecies)

## Create a random filter and test input image convolution

In [23]:
# Make a randomized filter with Nch channels
Nch = 4
Psi = pt.rand(Nch, 1, N_ngb, requires_grad=True).double()

In [24]:
# Expand the kernels by adding group permutations of nearest neighbor translations
GnnPermTensor= pt.tensor(GpermNNIdx).long()
Ng = GnnPermTensor.shape[0]

# repeat the permutations Nch times for each channel
GnnPermTensor_repeat = GnnPermTensor.repeat(Nch, 1).view(-1, 1, N_ngb)

# Now repeat each filter channel Ng times and permute according to group ops
Psi_repeat = Psi.repeat_interleave(Ng, dim=0)
Psi_repeat_perm = pt.gather(Psi_repeat, 2, GnnPermTensor_repeat).view(-1, 1*N_ngb)

In [25]:
Psi_repeat_perm.shape

torch.Size([192, 9])

In [26]:
# Test the repeated kernels
for ch in range(Nch):
    FullPsiCh = Psi_repeat_perm[ch*Ng : (ch+1)*Ng]
    for g in range(Ng):
        Psiperm_g = Psi[ch,0][GnnPermTensor[g]]
        assert pt.equal(FullPsiCh[g], Psiperm_g)

In [27]:
# Now let's do the bias terms
bias = pt.rand(Nch).unsqueeze(1)
bias_repeat = bias.repeat_interleave(Ng, dim=0)

In [28]:
# Do the convolution with the original state
linOut = pt.matmul(Psi_repeat_perm, StatesTensorFull) + bias_repeat
out = pt.nn.functional.softplus(linOut).view(StateTensors.shape[0], Nch, Ng, 512)

In [29]:
out.shape

torch.Size([10, 4, 48, 512])

In [30]:
GtoGIndDict = {}
for gInd, g in GIndtoGDict.items():
    GtoGIndDict[g] = gInd

In [31]:
# Multiply all Group ops with u^-1 and store new indices

NtestSym = 5 #StateTensors.shape[0]
for channel in range(Nch):
    print(channel, flush=True)
    out0 = out[0, channel]
    for SampInd in tqdm(range(NtestSym), position=0, leave=True):
        u = GIndtoGDict[SampInd]
        uInv = u.inv()
        outSamp = out[SampInd, channel]
        out0Transf = pt.zeros(Ng, StateTensors.shape[2]).double()
        for g, gInd in GtoGIndDict.items():
            newGInd = GtoGIndDict[uInv*g]
            for siteInd in range(StateTensors.shape[2]):
                Rsite = SiteIndtoR[siteInd]
                RsiteNew, _ = superBCC.crys.g_pos(uInv, Rsite, (0,0))
                RsiteNew %= 8
                siteIndNew = RtoSiteInd[RsiteNew[0], RsiteNew[1], RsiteNew[2]]

                out0Transf[gInd, siteInd] = out0[newGInd, siteIndNew]
                
        assert pt.equal(outSamp, out0Transf), "sample: {}, channel: {}".format(SampInd, channel)
        
print("Symmetry of 1st layer outputs correct")

0


100%|██████████| 5/5 [00:05<00:00,  1.12s/it]

1



100%|██████████| 5/5 [00:05<00:00,  1.13s/it]

2



100%|██████████| 5/5 [00:05<00:00,  1.09s/it]

3



100%|██████████| 5/5 [00:05<00:00,  1.10s/it]

Symmetry of 1st layer outputs correct





## Next, group averaging (pooling)

In [32]:
outSum = pt.sum(out, dim=2)/Ng

In [33]:
# Now for each state, check if the result is correct
for sampInd in range(StateTensors.shape[0]):
    for channel in range(Nch):
        outSampChannel = out[sampInd, channel]
        sumSamp = pt.zeros(outSum.shape[2]).double()
        for i in range(Ng):
            sumSamp += outSampChannel[i]
        assert pt.allclose(sumSamp/Ng, outSum[sampInd, channel])
print("Group averaging correct")

Group averaging correct


In [34]:
outSum.shape

torch.Size([10, 4, 512])

In [25]:
for sampInd in tqdm(range(StateTensors.shape[0]), position=0, leave=True):
    g = GIndtoGDict[sampInd]
    for channel in range(Nch):
        # Now transform the original state's output  
        out0New = pt.zeros(outSum.shape[2]).double()
        for siteInd in range(out0New.shape[0]):
            Rsite = SiteIndtoR[siteInd]
            RsiteNew, _ = superBCC.crys.g_pos(g, Rsite, (0, 0))
            RsiteNew %= 8 #apply PBC
            siteIndNew = RtoSiteInd[RsiteNew[0], RsiteNew[1], RsiteNew[2]]
            out0New[siteIndNew] = outSum[0, channel, siteInd].item()
        
        assert pt.allclose(out0New, outSum[sampInd, channel])

print("Symmetry of Group averaging correct", flush=True)

100%|██████████| 10/10 [00:00<00:00, 11.47it/s]

Symmetry of Group averaging correct





## Let's first work out the kernel of the second layer
### Input channels - Nch, ouput channels - Nch2

In [35]:
Nch2 = 3 # Output of the layer will be 3 channels
Psi2 = pt.rand(Nch2, Nch, N_ngb, requires_grad=True).double()
Psi2.shape

torch.Size([3, 4, 9])

In [36]:
# Now rearrange the kernel into the full matrix correctly

# First, let's do the repitition
Psi2_repeat = Psi2.repeat_interleave(Ng, dim = 0)
Psi2_repeat.shape

torch.Size([144, 4, 9])

In [37]:
# Then we do the permutation
GnnPermTensor_repeat_2 = GnnPermTensor.repeat(Nch2, Nch).view(-1, Nch, N_ngb)

In [40]:
GnnPermTensor_repeat_2.shape

torch.Size([144, 4, 9])

In [41]:
GnnPermTensor.shape

torch.Size([48, 9])

In [None]:
GnnPermTensor.repeat(Nch2, Nch)[1]

In [None]:
GnnPermTensor

In [36]:
Psi2_repeat_perm = pt.gather(Psi2_repeat, 2, GnnPermTensor_repeat_2).view(-1, Nch*N_ngb)

In [41]:
# Test the repeated kernels again
for ch2 in range(Nch2):
    for ch1 in range(Nch):
        # Get the portion of the kernel we need to test                
        FullPsiCh2 = Psi2_repeat_perm[ch2*Ng : (ch2+1)*Ng, ch1*N_ngb:(ch1+1)*N_ngb]
        for g in range(Ng):
            Psi2perm_g = Psi2_repeat_perm[ch2*Ng, ch1*N_ngb:(ch1+1)*N_ngb][GnnPermTensor[g]]
            assert pt.equal(FullPsiCh2[g], Psi2perm_g)

In [42]:
Psi2_repeat_perm.shape

torch.Size([144, 36])

In [43]:
Psi2_repeat_perm[1]

tensor([0.1577, 0.5958, 0.1759, 0.9108, 0.3299, 0.5226, 0.6978, 0.6980, 0.9042,
        0.3951, 0.1354, 0.4617, 0.6133, 0.8857, 0.5069, 0.4653, 0.7044, 0.2454,
        0.8610, 0.8028, 0.2826, 0.4200, 0.4498, 0.4143, 0.0896, 0.4331, 0.0417,
        0.8866, 0.7633, 0.7502, 0.6293, 0.4389, 0.7738, 0.7603, 0.2279, 0.4851],
       dtype=torch.float64, grad_fn=<SelectBackward>)

In [44]:
# Now let's do the bias terms
bias2 = pt.rand(Nch2).unsqueeze(1)
bias2_repeat = bias2.repeat_interleave(Ng, dim=0)

## Next, we work out rearranging the output of the first layer into a suitable input for the second layer

In [45]:
outSum.shape

torch.Size([10, 4, 512])

In [46]:
# The channels are already stacked in outSum - so let's first repeat
input2 = outSum
input2_repeat = input2.repeat_interleave(N_ngb, dim=1)
input2_repeat.shape

torch.Size([10, 36, 512])

In [47]:
input2.shape

torch.Size([10, 4, 512])

In [48]:
# Now we have to put in appropriate nearest neighbor sites
NNbatchRepeatL2 = NNsiteTensor.unsqueeze(0).repeat(input2_repeat.shape[0], Nch, 1)

In [49]:
NNbatchRepeatL2.shape

torch.Size([10, 36, 512])

In [50]:
input2_nnStack = pt.gather(input2_repeat, 2, NNbatchRepeatL2)

In [51]:
input2_nnStack.shape

torch.Size([10, 36, 512])

In [52]:
# Now let's see if nearest neighbors have been gathered properly
for sampInd in range(StateTensors.shape[0]):
    for ch in range(Nch):
        inSampL2 = input2_nnStack[sampInd, ch*N_ngb : (ch+1)*N_ngb]
        outSamp = outSum[sampInd, ch]
        for ngb in range(N_ngb):
            ngbSites = outSamp[NNsiteTensor[ngb]]
            assert pt.equal(ngbSites, inSampL2[ngb]), "{} {} {}".format(sampInd, ch, ngb)

In [53]:
# Now let's do the convolution
out2 = pt.nn.functional.softplus(pt.matmul(Psi2_repeat_perm, input2_nnStack) + bias2_repeat).view(StateTensors.shape[0], Nch2,
                                                        Ng, StateTensors.shape[2])

In [54]:
# Now average group dimension again
out2Sum = pt.sum(out2, dim=2)/Ng
out2Sum.shape

torch.Size([10, 3, 512])

In [55]:
# Now check symmetry
for sampInd in tqdm(range(StateTensors.shape[0]), position=0, leave=True):
    g = GIndtoGDict[sampInd]
    for channel in range(Nch2):
        # Now transform the original state's output  
        out0New = pt.zeros(out2Sum.shape[2]).double()
        for siteInd in range(out0New.shape[0]):
            Rsite = SiteIndtoR[siteInd]
            RsiteNew, _ = superBCC.crys.g_pos(g, Rsite, (0, 0))
            RsiteNew %= 8 #apply PBC
            siteIndNew = RtoSiteInd[RsiteNew[0], RsiteNew[1], RsiteNew[2]]
            out0New[siteIndNew] = out2Sum[0, channel, siteInd].item()
        
        assert pt.allclose(out0New, out2Sum[sampInd, channel])

print("Symmetry of layer 2 output correct")

100%|██████████| 10/10 [00:00<00:00, 15.19it/s]

Symmetry of layer 2 output correct





In [56]:
out2Sum.shape

torch.Size([10, 3, 512])

## Now let's test R3 convolution

In [57]:
# First let's make the kernel transformation matrix
gdiags = pt.zeros(48*3, 48*3).double()
for gInd, g in GIndtoGDict.items():
    rowStart = gInd * 3
    rowEnd = (gInd + 1) * 3
    gdiags[rowStart : rowEnd, rowStart : rowEnd] = pt.tensor(g.cartrot).double()

In [58]:
# Now, let's initialize the kernel
Psi3 = pt.rand(3, N_ngb, requires_grad=True).double()

# Repeat it
Psi3_repeat = Psi3.repeat(Ng, 1)

In [59]:
# construct the permutations
GnnPermTensor_R3_repeat = GnnPermTensor.repeat_interleave(3, dim=0)

In [60]:
Psi3_repeat_transf = pt.matmul(gdiags, pt.gather(Psi3_repeat, 1, GnnPermTensor_R3_repeat))

In [78]:
# Now check if the transformations were correct
Psi3_0 = Psi3
for gInd, g in GIndtoGDict.items():
    cartRotTens = pt.tensor(g.cartrot).double()
    Psi3_g = Psi3_repeat_transf[gInd*3 : (gInd+1)*3, :]
    for nn in range(N_ngb):
        nnPerm = GnnPermTensor[gInd, nn]
        if nn == 0: # the on-site term should only be rotated
            assert nnPerm == 0        
        # The rest of the NNs should be permuted and rotated
        assert pt.allclose(pt.matmul(cartRotTens, Psi3_0[:, nnPerm]), Psi3_g[:, nn])

In [77]:
Psi3.shape, Psi3_0.shape

(torch.Size([3, 9]), torch.Size([3, 9]))

In [62]:
# For the output, we only select the 0-channel result of out2Sum
out2Sum.shape
Inps3 = out2Sum[:, 0, :].view(StateTensors.shape[0], 1, StateTensors.shape[2])
Inps3.shape

torch.Size([10, 1, 512])

In [63]:
# Next, repeat and arrange input sites with neareset neighbors
Inps3_repeat = Inps3.repeat_interleave(N_ngb, dim=1)
Inps3_repeat.shape

torch.Size([10, 9, 512])

In [64]:
Inps3_nn = pt.gather(Inps3_repeat, 2, NNsitesBatchRepeat)

In [65]:
Inps3_nn.shape

torch.Size([10, 9, 512])

In [66]:
Psi3_repeat_transf.shape

torch.Size([144, 9])

In [67]:
Outs3 = pt.matmul(Psi3_repeat_transf, Inps3_nn).view(StateTensors.shape[0], Ng, 3, StateTensors.shape[2])

In [68]:
outs3GSum = pt.sum(Outs3, dim=1)

In [69]:
outs3GSum.shape

torch.Size([10, 3, 512])

In [70]:
# Check if for rotated states, vectors have also been rotated correctly
out3_0 = outs3GSum[0]
count = 0
for sampInd in tqdm(range(StateTensors.shape[0]), position=0, leave=True):
    g = GIndtoGDict[sampInd]
    gTens = pt.tensor(g.cartrot).double()
    for siteInd in range(StateTensors.shape[2]):
        # Get the vector at this site for state 0
        vecSite0 = out3_0[:, siteInd]
        
        # Rotate the vector
        vecNew = pt.matmul(gTens, vecSite0)
        
        # Now get the transformed site
        Rsite = SiteIndtoR[siteInd]
        Rnew, _ = superBCC.crys.g_pos(g, Rsite, (0, 0))
        Rnew %= 8
        siteIndNew = RtoSiteInd[Rnew[0], Rnew[1], Rnew[2]]
        
        assert pt.allclose(vecNew, outs3GSum[sampInd, :, siteIndNew])
        count += 1

100%|██████████| 10/10 [00:00<00:00, 14.60it/s]


In [71]:
count

5120

In [72]:
outs3SiteSum = pt.sum(outs3GSum, dim=2)
outs3SiteSum.shape

torch.Size([10, 3])

In [73]:
outs3SiteSum

tensor([[ 7.9581e-13,  2.0464e-12,  0.0000e+00],
        [ 6.7075e-12,  4.5475e-13, -2.8990e-12],
        [ 1.1937e-12, -9.0949e-13, -4.2064e-12],
        [ 1.3358e-12, -1.7906e-12, -2.4443e-12],
        [-1.9895e-12, -1.5916e-12, -1.7621e-12],
        [-4.6043e-12,  5.6843e-14, -2.4443e-12],
        [ 1.4779e-12,  4.5475e-13,  3.0695e-12],
        [ 1.2790e-12, -1.4921e-12, -1.0800e-12],
        [-3.4674e-12,  6.2528e-13, -2.7853e-12],
        [-3.0695e-12, -2.2737e-12, -1.9895e-12]], dtype=torch.float64,
       grad_fn=<SumBackward1>)

In [85]:
RtoSiteInd[1,1,1], RtoSiteInd[-1,-1,-1]

(73, 511)

In [86]:
outs3GSum[0,:,RtoSiteInd[1,1,1]], outs3GSum[0,:,RtoSiteInd[-1,-1,-1]]

(tensor([ 50.1469, -59.5655, -22.8830], dtype=torch.float64,
        grad_fn=<SelectBackward>),
 tensor([ 38.6426, -12.0880,  -7.6604], dtype=torch.float64,
        grad_fn=<SelectBackward>))

In [87]:
pt.sum(outs3GSum[0,0])

tensor(2.5011e-12, dtype=torch.float64, grad_fn=<SumBackward0>)

In [88]:
outs3GSum.shape

torch.Size([10, 3, 512])

In [89]:
out30 = outs3GSum[0, :, RtoSiteInd[1,1,1]]
sites = {}
for siteInd1 in range(outs3GSum.shape[2]):
    vec1 = outs3GSum[0, :, siteInd1]
    sites[siteInd1] = []
    for siteInd2 in range(outs3GSum.shape[2]):
        vec2 = outs3GSum[0, :, siteInd2]
        if pt.allclose(vec1+vec2, pt.zeros_like(vec1).double()):
            sites[siteInd1].append(siteInd2)