In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))


import sys 
sys.path.append('../warpspeed/')

from predict_metrics import *
import sympy as sp
from warpspeedkernel import *
import matplotlib.pyplot as plt

In [None]:
blockSize = (512, 1, 1)
blockingFactors = (1,1,1)

device = DeviceAmpere()
preds = {}


In [None]:
xySizes = [8, 16] + [i * 32 for i in range(1, 16)]


for vectorCount in [1, 8]:
    if not vectorCount in preds:
        preds[vectorCount] = {}
    for xy in xySizes:
        print("NB=" + str(vectorCount))
        domain = (xy, xy, 512 * 256 * 256 // (xy * xy))
        print(domain)
        if xy in preds[vectorCount]:
            continue
        
        loadFields = []
        storeFields = []

        xloads = []
        for z in {-1, 0, 1}:
            for y in {-1, 0, 1}:
                for x in {-1, 0, 1}:
                    xloads.append(
                        ("tidx + " + str(x), "tidy + " + str(y), "tidz + " + str(z))
                    )

        loadFields.append(Field("X", xloads, 8, [d + 2 for d in domain], 0, multiplicity=vectorCount))
        storeFields.append( Field("Y", [("tidx", "tidy", "tidz")], 8, [d + 2 for d in domain], 0, multiplicity=vectorCount))

        matrixLoads = []
        for row in range(0, 27):
            matrixLoads.append(
                (
                    "(tidx + tidy * {0} + tidz * {0} * {1}) + {3} * {0} * {1} * {2}".format(
                        domain[0], domain[1], domain[2], row
                    ),
                    "0",
                    "0",
                )
            )

        loadFields.append(
            Field("mat", matrixLoads, 8, (domain[0], domain[1], domain[2] * 27), 0)
        )
        loadFields.append(
            Field("idx", matrixLoads, 4, (domain[0], domain[1], domain[2] * 27), 0)
        )


        kernel = WarpspeedKernel(loadFields, storeFields, 64, flops=27 * 2 *  vectorCount)

        lc = LaunchConfig.compute(kernel, blockSize, domain, blockingFactors, device)
        basic = BasicMetrics.compute(lc, device, kernel)
        pred = DerivedMetrics(lc, basic, device)


        
        print(basic)
        print(display(HTML(pred.html())))

        preds[vectorCount][xy] = pred



In [None]:
fig,ax = plt.subplots(figsize=(8,4), dpi=200)

for vectorCount in preds.keys():
    ax.plot([xy*xy for xy in preds[vectorCount].keys()],  [preds[vectorCount][key].perfV3 for key in preds[vectorCount].keys()], ".-")
    
ax.set_xscale("log")
ax.set_ylim((0, ax.get_ylim()[1]))

plt.show()