# Thread Folding

Compute multiple elements per thread, especially in Z direction

In [None]:

import sys 
sys.path.append('../pystencils')
sys.path.append('../genpredict')

%load_ext autoreload
%autoreload 1
%aimport pystencils.warpspeed.warpspeed
%aimport predict
%aimport griditeration
%aimport volumes_isl
%aimport pystencils.astnodes
%aimport plot_utils



In [None]:
from tinydb import TinyDB, Query
db = TinyDB('./db.json')
db.truncate()

In [None]:
import cProfile
import re


import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import sys
from subprocess import run, PIPE

from pystencils.transformations import loop_blocking

import pystencils as ps
from pystencils.slicing import add_ghost_layers, make_slice, remove_ghost_layers
from pystencils.warpspeed.warpspeed import PyStencilsWarpSpeedKernel, getFieldExprs, lambdifyExprs, simplifyExprs
from griditeration import *
from volumes_isl import *


from plot_utils import *
from meas_utils import *
from pystencils_stencil_utils import PS3DStencil



import sympy as sp

import pycuda
import pycuda.autoinit
import pycuda.gpuarray as gpuarray
import pycuda.driver as drv

import timeit


In [None]:
size = (600, 600, 512)
SS = PS3DStencil(size, 2)

In [None]:
kernel = SS.getStarKernel((16, 16, 2), 2, (1,1,2))

printSASS("#define FUNC_PREFIX __global__\n#define RESTRICT __restrict__\n" + ps.get_code_str(kernel) )

In [None]:
import copy
import sympy as sp
from pystencils.data_types import TypedSymbol, get_type_of_expression
import pystencils.astnodes as astn
blockCount = 2

ast = ps.create_kernel(SS.getStarAssignments(1), target="gpu")

innerBlock =  ast.body._nodes[0].true_block
nodes = innerBlock._nodes

#storeExprs = []
#for n in nodes:
#    if isinstance(n.lhs, TypedSymbol):
#       assignments.append(n)
#    else:
#        storeExprs.append(n)
        
#storeStubs = []
#for s in storeExprs:
#    tempSymbol = TypedSymbol( str(s.lhs.base) + "_" + str(s.lhs.indices[0]), get_type_of_expression(s.rhs) )
#    assignments.append(ps.Assignment(tempSymbol, s.rhs))
#    storeStubs.append(ps.Assignment(s.lhs, tempSymbol))
    


newNodes = []

for block in range(blockCount):
    blockNodes = []
    symbols = [n.lhs for n in nodes]
    new_symbols = [ps.TypedSymbol( n.name + "_b"  + str(block), n.dtype)  for n in symbols if isinstance(n, ps.TypedSymbol)]    
    
    for n in nodes:
        print(n)
        blockNodes.append( astn.SympyAssignment(n.lhs.subs( [*zip(symbols, new_symbols)] ), 
                                         n.rhs.subs( [*zip(symbols, new_symbols)] )))
        print(blockNodes[-1])
        print()
    for n in blockNodes:
        if isinstance(n.lhs, astn.ResolvedFieldAccess):
            n.lhs = astn.ResolvedFieldAccess(n.lhs.base.subs([*zip(symbols, new_symbols)]), n.lhs.indices[0], n.lhs.field, n.lhs.offsets, n.lhs.idx_coordinate_values)
            
    newNodes.extend(blockNodes)
    print()
print()
ast.body._nodes[0].true_block = astn.Block(newNodes)

print(ast.body._nodes[0].true_block)

print(ast.__dict__)

#kernel = ast.compile()    
ps.show_code(ast)


In [None]:
predValues = dict()
measValues = dict()
kernelCache = dict()
wsKernelCache = dict()

In [None]:

xticks = []
xtickLabels = []
xtickCounter = 0

print()
print( "                      mem     mem      L2      L2")
print( "                     load   store    load   store       L1")
print()


