<a href="https://colab.research.google.com/github/sohamch/VKMC/blob/SymmNets/Lattice_Gas/CE_Symmetry/Numba_Cuda/Numba_GConv.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import math
from numba import cuda, jit, float32, float64, int64, uint8, int16

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
FP = "/content/drive/My Drive/Colab/Numba_Cuda/GConv/"

In [4]:
NsitesMax = 1024
NChMax = 100
NngbMax = 13 # maximum for close packed structs +1 for the site itself     
NChSitesMax = NChMax*NngbMax                                               

In [5]:
# write device function for softplus
@cuda.jit(device=True)
def softPlus(x, beta, threshold):
    if x < -threshold:
        return 0.0
    elif x > threshold:
        return x
    else:
        return math.log(1.0 + math.exp(beta*x))/beta

@cuda.jit(device=True)
def gradSoftPlus(x, beta, threshold):
    if x < -threshold:
        return 0.0
    elif x > threshold:
        return 1.0
    else:
        return 1./(1.0 + math.exp(-beta*x))

In [13]:
@cuda.jit
def Gconv(InputImage, Psi, OutImage, SiteNeighbors, GnnPerms,
          Nsites, NInCh, NOutCh, N_ngb, Ng,
          sp_beta, sp_threshold):
    # first get locations
    ty = cuda.threadIdx.y
    bx, by, bz = cuda.blockIdx.x, cuda.blockIdx.y, cuda.blockIdx.z

    bSizeX, bSizeY, bSizeZ = cuda.blockDim.x, cuda.blockDim.y, cuda.blockDim.z

    # Get the necessary output indices
    batchInd = bx # which sample the thread is working with
    outCh = (by*bSizeY + ty)//Nsites  # thread's output channel
    siteInd = (by*bSizeY + ty)%Nsites # which site the conv is over
    gInd = bz  # which group operation the thread is handling

    # Get the neighborhood of the current site
    # in the thread's local memory
    NgbIndices = cuda.local.array(shape=(NngbMax,), dtype=int64)
    
    NgbIndices[:N_ngb] = SiteNeighbors[siteInd, :]

    # create the shared arrays to store filters and input elements
    InChannel = cuda.shared.array(shape=(NsitesMax,), dtype=float64)
    Filter = cuda.shared.array(shape=(NChSitesMax,), dtype=float64)

    # Store the Group rotations of nns into shared memory
    gnnRotShared = cuda.local.array(shape=(NngbMax,), dtype=uint8)
    if ty < N_ngb:
        gnnRotShared[ty] = GnnPerms[gInd, ty]

    linSum = 0.
    for inCh in range(NInCh):
        # First read the input channel into shared memory
        for sweep in range(Nsites//bSizeY + 1):
            threadSiteInd = sweep*bSizeY + ty
            if threadSiteInd < Nsites:
                InChannel[threadSiteInd] = InputImage[batchInd, inCh, threadSiteInd]
        
        # Then read the filter for this input channel
        # Apply Group permutation to it as well
        for sweep in range((NOutCh*N_ngb)//bSizeY + 1):
            threadElemInd = sweep*bSizeY + ty
            if threadElemInd < NOutCh*N_ngb:
                Filter[threadElemInd] = Psi[inCh, threadElemInd]
        
        # synchronize the block
        cuda.syncthreads()

        # Reading phase is done - now convolve
        for ngb in range(N_ngb):
            ngbSite = NgbIndices[ngb]
            linSum += Filter[outCh*N_ngb + gnnRotShared[ngb]] * InChannel[ngbSite]
    
    # Apply non-linearity
    if bx == 0 and by == 0 and bz == 0 and ty == 0:
        print(linSum)
    nonLin = softPlus(linSum,sp_beta, sp_threshold)/Ng
    # atomically sum out the group channel
    cuda.atomic.add(OutImage, (batchInd, outCh, siteInd), nonLin)

In [7]:
# load the data
NNSites = np.load(FP + "NNsites_sitewise.npy").T
GNNperms = np.load(FP + "GroupNNpermutations.npy")
RtoSiteInd = np.load(FP + "RtoSiteInd.npy")
SiteIndtoR = np.load(FP + "SiteIndtoR.npy")
(Nsites, N_ngb) = NNSites.shape
Ng = GNNperms.shape[0]
print(N_ngb, Nsites, Ng)

9 512 48


In [8]:
# # load the pickle files
# import pickle
# with open(FP + "supercellBCC.pkl", "rb") as fl:
#     superBCC = pickle.load(fl)

# with open(FP + "GroupOpsIndices.pkl", "rb") as fl:
#     GIndices = pickle.load(fl)

# with open(FP + "jnetBCC.pkl") as fl:
#     jNetBCC = pickle.load(fl)

In [9]:
# Create a random input and output image map
Nsites = 512
Nbatch = 512
Nch = 16
InImage = np.random.rand(Nbatch, Nch, Nsites)
OutImage = np.zeros_like(InImage)

In [10]:
# Now set up a random filter
NchIn = Nch
NchOut = Nch
Psi = np.random.rand(NchIn, NchOut*N_ngb)

In [11]:
%set_env NUMBA_ENABLE_CUDASIM=0

env: NUMBA_ENABLE_CUDASIM=0


In [14]:
ty = 512
bX = Nbatch
bY = int(np.ceil((NchOut*Nsites)/ty))
bZ = Ng
Gconv[(bX, bY, bZ), (1, ty, 1)](InImage, Psi, OutImage, NNSites, GNNperms,
          Nsites, NchIn, NchOut, N_ngb, Ng,
          1.0, 20.0)

LoweringError: ignored

In [17]:
NNSites[0]#.shape

array([  0,  73, 511, 448,  64,   8,  56,   1,   7])

In [16]:
N_ngb

9