<a href="https://colab.research.google.com/github/sohamch/VKMC/blob/SymmNets/Lattice_Gas/CE_Symmetry/Numba_Cuda/Numba_GConv.ipynb/Numba_GConv.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import math
from numba import cuda, jit, float32, float64, int64, uint8, int16

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
FP = "/content/drive/My Drive/Colab/Numba_Cuda/GConv/"

In [5]:
NsitesMax = 1024
NChMax = 100
NngbMax = 13 # maximum for close packed structs +1 for the site itself     
NChSitesMax = NChMax*NngbMax                                               

In [6]:
# write device function for softplus
@cuda.jit(device=True)
def softPlus(x, beta, threshold):
    if x < -threshold:
        return 0.0
    elif x > threshold:
        return x
    else:
        return math.log(1.0 + math.exp(beta*x))/beta

@cuda.jit(device=True)
def gradSoftPlus(x, beta, threshold):
    if x < -threshold:
        return 0.0
    elif x > threshold:
        return 1.0
    else:
        return 1./(1.0 + math.exp(-beta*x))

In [7]:
@cuda.jit
def Gconv(InputImage, Psi, OutImage, SiteNeighbors, GnnPerms,
          Nsites, NInCh, NOutCh, N_ngb, Ng,
          sp_beta, sp_threshold):
    # first get locations
    ty = cuda.threadIdx.y
    bx, by, bz = cuda.blockIdx.x, cuda.blockIdx.y, cuda.blockIdx.z

    bSizeX, bSizeY, bSizeZ = cuda.blockDim.x, cuda.blockDim.y, cuda.blockDim.z

    # Get the necessary output indices
    batchInd = bx # which sample the thread is working with
    outCh = (by*bSizeY + ty)//Nsites  # thread's output channel
    siteInd = (by*bSizeY + ty)%Nsites # which site the conv is over
    gInd = bz  # which group operation the thread is handling

    # Get the neighborhood of the current site
    # in the thread's local memory
    NgbIndices = cuda.local.array(shape=(NngbMax,), dtype=int64)
    for ngb in range(N_ngb):
        NgbIndices[ngb] = SiteNeighbors[siteInd, ngb]

    # create the shared arrays to store filters and input elements
    InChannel = cuda.shared.array(shape=(NsitesMax,), dtype=float64)
    Filter = cuda.shared.array(shape=(NChSitesMax,), dtype=float64)

    # Store the Group rotations of nns into shared memory
    gnnRotShared = cuda.local.array(shape=(NngbMax,), dtype=uint8)
    if ty < N_ngb:
        gnnRotShared[ty] = GnnPerms[gInd, ty]

    linSum = 0.
    for inCh in range(NInCh):
        # First read the input channel into shared memory
        for sweep in range(Nsites//bSizeY + 1):
            threadSiteInd = sweep*bSizeY + ty
            if threadSiteInd < Nsites:
                InChannel[threadSiteInd] = InputImage[batchInd, inCh, threadSiteInd]
        
        # Then read the filter for this input channel
        # Apply Group permutation to it as well
        for sweep in range((NOutCh*N_ngb)//bSizeY + 1):
            threadElemInd = sweep*bSizeY + ty
            if threadElemInd < NOutCh*N_ngb:
                Filter[threadElemInd] = Psi[inCh, threadElemInd]
        
        # synchronize the block
        cuda.syncthreads()

        # Reading phase is done - now convolve
        for ngb in range(N_ngb):
            ngbSite = NgbIndices[ngb]
            linSum += Filter[outCh*N_ngb + gnnRotShared[ngb]] * InChannel[ngbSite]
    
    nonLin = softPlus(linSum,sp_beta, sp_threshold)/Ng
    # atomically sum out the group channel
    cuda.atomic.add(OutImage, (batchInd, outCh, siteInd), nonLin)

In [8]:
# load the data
NNSites = np.load(FP + "NNsites_sitewise.npy").T
GNNperms = np.load(FP + "GroupNNpermutations.npy")
RtoSiteInd = np.load(FP + "RtoSiteInd.npy")
SiteIndtoR = np.load(FP + "SiteIndtoR.npy")
(Nsites, N_ngb) = NNSites.shape
Ng = GNNperms.shape[0]
print(N_ngb, Nsites, Ng)

9 512 48


In [None]:
# # load the pickle files
# import pickle
# with open(FP + "supercellBCC.pkl", "rb") as fl:
#     superBCC = pickle.load(fl)

# with open(FP + "GroupOpsIndices.pkl", "rb") as fl:
#     GIndices = pickle.load(fl)

# with open(FP + "jnetBCC.pkl") as fl:
#     jNetBCC = pickle.load(fl)

In [52]:
# Create a random input and output image map
Nsites = 512
Nbatch = 512
NchIn = 32
NchOut = 16
InImage = np.random.rand(Nbatch, NchIn, Nsites)
OutImage = np.zeros((Nbatch, NchOut, Nsites))

In [53]:
# Now set up a random filter
Psi = np.random.rand(NchIn, NchOut*N_ngb)

In [54]:
ty = 512
NbatchRun = 512
bX = NbatchRun
bY = int(np.ceil((NchOut*Nsites)/ty))
bZ = Ng

In [55]:
%%time
d_input = cuda.to_device(InImage)
d_output = cuda.to_device(OutImage)
d_NNSites = cuda.to_device(NNSites)
d_GNNperms = cuda.to_device(GNNperms)
d_Psi = cuda.to_device(Psi)

d_Nsites = cuda.to_device(Nsites)
d_NchIn = cuda.to_device(NchIn)
d_NchOut = cuda.to_device(NchOut)
d_Nngb = cuda.to_device(N_ngb)
d_ng = cuda.to_device(Ng)

CPU times: user 30.3 ms, sys: 3.67 ms, total: 33.9 ms
Wall time: 35.8 ms


In [63]:
%%time
Gconv[(bX, bY, bZ), (1, ty, 1)](d_input[:128], d_Psi, d_output, d_NNSites, d_GNNperms,
          Nsites, NchIn, NchOut, N_ngb, Ng,
          1.0, 20.0)

CPU times: user 1.19 ms, sys: 0 ns, total: 1.19 ms
Wall time: 1.37 ms


In [64]:
# Copy the output to the host
HostOut = d_output.copy_to_host()

In [None]:
# Now let's do the conv explictly
# We'll test randomly chosen samples for time considerations
def softPlus(x, beta, threshold):
    if x < -threshold:
        return 0.0
    elif x > threshold:
        return x
    else:
        return math.log(1.0 + math.exp(beta*x))/beta

# select a random sample
sampInd = np.random.ranint(0, 512)
outSamp = OutImage[sampInd].copy()
for outCh in range(NchOut):
    for siteInd in range(Nsite):

        # Now go through the input channels
        for g in range(gInd):
            for inCh in range(NchIn):
                for ngb in range(N_ngb):
                    filt = Psi[inch, outCh*N_ngb + ]
                