for r in [2]:
    for blocking_factors in [(1,1,1), (1,1,2), (1,1,4)]:
        for xblock in [4, 8, 16, 32, 64, 128, 256]:
            for yblock in [1, 2, 4, 8, 16, 32, 64, 256]:
                for zblock in [1, 2, 4, 8, 16]:
                    if xblock*yblock*zblock not in [1024]:
                        continue

                    block = (xblock, yblock, zblock)
                    key = (r, *block, *blocking_factors)

                    print("block:" + str(block))
                    print("blocking_factors:" + str(blocking_factors))
                    
                    User = Query()
                    records = db.search((User.block==list(block)) &
                                    (User.range==r) & (User.dim == 3) & 
                                    (User.device=="V100") & 
                                    (User.blocking_factors==list(blocking_factors)) )
                    if(len(records) > 0):
                        print()
                        continue
                    
                    
                    if key in kernelCache:
                        kernel = kernelCache[key]
                    else:             
                        kernel = SS.getStarKernel(block, r, blocking_factors)
                        kernelCache[key] = kernel


                    if key in wsKernelCache:
                        wsKernel = wsKernelCache[key]
                    else:                   
                        wsKernel = PyStencilsWarpSpeedKernel(kernel.ast)     
                        wsKernel.registers = kernel.num_regs
                        wsKernelCache[key] = wsKernel

                    runFunc = SS.getRunFunc(kernel)

                    grid = tuple( (SS.size[i]) // block[i] for i in range(3))

                  
                    print("Registers: " + str(kernel.num_regs))
                    predV = getVolumes(wsKernel, block, grid, (r, r, r, *SS.size), blocking_factors)
                    if not key in measValues:
                        measV = measureMetrics(runFunc, SS.size)
                    else:
                        measV = measValues[key]

                    print("r={}  {:12}   {:5.2f}   {:5.2f}   {:5.2f}   {:5.2f}".format(r, str(block), measV["memLoad"], measV["memStore"], measV["L2Load"], measV["L2Store"] ))
                    print("            {:5.2f} / {:4.2f}   {:5.2f}   {:5.2f}   {:5.2f}   {:6.1f}".format(predV["memLoad"], predV["memLoadISL"], predV["memStore"], predV["L2Load"], predV["L2Store"], predV["L1cycles"]))


                    key = (r, *block)
                    predValues[key] = predV
                    measValues[key] = measV            

                    User = Query()
                    db.upsert({  "sort": "ordered", "block" : block, "range": r, "dim" : 3, "device" : "V100", "blocking_factors" : blocking_factors, "stype" : "star",
                              "mMemLoad" : measV["memLoad"], "mMemStore" : measV["memStore"], "mL2Load" : measV["L2Load"], "mL2Store" : measV["L2Store"],
                               "pL1LoadAllocated" : predV["L1AllocatedLoad"], "pL1Load" : predV["L1Load"], "pL1WarpLoad" : predV["L1WarpLoad"], "pL2Load" : predV["L2Load"], "pL2LoadExt" : predV["L2LoadExt"], "pL2Store" : predV["L2Store"],
                              "pMemLoad" : predV["memLoad"], "pMemLoadISL" : predV["memLoadISL"],"pMemLoadISLext" : predV["memLoadISLext"],  "pMemStore" : predV["memStore"], "pL2LoadAllocated" : predV["L2LoadAllocated"]},
                              (User.block==list(block)) & (User.range==r) & (User.dim == 3) & (User.device=="V100") & (User.blocking_factors==list(blocking_factors)) &  (User.stype=="star"))                

                    #print((measV["memStore"] - predV["memStore"]) / (predV["L2Store"] - predV["memStore"]))
                    print()

In [None]:
User = Query()
records = db.search(User.range>0)
print(len(records))
print(records[0])

L1AllocatedVolumes = [d["pL1LoadAllocated"] + 2048 * d["blocking_factors"][2]*8 for d in records]
L2LoadCapacitySpillRatio = [ (d["mL2Load"] - d["pL2Load"]) / ((d["pL1Load"])  - d["pL2Load"])  for d in records]
colors = [( math.log(d["block"][0])/7 , math.log(d["block"][1]) / 6, math.log(d["blocking_factors"][2])/3) for d in records ]
#colors = [(1,0,0) if d["sort"] == "random" else ((0,1,0) if d["sort"] == "ordered" else (0,0,1)) for d in records]

ax,fig = plt.subplots(figsize=(10,8))
plt.scatter(L1AllocatedVolumes, L2LoadCapacitySpillRatio, marker="*", c = colors)
plt.scatter(L1AllocatedVolumes, L2LoadCapacitySpillRatio, marker="o", c = colors, s = [1729] * len(colors), alpha=0.01, edgecolors="None")
plt.vlines(128*1024, 0, 0.8, color="gray")
plt.grid()
plt.xscale("log")
points = np.arange(128*1024, 1000000, 4096)
plt.plot( points,   (points - 128*1024) / points *0.4 ) 




#for e in db.search(User.pL1LoadAllocated > 200*1024):
#    print(e)

In [None]:
User = Query()
records = db.search(User.range>0)
print(len(records))
print(records[0])

L2Volumes = [d["pL2LoadAllocated"] + d["pMemStore"] * 80 * 2048 * d["blocking_factors"][2] for d in records]
MemLoadCapacitySpillRatio = [ (d["mMemLoad"] - d["pMemLoadISL"]) / (d["pL2Load"]- d["pMemLoad"] )  for d in records]
colors = [( math.log(d["block"][0]) / 7 , math.log(d["block"][2]) / 3, math.log(d["blocking_factors"][2])/3) for d in records ]
#colors = [(1,0,0) if d["sort"] == "random" else ((0,1,0) if d["sort"] == "ordered" else (0,0,1)) for d in records]

ax,fig = plt.subplots(figsize=(10,8))
plt.scatter(L2Volumes, MemLoadCapacitySpillRatio, marker="*", c = colors)
plt.scatter(L2Volumes, MemLoadCapacitySpillRatio, marker="o", c = colors, s = [1729] * len(colors), alpha=0.02, edgecolors="None")
plt.vlines(6*1024*1024, 0, 1.0, color="gray")
plt.grid()
#plt.xscale("log")
points = np.arange(6*1024*1024, 5*6*1024*1024, 1024*1024)
plt.plot( points,   (points - 6*1024*1024) / points *0.5 ) 




#for e in db.search(User.pL1LoadAllocated > 200*1024):
#    print(e)

In [None]:
def getArray(name):
    User = Query()
    records = db.all()
    return { tuple([r["blocking_factors"][2], *r["block"]]) :  r[name] for r in records }
    


In [None]:
volumeScatterPlot(getArray("mMemLoad"), getArray("pMemLoad"), "Memory Load Volumes")
volumeScatterPlot(getArray("mMemLoad"), getArray("pMemLoadISL"), "Memory Load Volumes ISL")
volumeScatterPlot(getArray("mMemLoad"), getArray("pMemLoadISLext"), "Memory Load Volumes ISL EXT")

In [None]:
volumeScatterPlot(getArray("mL2Load"), getArray("pL2Load"), "L2 Load Volumes")
volumeScatterPlot(getArray("mL2Load"), getArray("pL2LoadExt"), "L2 Load Volumes Ext")


In [None]:
volumeScatterPlot({key: v["memStore"] for key, v in measValues.items()}, {key: v["memStore"] for key, v in predValues.items()}, "Memory Store Volumes")
volumeScatterPlot({key: v["memStore"] for key, v in measValues.items()}, {key: v["memStoreExt"] for key, v in predValues.items()}, "Memory Store Volumes Ext")
volumeScatterPlot({key: v["L2Store"] for key, v in measValues.items()}, {key: v["L2Store"] for key, v in predValues.items()}, "L2 Store Volumes")