Github path : Lattice_Gas/CE_Symmetry/Numba_Cuda/

In [1]:
import numpy as np
import math
from numba import cuda, jit, float32, float64, int64, uint8, int16
# import cupy

In [2]:
import sys
sys.path.append("../../Symm_Network/SymNetTrials/")

In [3]:
FP = "../../Symm_Network/SymNetTrials/"

In [4]:
NsitesMax = 1024
NChMax = 100
NngbMax = 13 # maximum for close packed structs +1 for the site itself     
NChSitesMax = NChMax*NngbMax                                               

In [5]:
# write device function for softplus
@cuda.jit(device=True)
def softPlus(x, beta, threshold):
    if x < -threshold:
        return 0.0
    elif x > threshold:
        return x
    else:
        return math.log(1.0 + math.exp(beta*x))/beta

@cuda.jit(device=True)
def gradSoftPlus(x, beta, threshold):
    if x < -threshold:
        return 0.0
    elif x > threshold:
        return 1.0
    else:
        return 1./(1.0 + math.exp(-beta*x))

In [6]:
# @cuda.jit
# def Gcorr(InputImage, Psi, OutImage, OutF, SiteNeighbors, GnnPerms, N_array,
#           sp_beta, sp_threshold):
@cuda.jit
def Gcorr(InputImage, Psi, OutImage, OutF, SiteNeighbors, GnnPerms, N_array,
          sp_beta, sp_threshold):
    # first get locations
    ty = cuda.threadIdx.y
    bx, by, bz = cuda.blockIdx.x, cuda.blockIdx.y, cuda.blockIdx.z

    bSizeX, bSizeY, bSizeZ = cuda.blockDim.x, cuda.blockDim.y, cuda.blockDim.z

    Nsites = N_array[0] 
    NInCh = N_array[1]
    NOutCh = N_array[2]
    N_ngb = N_array[3]
    Ng = N_array[4]

    # Get the necessary output indices
    batchInd = bx # which sample the thread is working with
    outCh = (by*bSizeY + ty)//Nsites  # thread's output channel
    siteInd = (by*bSizeY + ty)%Nsites # which site the conv is over
    gInd = bz  # which group operation the thread is handling

    # Get the neighborhood of the current site
    # in the thread's local memory
    NgbIndices = cuda.local.array(shape=(NngbMax,), dtype=int64)
    for ngb in range(N_ngb):
        NgbIndices[ngb] = SiteNeighbors[siteInd, ngb]

    # create the shared arrays to store filters and input elements
    InChannel = cuda.shared.array(shape=(NsitesMax,), dtype=float64)
    Filter = cuda.shared.array(shape=(NChSitesMax,), dtype=float64)

    # Store the Group rotations of nns into shared memory
    gnnRotShared = cuda.shared.array(shape=(NngbMax,), dtype=uint8)
    if ty < N_ngb:
        gnnRotShared[ty] = GnnPerms[gInd, ty]

    linSum = 0.
    for inCh in range(NInCh):
        # First read the input channel into shared memory
        for sweep in range(Nsites//bSizeY + 1):
            threadSiteInd = sweep*bSizeY + ty
            if threadSiteInd < Nsites:
                InChannel[threadSiteInd] = InputImage[batchInd, inCh, threadSiteInd]
        
        # Then read the filter for this input channel
        # Apply Group permutation to it as well
        for sweep in range((NOutCh*N_ngb)//bSizeY + 1):
            threadElemInd = sweep*bSizeY + ty
            if threadElemInd < NOutCh*N_ngb:
                Filter[threadElemInd] = Psi[inCh, threadElemInd]
        
        # synchronize the block
        cuda.syncthreads()

        # Reading phase is done - now convolve
        for ngb in range(N_ngb):
            ngbSite = NgbIndices[ngb]
            linSum += Filter[outCh*N_ngb + gnnRotShared[ngb]] * InChannel[ngbSite]
        
        # synchronize the block
        cuda.syncthreads()
        
    OutF[batchInd, outCh, gInd, siteInd] = linSum
    
    nonLin = softPlus(linSum, sp_beta, sp_threshold)/Ng
    # atomically sum out the group channel
    cuda.atomic.add(OutImage, (batchInd, outCh, siteInd), nonLin)

In [7]:
# load the data
NNSites = np.load(FP + "NNsites_sitewise.npy").T
GNNperms = np.load(FP + "GroupNNpermutations.npy")
RtoSiteInd = np.load(FP + "RtoSiteInd.npy")
SiteIndtoR = np.load(FP + "SiteIndtoR.npy")
(Nsites, N_ngb) = NNSites.shape
Ng = GNNperms.shape[0]
print(N_ngb, Nsites, Ng)

9 512 48


In [None]:
# # load the pickle files
# import pickle
# with open(FP + "supercellBCC.pkl", "rb") as fl:
#     superBCC = pickle.load(fl)

# with open(FP + "GroupOpsIndices.pkl", "rb") as fl:
#     GIndices = pickle.load(fl)

# with open(FP + "jnetBCC.pkl") as fl:
#     jNetBCC = pickle.load(fl)

In [8]:
# Create a random input and output image map
Nsites = 512
Nbatch = 512
NchIn = 16
NchOut = 16

InImage = np.random.rand(Nbatch, NchIn, Nsites)
OutImage = np.zeros((Nbatch, NchOut, Nsites))
OutF = np.zeros((Nbatch, NchOut, Ng, Nsites))

# Now set up a random filter
Psi = np.random.rand(NchIn, NchOut*N_ngb)

In [9]:
ty = 512
NbatchRun = Nbatch
bX = NbatchRun
bY = int(np.ceil((NchOut*Nsites)/ty))
bZ = Ng

In [10]:
%%time
d_input = cuda.to_device(InImage)
d_NNSites = cuda.to_device(NNSites)
d_GNNperms = cuda.to_device(GNNperms)
d_Psi = cuda.to_device(Psi)

N_array = np.array([Nsites, NchIn, NchOut, N_ngb, Ng],dtype=int)
d_Narray = cuda.to_device(N_array)

CPU times: user 10.5 ms, sys: 497 ms, total: 508 ms
Wall time: 603 ms


In [23]:
%%time
d_output = cuda.to_device(OutImage)
d_outF = cuda.to_device(OutF)

CPU times: user 51.7 ms, sys: 29.2 ms, total: 80.9 ms
Wall time: 89.4 ms


In [24]:
import time

In [26]:
Nrun = 5
start = time.time()
for i in range(Nrun):
    Gcorr[(bX, bY, bZ), (1, ty, 1)](d_input, d_Psi, d_output, d_outF, d_NNSites,
                                    d_GNNperms, d_Narray, 1.0, 20.0)
    cuda.synchronize()
print("time : {}".format((time.time() - start)/Nrun))

time : 0.294343090057373


In [11]:
OutImage = np.zeros((Nbatch, NchOut, Nsites))
OutF = np.zeros((Nbatch, NchOut, Ng, Nsites))

d_output = cuda.to_device(OutImage)
d_outF = cuda.to_device(OutF)

cuda.synchronize()
Gcorr[(bX, bY, bZ), (1, ty, 1)](d_input, d_Psi, d_output, d_outF, d_NNSites,
                                    d_GNNperms, d_Narray, 1.0, 20.0)
cuda.synchronize()

In [12]:
@jit(nopython=True)
def softPlusCPU(x, beta, threshold):
    if x < -threshold:
        return 0.0
    elif x > threshold:
        return x
    else:
        return math.log(1.0 + math.exp(beta*x))/beta

@jit(nopython=True)
def CorrSerial(InImage, Psi, NchOut, NchIn, Nsites, Ng, NNSites, GNNperms):
    outSamp = np.zeros((InImage.shape[0], NchOut, Nsites))
    outSampF = np.zeros((InImage.shape[0], NchOut, Ng, Nsites))
    for sampInd in range(InImage.shape[0]):
        for outCh in range(NchOut):
            for siteInd in range(Nsites):
                # Now go through the input channels
                gsum = 0.
                for gInd in range(Ng):
                    linSum = 0.
                    for inCh in range(NchIn):
                        for ngb in range(N_ngb):
                            filt = Psi[inCh, outCh*N_ngb + GNNperms[gInd, ngb]]
                            linSum += filt * InImage[sampInd, inCh, NNSites[siteInd, ngb]]
                    
                    outSampF[sampInd, outCh, gInd, siteInd] = linSum
                    
                    gsum += softPlusCPU(linSum, 1, 20)/Ng

                outSamp[sampInd, outCh, siteInd] = gsum
    return outSamp, outSampF

In [13]:
%%time
out,outF=CorrSerial(InImage, Psi, NchOut, NchIn, Nsites, Ng, NNSites, GNNperms)

CPU times: user 55.4 s, sys: 180 ms, total: 55.6 s
Wall time: 55.6 s


In [14]:
HostOut = d_output.copy_to_host()
HostOutF = d_outF.copy_to_host()

In [19]:
# assert np.allclose(HostF[sampInd], outSampF)
assert np.allclose(HostOut, out)
assert np.allclose(HostOutF, outF)

In [None]:
@cuda.jit
def GConvBack(dL_dOutImg, InImg, Psi, dL_dInImg, NbgSites, Narray,
              sp_beta, sp_threshold):

    ty = cuda.threadIdx.y
    bSizeX, bSizeY, bSizeZ = cuda.blockIdx.x, cuda.blockIdx.y, cuda.blockIdx.z
    
    Nsites = N_array[0] 
    NInCh = N_array[1]
    NOutCh = N_array[2]
    N_ngb = N_array[3]
    Ng = N_array[4]

    # Get the necessary output indices
    batchInd = bx # which sample the thread is working with
    outCh = (by*bSizeY + ty)//Nsites  # thread's output channel
    siteInd = (by*bSizeY + ty)%Nsites # which site the conv is over
    gInd = bz  # which group operation the thread is handling

    # First compute the linear correlation term
    
    # Get the neighborhood of the current site
    # in the thread's local memory
    NgbIndices = cuda.local.array(shape=(NngbMax,), dtype=int64)
    for ngb in range(N_ngb):
        NgbIndices[ngb] = SiteNeighbors[siteInd, ngb]

    # create the shared arrays to store filters and input elements
    InChannel = cuda.shared.array(shape=(NsitesMax,), dtype=float64) # max 8kB
    Filter = cuda.shared.array(shape=(NChSitesMax,), dtype=float64) # max 8kB+
    # create a shared array in which threads write their contribution to
    # the gradient
    dL_dIn_shared = cuda.shared.array(shape=(NsitesMax,), dtype=float64) # max 8KB

    # Store the Group rotations of nns into shared memory
    gnnRotShared = cuda.shared.array(shape=(NngbMax,), dtype=uint8) # max 0.1 kB
    if ty < N_ngb:
        gnnRotShared[ty] = GnnPerms[gInd, ty]

    linSum = 0.
    for inCh in range(NInCh):
        # First read the input channel into shared memory
        for sweep in range(Nsites//bSizeY + 1):
            threadSiteInd = sweep*bSizeY + ty
            if threadSiteInd < Nsites:
                InChannel[threadSiteInd] = InImg[batchInd, inCh, threadSiteInd]
        
        # Then read the filter for this input channel
        for sweep in range((NOutCh*N_ngb)//bSizeY + 1):
            threadElemInd = sweep*bSizeY + ty
            if threadElemInd < NOutCh*N_ngb:
                Filter[threadElemInd] = Psi[inCh, threadElemInd]
        
        # synchronize the block
        cuda.syncthreads()

        # Reading phase is done - now convolve
        for ngb in range(N_ngb):
            ngbSite = NgbIndices[ngb]
            linSum += Filter[outCh*N_ngb + gnnRotShared[ngb]] * InChannel[ngbSite]
        
        # synchronize the block
        cuda.syncthreads()
        
    # once the linear sum for this layer is found, we need to
    # add its contribution to the relevant components of the gradient
    nonLinGrad = gradSoftPlus(linSum, sp_beta, sp_threshold)

    gradOut = dL_dOutImg[batchInd, outCh, siteInd]

    for inCh in range(NInCh): # iterate over input channels
        # zero out the block sum for this input channel
        for sweep in range(Nsites//bSizeY + 1):
            threadSiteInd = sweep*bSizeY + ty
            if threadSiteInd < Nsites:
                dL_dIn_shared[threadSiteInd] = 0.

        # Read in the filter elements
        for sweep in range((NOutCh*N_ngb)//bSizeY + 1):
            threadElemInd = sweep*bSizeY + ty
            if threadElemInd < NOutCh*N_ngb:
                Filter[threadElemInd] = Psi[inCh, threadElemInd]
        
        # synchronize to ensure all reads are complete
        cuda.syncthreads()

        for ngb in range(N_ngb): # iterate over nearest neighbors to write to
            ngbSite = NgbIndices[ngb]
            # compute the current thread's contribution to this current nearest
            # neighbor
            val = nonLinGrad * gradOut * Filter[outCh*N_ngb + gnnRotShared[ngb]]
            cuda.atomicadd(dL_dIn_shared[ngbSite], val)
        
        # wait for atomic adds to get completed
        cuda.synchthreads()

        # Once the neighbor contributions are accumulated, now sum out the group
        # and output channels
        # threads of the same batch, different out channels and different
        # group channels must add to the same site.
        cuda.atomicadd(dL_dInImg[batchInd, inCh, siteInd], dL_dIn_shared[siteInd])

        # sync the threads again before zeroing out the shared array
        cuda.synchthreads()