In [None]:
!pip install pycuda

In [None]:
import math
import numpy as np
#try:
#  import pycuda.driver as cuda
#  import pycuda.autoinit
#  from pycuda.compiler import SourceModule
#except ImportError:
#   pass
#try:
#  import Metal
#except ImportError:
#  print("couldnt get the metal")
#  pass

import pycuda.driver as cuda
import pycuda.autoinit
from pycuda.compiler import SourceModule

class buffer:
  def __init__(self,data,size,d):
      self.data = data
      self.size = size
      self.d = d
      self.np_data = None #cache?
      #TODO cache np if faster?

  def np(self,params=None):
    if self.d == "Metal": #TODO np_cache
        output = np.asarray(self.data.contents().as_buffer(self.size))
        return np.frombuffer(output, dtype=np.float32)
    if self.d == "OpenCL":
        queue = params["queue"]
        if self.np_data is None:
            self.np_data = np.zeros(math.ceil(self.size/4)).astype(np.float32)
        cl.enqueue_copy(queue, self.np_data, self.data)
        return self.np_data
    if self.d == "CUDA":
        if self.np_data is None:
            self.np_data = np.zeros(math.ceil(self.size/4)).astype(np.float32)
        cuda.memcpy_dtoh(self.np_data, self.data)
        return self.np_data

  def copy(self,params):
    return create_buffer(self.np(params),self.d,params)

  def rand_like(x,params):
        return create_buffer(np.random.random(np.shape(x.np(params).flatten())).astype(np.float32),x.d,params)

  def delete(self): #todo OpenCL
      if self.d == "Metal":
        self.data.setPurgeableState_(Metal.MTLPurgeableStateEmpty)
        self.data.release()


def compile(prg_str,d,params):
  if d == "Metal":
    library, err = params["device"].newLibraryWithSource_options_error_(prg_str, Metal.MTLCompileOptions.alloc().init(), None)
    return library
  if d == "OpenCL":
    return cl.Program(params["ctx"],prg_str).build()
  if d == "CUDA":
    return SourceModule(prg_str)

def run(prg,func,params,args,gs,ls,d):
  if d == "Metal":
    mtl_queue = params["queue"]
    device = params["device"]
    fxn = prg.newFunctionWithName_(func)
    command_buffer = mtl_queue.commandBuffer()
    encoder = command_buffer.computeCommandEncoder()
    pipeline_state, err = device.newComputePipelineStateWithFunction_error_(fxn, None)
    encoder.setComputePipelineState_(pipeline_state)
    i = 0
    for arg in args:
        encoder.setBuffer_offset_atIndex_(arg.data, 0, i)
        i+=1
    threadsPerGrid = Metal.MTLSizeMake(gs,1,1)
    threadsPerThreadGroup = Metal.MTLSizeMake(ls,1,1)
    encoder.dispatchThreadgroups_threadsPerThreadgroup_(threadsPerGrid, threadsPerThreadGroup)
    encoder.endEncoding()
    command_buffer.commit()
    command_buffer.waitUntilCompleted()
  if d == "OpenCL":
     gs*=ls
     queue = params["queue"]
     kernel = getattr(prg,func)
     data = []
     for a in args: data.append(a.data) #todo, better way?
     kernel(queue, (gs,1), (ls,1),*data)
  if d == "CUDA":
    fxn = prg.get_function(func)
    data = []
    for a in args: data.append(a.data) #todo, better way?
    fxn(*data,block=(ls,1,1),grid=(gs,1))
  return

def run_test(prg,func,params,args,gs,ls,d): #TODO, only for metal because no delete in OpenCL yet
  args_copy_a = []
  for a in args:
    args_copy_a.append(a.copy(params))
  run(prg,func,params,args_copy_a,gs,ls,d)
  for x in range(3):
    print("test =",x)
    args_copy_b = []
    for a in args:
        args_copy_b.append(a.copy(params))
    run(prg,func,params,args_copy_b,gs,ls,d)
    for j in range(len(args_copy_b)):
      np.testing.assert_allclose(args_copy_a[j].np(params),args_copy_b[j].np(params),1e-6)
    for j in range(len(args_copy_b)):
      args_copy_b[j].delete()
    args_copy_b = [] #todo, needed?
  for j in range(len(args_copy_a)):
    args_copy_a[j].delete()
  args_copy_a = []#todo, needed?
  run(prg,func,params,args,gs,ls,d)
  return

def run_old(prg,func,params,args,gs,ls,d):
  if d == "Metal":
    mtl_queue = params["queue"]
    device = params["device"]
    fxn = prg.newFunctionWithName_(func)
    command_buffer = mtl_queue.commandBuffer()
    encoder = command_buffer.computeCommandEncoder()
    pipeline_state, err = device.newComputePipelineStateWithFunction_error_(fxn, None)
    encoder.setComputePipelineState_(pipeline_state)
    i = 0
    for arg in args:
        encoder.setBuffer_offset_atIndex_(arg.data, 0, i)
        i+=1
    threadsPerGrid = Metal.MTLSizeMake(gs,1,1)
    threadsPerThreadGroup = Metal.MTLSizeMake(ls,1,1)
    encoder.dispatchThreadgroups_threadsPerThreadgroup_(threadsPerGrid, threadsPerThreadGroup)
    encoder.endEncoding()
    command_buffer.commit()
    command_buffer.waitUntilCompleted()
  if d == "OpenCL":
     queue = params["queue"]
     kernel = getattr(prg,func)
     kernel(queue, (gs,1), (ls,1),*args)
  return

def create_buffer(a,d,params):
  if d == "Metal":
    device = params["device"]
    a_buffer = device.newBufferWithLength_options_(len(a.flatten())*4 ,1)
    m = a_buffer.contents().as_buffer(len(a.flatten())*4)
    m[:] = bytes(a)
    return buffer(a_buffer,len(a.flatten())*4,d)
  if d == "OpenCL":
    ctx = params["ctx"]
    mf = params["mf"]
    data = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=a)
    return buffer(data,len(a.flatten()),d)
  if d == "CUDA":
    a_gpu = cuda.mem_alloc(a.nbytes)
    cuda.memcpy_htod(a_gpu, a)
    return buffer(a_gpu,len(a.flatten()),d)
  return None

def create_buffer_empty(size,d,params):
    if d == "Metal":
        device = params["device"]
        a_buffer = device.newBufferWithLength_options_(size ,1)
        return buffer(a_buffer,size,d)
    if d == "OpenCL":
        ctx = params["ctx"]
        mf = params["mf"]
        data = cl.Buffer(ctx, mf.READ_ONLY, size)
        return buffer(data,size,d)
    if d == "CUDA":
        return buffer(cuda.mem_alloc(size),size,d)
    return None


In [None]:
import numpy as np
import time
import math
ls = 256

kernel_prefix = {"OpenCL":"","Metal":"#include <metal_stdlib>\n#include <metal_simdgroup_matrix>\nusing namespace metal;\n","CUDA":""}
uint3_arg = {"OpenCL":"","Metal":", uint3 gid [[thread_position_in_grid]]","CUDA":""}
func_dec = {"OpenCL":"__kernel","Metal":"kernel","CUDA":"__global__"} #TODO vs local cuda?
var_dec = {"OpenCL":"__global","Metal":"device","CUDA":""}
barrier = {"OpenCL":"barrier(CLK_LOCAL_MEM_FENCE);","Metal":"threadgroup_barrier(mem_flags::mem_threadgroup);","CUDA":" __syncthreads();"}
global_idx = {"OpenCL":"get_global_id(0)","Metal":"gid.x","CUDA":"threadIdx.x+blockIdx.x*blockDim.x"}
local_var = {"OpenCL":"__attribute__ ((aligned (16))) __local","Metal":"threadgroup","CUDA":"__shared__"}



class Kernels:
    def __init__(self,dim,n_heads,max_context,device):
        self.d = device
        self.prg_cache = {}
        self.dim = dim
        self.n_heads = n_heads
        self.max_context = max_context
        if device == "Metal":
            self.device = Metal.MTLCreateSystemDefaultDevice()
            self.queue = self.device.newCommandQueue()
            self.params = {"queue":self.queue,"device":self.device}
        if device == "OpenCL":
            platform = cl.get_platforms()
            my_gpu_devices = platform[0].get_devices(device_type=cl.device_type.GPU)
            ctx = cl.Context(devices=my_gpu_devices)
            self.queue = cl.CommandQueue(ctx)
            mf = cl.mem_flags
            prg = None
            self.params = {"ctx":ctx,"mf":mf,"queue":self.queue}
        if device == "CUDA":
            self.params = None

    def add(self,a_g,b_g,b_s=0,a_s=0):
        if hasattr(self, 'add_res_g') == False:
            self.add_res_g = create_buffer_empty(self.dim*4,self.d,self.params)
        prg_str = f"""
        {kernel_prefix[self.d]}
        {func_dec[self.d]} void add(
            {var_dec[self.d]} const float *a, {var_dec[self.d]} const float *b, {var_dec[self.d]} float *res{uint3_arg[self.d]})
        {{
        int gidx0 = {global_idx[self.d]};
        if(gidx0 < {self.dim}) {{
            res[gidx0] = a[{int(a_s)*self.dim} + gidx0] + b[gidx0 + {b_s*self.dim}];
        }}
        }}
        """
        prg = compile(prg_str,self.d,self.params)
        g = math.ceil(self.dim / ls)
        run(prg,"add",self.params,[a_g, b_g,self.add_res_g],g,ls,self.d) #TODO this breaks it all (CUDA?)
        return self.add_res_g

    def tok_emb(self,tokens,weight_g,weight_2_g,no_tokens):
        tokens_g = create_buffer(tokens.astype(np.int32),self.d,self.params)
        size = no_tokens*self.dim
        tok_emb_g = create_buffer_empty(no_tokens*self.dim*4,self.d,self.params)
        prg_str = f"""
        {kernel_prefix[self.d]}
        {func_dec[self.d]} void mm(
            {var_dec[self.d]} int *tokens, {var_dec[self.d]} const float *weight, {var_dec[self.d]} const float *weight2,  {var_dec[self.d]} float *tok_emb{uint3_arg[self.d]})
        {{
            int gidx0 = {global_idx[self.d]};
            int i = gidx0 / {self.dim};
            int j = gidx0 % {self.dim};
            tok_emb[i*{self.dim} + j] = weight[tokens[i]*{self.dim} + j] + weight2[i*{self.dim} + j];
        }}
        """
        library = compile(prg_str,self.d,self.params)
        gs = math.ceil(size / ls)
        run(library,"mm",self.params,[tokens_g,weight_g,weight_2_g,tok_emb_g],gs,ls,self.d)
        return tok_emb_g

    def kernel_1(self,h_g,weight_g,bias_g,weight2_g,temperature,random_num):
        ls = 256
        seg = int(self.dim / ls)
        rows = self.dim
        cols = 50257
        if hasattr(self, 'logits_g') == False:
            self.logits_g = create_buffer_empty(50257*4,self.d,self.params)
        if hasattr(self, 'res') == False:
            self.res = np.zeros(1).astype(np.float32)
        if hasattr(self, 'res_g') == False:
            self.res_g = create_buffer_empty(1*4,self.d,self.params)
        seg2 = math.ceil(50257 / ls)
        prg_str = f"""
        {kernel_prefix[self.d]}
        {func_dec[self.d]} void mm4(
            {var_dec[self.d]} float *h, {var_dec[self.d]} const float *weight, {var_dec[self.d]} const float *bias{uint3_arg[self.d]})
        {{
            {local_var[self.d]} float temp[{ls}];
            {local_var[self.d]} float mean;
            int lidx0 = {global_idx[self.d]};
            float total = 0;
            for(int i = 0; i < {seg}; i++) {{
                total += h[lidx0*{seg} + i];
            }}
            temp[lidx0] = total;
            {barrier[self.d]}
            if(lidx0==0) {{
                total = 0;
                for(int i = 0; i < {ls}; i++) {{
                    total += temp[i];
                }}
                mean = total / {self.dim};
            }}
            {barrier[self.d]}
            for(int i = 0; i < {seg}; i++) {{
                h[i + lidx0*{seg}] -= mean;
            }}
            {barrier[self.d]}
            total = 0;
            for(int i = 0; i < {seg}; i++) {{
                total += pow(h[lidx0*{seg} + i],2);
            }}
            temp[lidx0] = total;
            {barrier[self.d]}
            if(lidx0==0) {{
                total = 0;
                for(int i = 0; i < {ls}; i++) {{
                    total += temp[i];
                }}
                mean = pow(total / {self.dim} + 1e-5,0.5);
            }}
            {barrier[self.d]}
            for(int i = 0; i < {seg}; i++) {{
                h[i + lidx0*{seg}] = (h[i + lidx0*{seg}] * weight[i + lidx0*{seg}]) / mean + bias[i + lidx0*{seg}];
            }}
        }}
        {func_dec[self.d]} void matvec(
            {var_dec[self.d]} const float *h, {var_dec[self.d]} const float *weight2 , {var_dec[self.d]} float *res{uint3_arg[self.d]})
        {{
            int gidx0 = {global_idx[self.d]};
            res[gidx0] = 0;
            for(int j = 0; j < {rows}; j++) {{
                res[gidx0] += h[j] * weight2[gidx0 + j*{cols}];
            }}
            res[gidx0] /= {temperature};
        }}
        {func_dec[self.d]} void mm5(
            {var_dec[self.d]} const float *a, {var_dec[self.d]} float *res{uint3_arg[self.d]})
        {{
            res[0] = a[0]; //todo why is this needed?, used to be a MAX
        }}

        {func_dec[self.d]} void mm6(
        {var_dec[self.d]} float *a, {var_dec[self.d]} float *res{uint3_arg[self.d]})
        {{
            int gidx0 = {global_idx[self.d]};
            a[gidx0] = exp(a[gidx0] - res[0]);
        }}

        {func_dec[self.d]} void mm8(
        {var_dec[self.d]} float *a, {var_dec[self.d]} const float *res{uint3_arg[self.d]})
        {{
            int gidx0 = {global_idx[self.d]};
            a[gidx0] = a[gidx0] / res[0];
        }}

        {func_dec[self.d]} void mm9(
        {var_dec[self.d]} const float *a, {var_dec[self.d]} float *res{uint3_arg[self.d]})
        {{
            {local_var[self.d]} float temp[{ls}];
            int lidx0 = {global_idx[self.d]};
            temp[lidx0] = 0;
            for(int i = 0; i < {math.ceil(50257 / ls)}; i++) {{
                if(lidx0*{math.ceil(50257 / ls)} + i < 50257){{
                temp[lidx0] += a[lidx0*{math.ceil(50257 / ls)} + i];
                }}
            }}
            {barrier[self.d]}
            float t = 0;
            if(lidx0 == 0) {{
                for(int i = 0; i < {ls}; i++) {{
                    t+=temp[i];
                }}
                res[0] = t;
            }}
        }}

        {func_dec[self.d]} void mm10(
        {var_dec[self.d]} float *a{uint3_arg[self.d]})
        {{
            for(int i = 1; i < 50257; i++) {{
                a[i] += a[i-1];
            }}
        }}
        """

        if prg_str not in self.prg_cache:
            library = compile(prg_str,self.d,self.params)
            self.prg_cache[prg_str] = library
        prg = self.prg_cache[prg_str]

        prg_str = f"""
        {func_dec[self.d]} void mm11(
        {var_dec[self.d]} float *a{uint3_arg[self.d]})
        {{
            int gidx0 = {global_idx[self.d]};
            if(a[gidx0] < {random_num}) {{ //TODO, used to be (a[gidx0] / a[50256])/{random_num}
                a[gidx0] = 1;
            }} else {{
                a[gidx0] = 0;
            }}
        }}
        """
        prg2 = compile(prg_str,self.d,self.params)

        gs =  math.ceil(50257 / ls)
        run(prg,"mm4",self.params,[h_g, weight_g, bias_g],1,ls,self.d)
        run(prg,"matvec",self.params,[h_g, weight2_g,self.logits_g],gs,ls,self.d)
        run(prg,"mm5",self.params,[self.logits_g,self.res_g],1,1,self.d)
        run(prg,"mm6",self.params,[self.logits_g,self.res_g],gs,ls,self.d)
        run(prg,"mm5",self.params,[self.logits_g,self.res_g],1,1,self.d)
        run(prg,"mm8",self.params,[self.logits_g,self.res_g],gs,ls,self.d)
        run(prg,"mm9",self.params,[self.logits_g,self.res_g],1,ls,self.d)
        run(prg,"mm8",self.params,[self.logits_g,self.res_g],gs,ls,self.d)
        run(prg,"mm10",self.params,[self.logits_g],1,1,self.d)
        run(prg2,"mm11",self.params,[self.logits_g],gs,ls,self.d)
        run(prg,"mm9",self.params,[self.logits_g,self.res_g],1,ls,self.d)
        return self.res_g.np(self.params)

    def kernel_3(self,x_g,weight_g,bias_g,attn_weight_g,attn_bias_g,new_cache_g\
        ,ln_f_weight_g,ln_f_bias_g,n_tokens,max_content,lm_head_weight_g,temperature,random_num):
        ls = 256
        size = self.dim
        b_cols2 = 50257
        b_rows2 = self.dim
        seg2 = math.ceil(50257 / ls)
        b_cols = self.dim*3 #todo
        b_rows = self.dim
        seg = int(size / ls) #todo
        x0_g = create_buffer_empty(n_tokens*self.dim*4,self.d,self.params)
        logits_g = create_buffer_empty(50257*4,self.d,self.params)
        c_g = create_buffer_empty(max_content*b_cols*4,self.d,self.params) #todo, can this be smaller?
        res = np.zeros(1).astype(np.float32)
        res_g = create_buffer_empty(1*4,self.d,self.params)

        prg_str = f"""
        {kernel_prefix[self.d]}
        {func_dec[self.d]} void mm({var_dec[self.d]} const float *x_in,
            {var_dec[self.d]} float *x, {var_dec[self.d]} const float *weight, {var_dec[self.d]} const float *bias{uint3_arg[self.d]})
        {{
            {local_var[self.d]} float temp[{ls}];
            {local_var[self.d]} float temp2[{n_tokens}];
            int gidx0 = {global_idx[self.d]};
            int lidx0 = gidx0 % {ls};
            int r = gidx0 / {ls};
            temp2[r] = 0;
            for(int i = 0; i < {seg}; i++) {{
                x[{self.dim}*r + lidx0*{seg} + i] = x_in[{self.dim}*r + lidx0*{seg} + i];
                temp2[r] += x[{self.dim}*r + lidx0*{seg} + i];
            }}
            temp[lidx0] = temp2[r];
            {barrier[self.d]}
            if(lidx0<{n_tokens}) {{
                temp2[lidx0] = 0;
                for(int i = 0; i < {ls}; i++) {{
                    temp2[lidx0] += temp[i];
                }}
                temp2[lidx0] = temp2[lidx0] / {size};
            }}
            {barrier[self.d]}
            for(int i = 0; i < {seg}; i++) {{
                x[{self.dim}*r + i + lidx0*{seg}] -= temp2[r];
            }}
            {barrier[self.d]}
            temp2[r] = 0;
            for(int i = 0; i < {seg}; i++) {{
                temp2[r] += pow(x[{self.dim}*r + lidx0*{seg} + i],2);
            }}
            temp[lidx0] = temp2[r];
            {barrier[self.d]}
            if(lidx0<{n_tokens}) {{
                temp2[lidx0] = 0;
                for(int i = 0; i < {ls}; i++) {{
                    temp2[lidx0] += temp[i];
                }}
                temp2[lidx0] = pow(temp2[lidx0] / {size} + 1e-5,0.5);
            }}
            {barrier[self.d]}
            for(int i = 0; i < {seg}; i++) {{
                x[{self.dim}*r + i + lidx0*{seg}] = (x[{self.dim}*r + i + lidx0*{seg}] * weight[i + lidx0*{seg}]) / temp2[r] + bias[i + lidx0*{seg}];
            }}
        }}
        {func_dec[self.d]} void mm2(
            {var_dec[self.d]} const float *x, {var_dec[self.d]} const float *attn_weight, {var_dec[self.d]} const float *attn_bias,{var_dec[self.d]} float *res{uint3_arg[self.d]})
        {{
            if({global_idx[self.d]} < {b_cols*n_tokens}) {{ //TODO
            int gidx0 = {global_idx[self.d]};
            int i = gidx0 / {n_tokens};
            int y = gidx0 % {n_tokens};
            float total = 0;
            for(int k = 0; k < {b_rows}; k++) {{
                total += x[y*{b_rows} + k] * attn_weight[i*{b_rows} + k];
            }}
            res[y*{b_cols} + i] = total + attn_bias[i];
            }}
        }}
        {func_dec[self.d]} void mm3(
            {var_dec[self.d]} const float *xqkv, {var_dec[self.d]} float *new_cache{uint3_arg[self.d]})
        {{
            int gidx0 = {global_idx[self.d]};
            int i = gidx0 / {self.n_heads*64};
            int j = gidx0 % {self.n_heads*64};
            new_cache[i*{self.n_heads*64} + j] = xqkv[i*{self.n_heads*64*3} + {self.dim*1} + j];
            new_cache[{max_content*self.n_heads*64} + i*{self.n_heads*64} + j] = xqkv[i*{self.n_heads*64*3} + {self.dim}*2 + j];
        }}
         {func_dec[self.d]} void mm4(
            {var_dec[self.d]} const float *xqkv, {var_dec[self.d]} float *new_cache{uint3_arg[self.d]})
        {{
            if({global_idx[self.d]} < {b_cols*n_tokens}) {{
            int gidx0 = {global_idx[self.d]};
            int i = gidx0 / {self.n_heads*64};
            int j = gidx0 % {self.n_heads*64};
            new_cache[i*{self.n_heads*64} + j] = xqkv[i*{self.n_heads*64*3} + j + {self.dim}];
            new_cache[{max_content*self.n_heads*64} + i*{self.n_heads*64} + j] = xqkv[i*{self.n_heads*64*3} + j + {self.dim}*2];
            }}
        }}
        {func_dec[self.d]} void mm5(
            {var_dec[self.d]} float *x, {var_dec[self.d]} const float *ln_f_weight, {var_dec[self.d]} const float *ln_f_bias{uint3_arg[self.d]})
        {{
            {local_var[self.d]} float temp[{ls}];
            {local_var[self.d]} float mean;
            int lidx0 = {global_idx[self.d]};
            float total = 0;
            for(int i = 0; i < {seg}; i++) {{
                total += x[lidx0*{seg} + i];
            }}
            temp[lidx0] = total;
            {barrier[self.d]}
            if(lidx0==0) {{
                total = 0;
                for(int i = 0; i < {ls}; i++) {{
                    total += temp[i];
                }}
                mean = total / {size};
            }}
            {barrier[self.d]}
            for(int i = 0; i < {seg}; i++) {{
                x[i + lidx0*{seg} + {(n_tokens - 1)*self.dim}] -= mean;
            }}
            {barrier[self.d]}
            total = 0;
            for(int i = 0; i < {seg}; i++) {{
                total += pow(x[lidx0*{seg} + i + {(n_tokens - 1)*self.dim}],2);
            }}
            temp[lidx0] = total;
            {barrier[self.d]}
            if(lidx0==0) {{
                total = 0;
                for(int i = 0; i < {ls}; i++) {{
                    total += temp[i];
                }}
                mean = pow(total / {size} + 1e-5,0.5);
            }}
            {barrier[self.d]}
            for(int i = 0; i < {seg}; i++) {{
                x[i + lidx0*{seg}] = (x[i + lidx0*{seg} + {(n_tokens - 1)*self.dim}] * ln_f_weight[i + lidx0*{seg}]) / mean + ln_f_bias[i + lidx0*{seg}];
            }}
        }}
        {func_dec[self.d]} void matmul(
            {var_dec[self.d]} const float *a, {var_dec[self.d]} const float *b, {var_dec[self.d]} float *res{uint3_arg[self.d]})
        {{
            if({global_idx[self.d]} < {b_cols2}) {{
                int x = {global_idx[self.d]};
                float total = 0;
                for(int k = 0; k < {b_rows2}; k++) {{
                    total += a[k] * b[x*{b_rows2} + k];
                }}
                res[x] = total / {temperature};
            }}
        }}
        {func_dec[self.d]} void mm6(
            {var_dec[self.d]} const float *a, {var_dec[self.d]} float *res{uint3_arg[self.d]})
        {{
            res[0] = a[0]; //todo why is this needed?, used to be a MAX
        }}

        {func_dec[self.d]} void mm7(
        {var_dec[self.d]} float *a, {var_dec[self.d]} const float *res{uint3_arg[self.d]})
        {{
            if({global_idx[self.d]} < 50257) {{ //TODO
            int gidx0 = {global_idx[self.d]};
            a[gidx0] = exp(a[gidx0] - res[0]);
            }}
        }}

        {func_dec[self.d]} void mm9(
        {var_dec[self.d]} float *a, {var_dec[self.d]} const float *res{uint3_arg[self.d]})
        {{
            if({global_idx[self.d]} < 50257) {{
            int gidx0 = {global_idx[self.d]};
            a[gidx0] = a[gidx0] / res[0];
            }}
        }}

        {func_dec[self.d]} void mm10(
        {var_dec[self.d]} const float *a, {var_dec[self.d]} float *res{uint3_arg[self.d]})
        {{
            {local_var[self.d]} float temp[{ls}];
            int lidx0 = {global_idx[self.d]};
            float t = 0;
            for(int i = 0; i < {seg2}; i++) {{
                if(lidx0*{seg2} + i < 50257) {{
                t += a[lidx0*{seg2} + i];
                }}
            }}
            temp[lidx0] = t;
            {barrier[self.d]}
            if(lidx0 == 0) {{
                t = 0;
                for(int i = 0; i < {ls}; i++) {{
                    t += temp[i];
                }}
                res[0] = t;
            }}
        }}

        {func_dec[self.d]} void mm11(
        {var_dec[self.d]} float *a{uint3_arg[self.d]})
        {{
            for(int i = 1; i < 50257; i++) {{
                a[i] += a[i-1];
            }}
        }}

        {func_dec[self.d]} void mm12(
        {var_dec[self.d]} float *a{uint3_arg[self.d]})
        {{
            int gidx0 = {global_idx[self.d]};
            if(a[gidx0] < {random_num}) {{ //TODO, used to be (a[gidx0] / a[50256])/{random_num}
                a[gidx0] = 1;
            }} else {{
                a[gidx0] = 0;
            }}
        }}
        """

        if prg_str not in self.prg_cache:
            library = compile(prg_str,self.d,self.params)
            self.prg_cache[prg_str] = library
        prg = self.prg_cache[prg_str]

        run(prg,"mm",self.params,[x_g, x0_g, weight_g, bias_g],n_tokens,ls,self.d)
        gs = math.ceil(b_cols*n_tokens / ls)
        run(prg,"mm2",self.params,[x0_g, attn_weight_g,attn_bias_g,c_g],gs,ls,self.d)
        run(prg,"mm4",self.params,[c_g, new_cache_g],gs,ls,self.d)
        run(prg,"mm5",self.params,[x_g, ln_f_weight_g, ln_f_bias_g],1,ls,self.d)
        run(prg,"matmul",self.params,[x_g, lm_head_weight_g,logits_g],math.ceil(b_cols2 / ls),ls,self.d)
        run(prg,"mm6",self.params,[logits_g,res_g],1,1,self.d)
        gs = math.ceil(50257 / ls)
        run(prg,"mm7",self.params,[logits_g,res_g],gs,ls,self.d)
        run(prg,"mm6",self.params,[logits_g,res_g],1,1,self.d)
        run(prg,"mm9",self.params,[logits_g,res_g],gs,ls,self.d)
        run(prg,"mm10",self.params,[logits_g,res_g],1,ls,self.d)
        run(prg,"mm9",self.params,[logits_g,res_g],gs,ls,self.d)
        run(prg,"mm11",self.params,[logits_g],1,1,self.d)
        run(prg,"mm12",self.params,[logits_g],gs,ls,self.d)
        run(prg,"mm10",self.params,[logits_g,res_g],1,ls,self.d)
        return res_g.np(self.params)

    def kernel_0(self,a_g,c_g,d_g,e_g,xqkv_g,keys_values_g,weight_g,bias_g,\
        weight2_g,bias2_g,weight3_g,bias3_g,weight4_g,bias4_g,start_pos,g,j=0):
        ls = 256
        ls = 256 #TODO why is 256 fastet than 32?
        seg3 = math.ceil(self.dim / ls) #todo
        seg = math.ceil(self.dim / ls)
        if hasattr(self, 'temp_g') == False:
            self.temp_g = create_buffer_empty(self.n_heads*self.max_context*4,self.d,self.params)
        if hasattr(self, 'xq_temp_g') == False:
            self.xq_temp_g = create_buffer_empty((self.n_heads*(self.max_context+1)*64 + 64)*4,self.d,self.params) #TODO can this be smaller?
        #{barrier[self.d]}
        prg_str = f"""
        {kernel_prefix[self.d]}
        {func_dec[self.d]} void mm(
            {var_dec[self.d]} float *a,
            {var_dec[self.d]} float *mean{uint3_arg[self.d]})
        {{
            {local_var[self.d]} float temp[{ls}];
            int lidx0 = {global_idx[self.d]};
            float t = 0;
            for(int i = 0; i < {seg}; i++) {{
                t += a[lidx0*{seg} + i];
            }}
            temp[lidx0] = t;
            {barrier[self.d]}
            if(lidx0==0) {{
                t = 0;
                for(int i = 0; i < {ls}; i++) {{
                    t += temp[i];
                }}
                mean[0] = t / {self.dim};
            }}
            {barrier[self.d]}
            t = 0;
            for(int i = 0; i < {seg}; i++) {{
                a[i + lidx0*{seg}] -= mean[0];
                t += pow(a[lidx0*{seg} + i],2);
            }}
            temp[lidx0] = t;
            {barrier[self.d]}
            if(lidx0==0) {{
                t = 0;
                for(int i = 0; i < {ls}; i++) {{
                    t += temp[i];
                }}
                mean[0] = pow(t / {self.dim} + 1e-5,0.5);
            }}
        }}

        {func_dec[self.d]} void mm4(
            {var_dec[self.d]} float *a,
            {var_dec[self.d]} const float *weight2, {var_dec[self.d]} const float *bias2,
            {var_dec[self.d]} const float *weight3, {var_dec[self.d]} const float *bias3,
            {var_dec[self.d]} float *mean,
            {var_dec[self.d]} float *h_temp, {var_dec[self.d]} float *h, {var_dec[self.d]} float *bias3_temp{uint3_arg[self.d]})
        {{
            int lidx0 = {global_idx[self.d]} % {ls};
            int i = {global_idx[self.d]} / {ls};
            bias3_temp[i + lidx0*{math.ceil(self.dim*4 / ls)}] = bias3[i + lidx0*{math.ceil(self.dim*4 / ls)}];
            for(int j = 0; j < {self.dim}; j++) {{
                bias3_temp[i + lidx0*{math.ceil(self.dim*4 / ls)}] += ((h[j] * weight2[j]) / mean[0] + bias2[j]) * weight3[(i + lidx0*{math.ceil(self.dim*4 / ls)})*{self.dim} + j];
            }}
            float tth = bias3_temp[i + lidx0*{math.ceil(self.dim*4 / ls)}] * 0.7978845608\
            * (1 + 0.044715 * pow(bias3_temp[i + lidx0*{math.ceil(self.dim*4 / ls)}],2));
            float th = tanh(tth);
            if(isnan(th) && tth < 0) {{ th = -1;}}
            if(isnan(th) && tth >= 0) {{ th = 1;}}
            bias3_temp[i + lidx0*{math.ceil(self.dim*4 / ls)}] = 0.5 * bias3_temp[i + lidx0*{math.ceil(self.dim*4 / ls)}]\
            * (1 + th);
        }}
        {func_dec[self.d]} void mm5(
            {var_dec[self.d]} float *a,
            {var_dec[self.d]} const float *weight4,{var_dec[self.d]} const float *bias4,
            {var_dec[self.d]} float *h_temp, {var_dec[self.d]} float *bias3_temp{uint3_arg[self.d]})
        {{
            {local_var[self.d]} float bias4_temp[{self.dim*3}];
            int lidx0 = {global_idx[self.d]} % {ls};
            int i = {global_idx[self.d]} / {ls};
            bias4_temp[lidx0 + i*{ls}] = bias4[lidx0 + i*{ls}];
            for(int j = 0; j < {self.dim*4}; j++) {{
                bias4_temp[lidx0 + i*{ls}] += bias3_temp[j] * weight4[lidx0 + i*{ls} + j*{self.dim}];
            }}
            a[lidx0 + i*{ls}] = bias4_temp[lidx0 + i*{ls}] + h_temp[lidx0 + i*{ls}];
        }}
        """

        if prg_str not in self.prg_cache:
            library = compile(prg_str,self.d,self.params)
            self.prg_cache[prg_str] = library
        prg = self.prg_cache[prg_str]

        prg_str = f"""
        {kernel_prefix[self.d]}
        {func_dec[self.d]} void mm1(
            {var_dec[self.d]} const float *a, {var_dec[self.d]} const float *c, {var_dec[self.d]} const float *d, {var_dec[self.d]} const float *e,
            {var_dec[self.d]} const float *xqkv, {var_dec[self.d]} float *keys_values,
            {var_dec[self.d]} float *xq_temp, {var_dec[self.d]} float *mean{uint3_arg[self.d]})
        {{
            int lidx0 = {global_idx[self.d]} % {ls};
            int i = {global_idx[self.d]} / {ls};
            float t = 0;
            xq_temp[lidx0*{math.ceil(self.dim*3 / ls)} + i] = xqkv[lidx0*{int(self.dim*3 / ls)} + i];
            for(int k = 0; k < {self.dim}; k++) {{
                t += ((a[k] * c[k]) / mean[0] + d[k]) * e[(lidx0*{int(self.dim*3 / ls)} + i)*{self.dim} + k];
            }}
            if((lidx0*{int(self.dim*3 / ls)} + i) < {g}) {{
                xq_temp[lidx0*{int(self.dim*3 / ls)} + i] += t;
                }}
            if((lidx0*{int(self.dim*3 / ls)} + i) >= {g} && (lidx0*{int(self.dim*3 / ls)} + i) < {2*g}) {{
                keys_values[{start_pos}*{self.dim} + lidx0*{int(self.dim*3 / ls)} + i - {g}] = xqkv[{self.dim} + lidx0*{int(self.dim*3 / ls)} + i - {g}] + t;
            }}
            if((lidx0*{int(self.dim*3 / ls)} + i) >= {2*g}) {{
                keys_values[{self.dim*self.max_context} + {start_pos}*{self.dim} + lidx0*{int(self.dim*3 / ls)} + i - {2*g}] = xqkv[{self.dim*2} + lidx0*{int(self.dim*3 / ls)} + i - {2*g}] + t;
            }}
        }}
        {func_dec[self.d]} void mm2(
            {var_dec[self.d]} const float *keys_values, {var_dec[self.d]} float *temp3, {var_dec[self.d]} const float *xq_temp{uint3_arg[self.d]})
        {{
            if({global_idx[self.d]} < {self.n_heads*(start_pos+1)*(start_pos+1)}) {{
            int lidx0 = {global_idx[self.d]};
            int x = (lidx0) % {start_pos+1};
            int k = (lidx0) / {start_pos+1};
            float acc0 = 0;
            for(int i = 0; i < 64; i++) {{
                acc0 += xq_temp[i + 64*k] * keys_values[x*{self.n_heads*64} + i + 64*k];
            }}
            if(x + k*{start_pos+1} < {self.n_heads*self.max_context}) {{
            temp3[x + k*{start_pos+1}] = acc0 / 8; //hardcoded math.sqrt(64)
            }}
            }}
        }}
        {func_dec[self.d]} void mm3(
            {var_dec[self.d]} float *a,
            {var_dec[self.d]} float *keys_values,
            {var_dec[self.d]} const float *weight,{var_dec[self.d]} const float *bias,
            {var_dec[self.d]} const float *weight2, {var_dec[self.d]} const float *bias2,
            {var_dec[self.d]} const float *weight3, {var_dec[self.d]} const float *bias3,
            {var_dec[self.d]} const float *weight4,
            {var_dec[self.d]} float *bias4,
            {var_dec[self.d]} float *temp3, {var_dec[self.d]} float *xq_temp, {var_dec[self.d]} float *mean,
            {var_dec[self.d]} float *h_temp, {var_dec[self.d]} float *h{uint3_arg[self.d]})
        {{
            {local_var[self.d]} float temp[{ls}];
            int lidx0 = {global_idx[self.d]};
            if(lidx0 < {self.n_heads}){{
            float m = -INFINITY;
            for(int i = 0; i < {start_pos+1}; i++) {{
                m = max(m,temp3[i + lidx0*{start_pos+1}]);
            }}
            float t = 0;
            for(int i = 0; i < {start_pos+1}; i++) {{
                temp3[i + lidx0*{start_pos+1}] = exp(temp3[i + lidx0*{start_pos+1}] - m);
                t += temp3[i + lidx0*{start_pos+1}];
            }}
            for(int i = 0; i < {start_pos+1}; i++) {{
                temp3[i + lidx0*{start_pos+1}] /= t;
            }}
            }}
            {barrier[self.d]}
            for(int g = 0; g < {seg3}; g++) {{
                float acc0 = 0;
                for(int i = 0; i < {start_pos+1}; i++) {{
                    acc0 += temp3[i + {start_pos+1}*((g + lidx0*{seg3}) / 64)] * keys_values[{self.dim*self.max_context} + i*{self.n_heads*64} + g + lidx0*{seg3}];
                }}
                xq_temp[g + lidx0*{seg3}] = acc0;
            }}
            {barrier[self.d]}
            for(int i = 0; i < {seg3}; i++) {{
                float acc = 0;
                for(int x = 0; x < {self.dim}; x++) {{
                    acc += xq_temp[x] * weight[x*{self.dim} + lidx0*{seg3} + i];
                }}
                h[lidx0*{seg3} + i] = a[lidx0*{seg3} + i] + acc + bias[lidx0*{seg3} + i];
                h_temp[lidx0*{seg3} + i] = h[lidx0*{seg3} + i];
            }}
            {barrier[self.d]}
            float total = 0;
            for(int i = 0; i < {seg3}; i++) {{
                total += h[lidx0*{seg3} + i];
            }}
            temp[lidx0] = total;
            {barrier[self.d]}
            if(lidx0==0) {{
                total = 0;
                for(int i = 0; i < {ls}; i++) {{
                    total += temp[i];
                }}
                mean[0] = total / {self.dim};
            }}
            {barrier[self.d]}
            total = 0;
            for(int i = 0; i < {seg3}; i++) {{
                h[i + lidx0*{seg3}] = h[i + lidx0*{seg3}] - mean[0];
                total += pow(h[lidx0*{seg3} + i],2);
            }}
            temp[lidx0] = total;
            {barrier[self.d]}
            if(lidx0==0) {{
                total = 0;
                for(int i = 0; i < {ls}; i++) {{
                    total += temp[i];
                }}
                mean[0] = pow(total / {self.dim} + 1e-5,0.5);
            }}
            }}
        """
        if prg_str not in self.prg_cache:
            library = compile(prg_str,self.d,self.params)
            self.prg_cache[prg_str] = library
        prg2 = self.prg_cache[prg_str]

        if hasattr(self, 'total') == False:
            self.total = create_buffer_empty(1*4,self.d,self.params)

        if hasattr(self, 'bias3_temp') == False:
            self.bias3_temp = create_buffer_empty(self.dim*4*4,self.d,self.params)
        if hasattr(self, 'mean') == False:
            self.mean = create_buffer_empty(1*4,self.d,self.params)
        if hasattr(self, 'h_temp') == False:
            self.h_temp = create_buffer_empty(self.dim*4,self.d,self.params)
        if hasattr(self, 'h') == False:
            self.h = create_buffer_empty(self.dim*4,self.d,self.params)

        run(prg,"mm",self.params,[a_g,self.mean],1,ls,self.d)
        run(prg2,"mm1",self.params,[a_g,c_g,d_g,e_g,xqkv_g,keys_values_g,self.xq_temp_g,self.mean],math.ceil(self.dim*3 / ls),ls,self.d)
        run(prg2,"mm2",self.params,[keys_values_g ,self.temp_g, self.xq_temp_g],math.ceil((self.n_heads*(start_pos+1)*(start_pos+1)) / ls),ls,self.d)
        run(prg2,"mm3",self.params,[a_g,keys_values_g,weight_g\
        ,bias_g,weight2_g,bias2_g,weight3_g,bias3_g,weight4_g,bias4_g,self.temp_g, self.xq_temp_g,self.mean,self.h_temp,self.h],1,ls,self.d)
        run(prg,"mm4",self.params,[a_g,weight2_g,bias2_g,\
        weight3_g,bias3_g,self.mean,self.h_temp,self.h,self.bias3_temp],int(self.dim*4 / ls),ls,self.d)
        run(prg,"mm5",self.params,[a_g,weight4_g,bias4_g,self.h_temp,self.bias3_temp],int(self.dim / ls),ls,self.d)
        return a_g

    def kernel_2(self,x_g,ln_1_weight_g,ln_1_bias_g,attn_weight_g,attn_bias_g,cache_kv_g,attn_c_proj_weight_g,attn_c_proj_bias_g,ln_2_weight_g,ln_2_bias_g,c_fc_weight_g,c_fc_bias_g\
        ,c_proj_weight_g,c_proj_bias_g,num_tokens,max_content,j=0):
        if hasattr(self, 'h_g') == False:
            self.h_g = create_buffer_empty(max_content*self.dim*4,self.d,self.params)
        if hasattr(self, 'h2_g') == False:
            self.h2_g = create_buffer_empty(max_content*self.dim*4,self.d,self.params)
        if hasattr(self, 'xq_g') == False:
            self.xq_g = create_buffer_empty(self.n_heads*64*max_content*4,self.d,self.params)
        if hasattr(self, 'xq_g_temp') == False:
            self.xq_g_temp = create_buffer_empty(self.n_heads*64*max_content*4,self.d,self.params)
        if hasattr(self, 'xv_g') == False:
            self.xv_g = create_buffer_empty(self.n_heads*64*max_content*4,self.d,self.params)
        if hasattr(self, 'c_g') == False:
            self.c_g = create_buffer_empty(self.n_heads*64*max_content*4,self.d,self.params)
        if hasattr(self, 'xqt_g') == False:
            self.xqt_g = create_buffer_empty(self.n_heads*64*max_content*4,self.d,self.params)
        if hasattr(self, 'res_g') == False:
            self.res_g = create_buffer_empty(max_content*self.n_heads*4,self.d,self.params)
        if hasattr(self, 'xqkv_g') == False:
            self.xqkv_g = create_buffer_empty(max_content*self.dim*3*4,self.d,self.params)
        if hasattr(self, 'd_g') == False:
            self.d_g = create_buffer_empty(max_content*self.dim*4*4,self.d,self.params)
        a_rows = num_tokens
        a_cols = 64
        b_rows = self.dim
        ls = 256
        size = self.dim
        seg = int(size / ls) #todo
        b_cols = self.dim*3 # for first part
        b_cols_2 = self.dim*4
        prg_str = f"""
        {kernel_prefix[self.d]}
        {func_dec[self.d]} void mm(
            {var_dec[self.d]} float *x, {var_dec[self.d]} const float *weight, {var_dec[self.d]} const float *bias,
            {var_dec[self.d]} float *copy{uint3_arg[self.d]})
        {{
            {local_var[self.d]} float temp[{ls}];
            {local_var[self.d]} float temp2[{num_tokens}];
            int gidx0 = {global_idx[self.d]};
            int lidx0 = gidx0 % {ls};
            int r = gidx0 / {ls};
            temp2[r] = 0;
            for(int i = 0; i < {seg}; i++) {{
                copy[{self.dim}*r + lidx0*{seg} + i] = x[{self.dim}*r + lidx0*{seg} + i];
                temp2[r] += x[{self.dim}*r + lidx0*{seg} + i];
            }}
            temp[lidx0] = temp2[r];
            {barrier[self.d]}
            if(lidx0<{num_tokens}) {{
                temp2[lidx0] = 0;
                for(int i = 0; i < {ls}; i++) {{
                    temp2[lidx0] += temp[i];
                }}
                temp2[lidx0] = temp2[lidx0] / {size};
            }}
            {barrier[self.d]}
            for(int i = 0; i < {seg}; i++) {{
                x[{self.dim}*r + i + lidx0*{seg}] -= temp2[r];
            }}
            {barrier[self.d]}
            temp2[r] = 0;
            for(int i = 0; i < {seg}; i++) {{
                temp2[r] += pow(x[{self.dim}*r + lidx0*{seg} + i],2.0);
            }}
            temp[lidx0] = temp2[r];
            {barrier[self.d]}
            if(lidx0<{num_tokens}) {{
                temp2[lidx0] = 0;
                for(int i = 0; i < {ls}; i++) {{
                    temp2[lidx0] += temp[i];
                }}
                temp2[lidx0] = pow(temp2[lidx0] / {size} + 1e-5,0.5);
            }}
            {barrier[self.d]}
            for(int i = 0; i < {seg}; i++) {{
                x[{self.dim}*r + i + lidx0*{seg}] = (x[{self.dim}*r + i + lidx0*{seg}] * weight[i + lidx0*{seg}]) / temp2[r] + bias[i + lidx0*{seg}];
            }}
        }}
        {func_dec[self.d]} void mm2(
            {var_dec[self.d]} const float *x, {var_dec[self.d]} const float *attn_weight, {var_dec[self.d]} const float *attn_bias,{var_dec[self.d]} float *res{uint3_arg[self.d]})
        {{
            int gidx0 = {global_idx[self.d]};
            if(gidx0 < {b_cols*num_tokens}) {{
            int i = gidx0 / {num_tokens};
            int y = gidx0 % {num_tokens};
            float total = 0;
            for(int k = 0; k < {b_rows}; k++) {{
                total += x[y*{b_rows} + k] * attn_weight[i*{b_rows} + k];
            }}
            res[y*{b_cols} + i] = total + attn_bias[i];
            }}
        }}
        {func_dec[self.d]} void mm3(
            {var_dec[self.d]} const float *xqkv, {var_dec[self.d]} float *new_cache{uint3_arg[self.d]})
        {{
            int gidx0 = {global_idx[self.d]};
            int i = gidx0 / {self.n_heads*64};
            int j = gidx0 % {self.n_heads*64};
            new_cache[i*{self.n_heads*64} + j] = xqkv[i*{self.n_heads*64*3} + {self.dim}*1 + j];
            new_cache[{max_content}*{self.n_heads*64} + i*{self.n_heads*64} + j] = xqkv[i*{self.n_heads*64*3} + {self.dim}*2 + j];
        }}
        {func_dec[self.d]} void tr(
            {var_dec[self.d]} const float *xqkv, {var_dec[self.d]} float *xq, {var_dec[self.d]} float *xv{uint3_arg[self.d]})
        {{
            int gidx0 = {global_idx[self.d]};
            int i = (gidx0 / {64}) / {num_tokens};
            int j = (gidx0 / {64}) % {num_tokens};
            int k = gidx0 % 64;
            xq[i*{num_tokens}*64 + j*64 + k] = xqkv[i*64 + j*{64*self.n_heads*3} + k];
            xv[i*{num_tokens}*64 + j*64 + k] = xqkv[i*64 + j*{64*self.n_heads*3} + k + {64*self.n_heads*2}];
        }}
        {func_dec[self.d]} void ms0(
            {var_dec[self.d]} float *xq_temp, {var_dec[self.d]} const float *xq, {var_dec[self.d]} const float *xqkv{uint3_arg[self.d]})
        {{
                int gidx0 = {global_idx[self.d]};
                if(gidx0 < {self.n_heads*a_rows*a_rows}) {{
                int x = (gidx0 / {a_rows}) % {a_rows};
                int z = gidx0 / ({a_rows}*{a_rows});
                int y = gidx0 % {a_rows};
                float total = 0;
                for(int k = 0; k < {a_cols}; k++) {{
                    total += xq[y*{a_cols} + k + z*{a_rows}*{a_cols}] * xqkv[x*{64*self.n_heads*3} + k + z*64 + {self.dim}];
                }}
                xq_temp[y*{a_rows} + x + z*{a_rows}*{a_rows}] = total / 8; //sqrt 64 input shape xq
                }}
        }}
        {func_dec[self.d]} void ms(
            {var_dec[self.d]} float *xq{uint3_arg[self.d]})
        {{
        int gidx0 = {global_idx[self.d]};
        if(gidx0 < {self.n_heads*num_tokens*num_tokens}) {{
            int x = (gidx0 / {num_tokens}) / {num_tokens};
            int y = (gidx0 / {num_tokens}) % {num_tokens};
            int z = gidx0 % {num_tokens};
            if(z > y) {{ //todo, this can probably be 2x faster
                xq[x*{num_tokens}*{num_tokens} + y*{num_tokens} + z] = -INFINITY;
            }}
            xq[x*{num_tokens}*{num_tokens} + y*{num_tokens} + z] = exp(xq[x*{num_tokens}*{num_tokens} + y*{num_tokens} + z]);
        }}
        }}
        {func_dec[self.d]} void ms3(
            {var_dec[self.d]} const float *xq, {var_dec[self.d]} float *mx{uint3_arg[self.d]})
        {{
        int gidx0 = {global_idx[self.d]};
        if(gidx0 < {num_tokens*self.n_heads}) {{
        int x = gidx0 / {num_tokens};
        int y = gidx0 % {num_tokens};
            float m = 0;
            for(int z = 0; z < {num_tokens}; z++) {{
                m += xq[x*{num_tokens}*{num_tokens} + y*{num_tokens} + z];
            }}
            mx[x*{num_tokens} + y] = m;
            }}
        }}
        {func_dec[self.d]} void ms4(
            {var_dec[self.d]} float *xq, {var_dec[self.d]} const float *mx{uint3_arg[self.d]})
        {{
        int gidx0 = {global_idx[self.d]};
        if(gidx0 < {num_tokens*num_tokens*self.n_heads}) {{
            int x = (gidx0 / {num_tokens}) / {num_tokens};
            int y = (gidx0 / {num_tokens}) % {num_tokens};
            int z = gidx0 % {num_tokens};
            xq[x*{num_tokens}*{num_tokens} + y*{num_tokens} + z] /= mx[x*{num_tokens} + y];
        }}
        }}
        {func_dec[self.d]} void ms5(
            {var_dec[self.d]} const float *xq, {var_dec[self.d]} const float *xv, {var_dec[self.d]} float *res{uint3_arg[self.d]})
        {{
            int gidx0 = {global_idx[self.d]};
            int z = (gidx0 / {num_tokens}) / {a_cols};
            int x = (gidx0 / {num_tokens}) % {a_cols};
            int y = gidx0 % {num_tokens};
            float total = 0;
            for(int k = 0; k < {num_tokens}; k++) {{
                total += xq[y*{num_tokens} + k + z*{num_tokens}*{num_tokens}] * xv[x + k*{a_cols} + z*{num_tokens}*{a_cols}];
            }}
            res[y*{a_cols} + x + z*{a_cols}*{num_tokens}] = total;
        }}
        {func_dec[self.d]} void ms6( //transpose
            {var_dec[self.d]} const float *xq, {var_dec[self.d]} float *xqt{uint3_arg[self.d]})
        {{
            int gidx0 = {global_idx[self.d]};
            int i = (gidx0 / 64) / {num_tokens};
            int j = (gidx0 / 64) % {num_tokens};
            int k = gidx0 % 64;
            xqt[i*64 + j*{self.n_heads*64} + k] = xq[i*{num_tokens}*64 + j*64 + k];
        }}
        {func_dec[self.d]} void ms7(
            {var_dec[self.d]} const float *xq, {var_dec[self.d]} const float *attn_weight,{var_dec[self.d]} const float *attn_bias, {var_dec[self.d]} float *res{uint3_arg[self.d]})
        {{
            if({global_idx[self.d]} < {b_rows*num_tokens}) {{ //TODO don't allow larger local size? wasteful
            int gidx0 = {global_idx[self.d]};
            int x = gidx0 / {num_tokens};
            int y = gidx0 % {num_tokens};
            float total = 0;
            for(int k = 0; k < {b_rows}; k++) {{
                total += xq[y*{b_rows} + k] * attn_weight[x*{b_rows} + k];
            }}
            res[y*{b_rows} + x] += total + attn_bias[x];
            }}
        }}

        {func_dec[self.d]} void ms8( //TODO, there are other kernels like this to fix
            {var_dec[self.d]} float *x, {var_dec[self.d]} const float *ln_2_weight, {var_dec[self.d]} const float *ln_2_bias
            ,{var_dec[self.d]} float *copy{uint3_arg[self.d]})
        {{
            {local_var[self.d]} float temp[{ls}];
            {local_var[self.d]} float total;
            int gidx0 = {global_idx[self.d]};
            int lidx0 = gidx0 % {ls};
            int r = gidx0 / {ls}; //todo clean
            total = 0;
            for(int i = 0; i < {seg}; i++) {{
                copy[{self.dim}*r + lidx0*{seg} + i] = x[{self.dim}*r + lidx0*{seg} + i];
                total += x[{self.dim}*r + lidx0*{seg} + i];
            }}
            temp[lidx0] = total;
            {barrier[self.d]}
            if(lidx0<{num_tokens}) {{
                total = 0;
                for(int i = 0; i < {ls}; i++) {{
                    total += temp[i];
                }}
            }}
            {barrier[self.d]}
            for(int i = 0; i < {seg}; i++) {{
                x[{self.dim}*r + i + lidx0*{seg}] -= total / {size};
            }}
            {barrier[self.d]}
            total = 0;
            for(int i = 0; i < {seg}; i++) {{
                total += pow(x[{self.dim}*r + lidx0*{seg} + i],2);
            }}
            temp[lidx0] = total;
            {barrier[self.d]}
            if(lidx0<{num_tokens}) {{
                total = 0;
                for(int i = 0; i < {ls}; i++) {{
                    total += temp[i];
                }}
                total = pow(total / {size} + 1e-5,0.5);
            }}
            {barrier[self.d]}
            for(int i = 0; i < {seg}; i++) {{
                x[{self.dim}*r + i + lidx0*{seg}] = (x[{self.dim}*r + i + lidx0*{seg}] * ln_2_weight[i + lidx0*{seg}]) / total + ln_2_bias[i + lidx0*{seg}];
            }}
        }}

        {func_dec[self.d]} void ms9(
            {var_dec[self.d]} const float *a, {var_dec[self.d]} const float *c_fc_weight,{var_dec[self.d]} const float *c_fc_bias, {var_dec[self.d]} float *res{uint3_arg[self.d]})
        {{
            int gidx0 = {global_idx[self.d]};
            int x = gidx0 / {num_tokens};
            int y = gidx0 % {num_tokens};
            float total = 0;
            for(int k = 0; k < {b_rows}; k++) {{
                total += a[y*{b_rows} + k] * c_fc_weight[x*{b_rows} + k];  //TODO A LEADS TO NANs
            }}
            float tth = (total + c_fc_bias[x]) * 0.7978845608\
                * (1 + 0.044715 * pow((total + c_fc_bias[x]),2));
            float th = tanh(tth);
            if(isnan(th) && tth < 0) {{
                th = -1;
            }}
            if(isnan(th) && tth >= 0) {{
                th = 1;
            }}
            res[y*{b_cols_2} + x] = 0.5 * (total + c_fc_bias[x])\
                * (1 + th);
        }}
        {func_dec[self.d]} void ms10(
            {var_dec[self.d]} const float *a, {var_dec[self.d]} const float *c_proj_weight,{var_dec[self.d]} const float *c_proj_bias, {var_dec[self.d]} float *res{uint3_arg[self.d]})
        {{
            if({global_idx[self.d]} < {b_rows*num_tokens}) {{ //TODO, wasteful?
            int gidx0 = {global_idx[self.d]};
            int x = gidx0 / {num_tokens};
            int y = gidx0 % {num_tokens};
            float total = 0;
            for(int k = 0; k < {b_cols_2}; k++) {{
                total += a[y*{b_cols_2} + k] * c_proj_weight[x*{b_cols_2} + k];
            }}
            res[y*{b_rows} + x] += total + c_proj_bias[x];
            }}
        }}
        """
        #if prg_str not in self.prg_cache: TODO, why does caching this change the output?
        #    library = compile(prg_str,self.d,self.params)
        #    self.prg_cache[prg_str] = library
        #prg = self.prg_cache[prg_str]

        prg = compile(prg_str,self.d,self.params)
        run(prg,"mm",self.params,[x_g,ln_1_weight_g,ln_1_bias_g,self.h_g],num_tokens,ls,self.d)
        print("output =",x_g.np()[0:10])
        run(prg,"mm2",self.params,[x_g,attn_weight_g,attn_bias_g,self.xqkv_g],math.ceil(b_cols*num_tokens / ls),ls,self.d)
        run(prg,"mm3",self.params,[self.xqkv_g, cache_kv_g],math.ceil((num_tokens*self.n_heads*64) / ls),ls,self.d)
        run(prg,"tr",self.params,[self.xqkv_g, self.xq_g, self.xv_g],math.ceil((num_tokens*self.n_heads*64) / ls),ls,self.d)
        run(prg,"ms0",self.params,[self.xq_g_temp,self.xq_g, self.xqkv_g],math.ceil(self.n_heads*num_tokens*num_tokens/ls),ls,self.d)
        run(prg,"ms",self.params,[self.xq_g_temp],math.ceil(self.n_heads*num_tokens*num_tokens / ls),ls,self.d)
        run(prg,"ms3",self.params,[self.xq_g_temp,self.res_g],math.ceil(self.n_heads*num_tokens/ls),min(self.n_heads*num_tokens,ls),self.d)
        run(prg,"ms4",self.params,[self.xq_g_temp,self.res_g],math.ceil(self.n_heads*num_tokens*num_tokens / ls),ls,self.d)
        run(prg,"ms5",self.params,[self.xq_g_temp,self.xv_g,self.c_g],math.ceil(self.n_heads*a_cols*num_tokens / ls),ls,self.d)
        run(prg,"ms6",self.params,[self.c_g,self.xqt_g],math.ceil(num_tokens*self.n_heads*64 / ls),ls,self.d)
        run(prg,"ms7",self.params,[self.xqt_g,attn_c_proj_weight_g,attn_c_proj_bias_g,self.h_g],math.ceil(b_rows*num_tokens / ls),ls,self.d)
        run(prg,"ms8",self.params,[self.h_g, ln_2_weight_g, ln_2_bias_g,self.h2_g],num_tokens,ls,self.d)
        run(prg,"ms9",self.params,[self.h_g, c_fc_weight_g,c_fc_bias_g,self.d_g],math.ceil(b_cols_2*num_tokens / ls),ls,self.d)
        run(prg,"ms10",self.params,[self.d_g, c_proj_weight_g,c_proj_bias_g,self.h2_g],math.ceil(b_rows*num_tokens / ls) ,ls,self.d)
        return self.h2_g

    def time_it(func,a,b,i=100):
        f = None
        total_time = 0
        for _ in range(i):
            st = time.perf_counter()
            ret = func(a,b)
            t = time.perf_counter() - st
            total_time += t
            if f is None or t < f:
                f = t
        return ret,f


In [None]:
!pip install tinygrad
!pip install requests

In [None]:
#!/usr/bin/env python3 #for tinygrad repo, get rid of libs etc
# can I beat https://github.com/jaymody/xpicoGPT.git?
# beating https://github.com/WAUthethird/stupidGPT should be easy
from typing import Union, Tuple
from tqdm import trange
import numpy as np
import os
import pickle
from tinygrad.nn.state import torch_load
from tinygrad.helpers import fetch
from transformers import AutoModelForCausalLM, AutoTokenizer
import requests
d = "CUDA"
folder = ""
'''
try:
  import pyopencl as cl
  d = "OpenCL"
except ImportError:
   pass

try:
  import Metal
  import metal_kernels_large
  d = "Metal"
  folder = "metal/"
  print("Using Metal")
except ImportError:
  pass

try:
  import pycuda.driver as cuda
  import pycuda.autoinit
  from pycuda.compiler import SourceModule
  d = "CUDA"
except ImportError:
   pass
'''

d = "CUDA"
if d == "Metal":
    device = Metal.MTLCreateSystemDefaultDevice()
    queue = device.newCommandQueue()
    params = {"queue":queue,"device":device}
if d == "OpenCL":
    platform = cl.get_platforms()
    my_gpu_devices = platform[0].get_devices(device_type=cl.device_type.GPU)
    ctx = cl.Context(devices=my_gpu_devices)
    mf = cl.mem_flags
    params = {"ctx":ctx,"mf":mf,"queue":cl.CommandQueue(ctx)}
if d == "CUDA":
    params = None

#https://raw.githubusercontent.com/roryclear/transformer/main/tokens.txt



with open("tokens2.txt", "wb") as file:
  file.write(requests.get("https://raw.githubusercontent.com/roryclear/transformer/main/tokens.txt").content)
tokens = open('tokens2.txt','r',encoding="utf-8").readlines()

#tokens = open('tokens.txt','r',encoding="utf-8").readlines()
token_dict = dict()
max_token_length = -1
for i in range(len(tokens)):
  s = tokens[i].replace("\n","").replace("/n","\n")
  token_dict[s] = i
  if len(s) > max_token_length:
    max_token_length = len(s)
def decode(index):
  ret = ""
  for i in index:
    ret+=tokens[i].replace("\n","").replace("/n","\n") #hack with linebreak
  return ret

def encode(x):
  ret = []
  token = None
  i = -1
  while len(x) > 0:
    token = None
    i = -1
    while token == None:
      i+=1
      s = x[:min(max_token_length,len(x))-i]
      if s in token_dict:
        token = token_dict[s]
    ret.append(token)
    x = x[min(max_token_length,len(x))-i:]
  return ret

  def __call__():
    return None

class Rand:
  def __init__(self):
    self.seed = 420

  def rand(self):
    self.seed += 1
    rng = np.random.default_rng(self.seed)
    rng_np_buffer = rng.random(size=1, dtype=np.float32).astype(dtype=np.float32, copy=False)
    return rng_np_buffer[0]

class Transformer:
  def __init__(self):
    return None

  def to_buffer(self,n_heads,dim):
    self.n_heads = n_heads
    self.dim = dim

    print("copying ln_1_weight")
    for i in range(len(self.ln_1_weight)):
      self.ln_1_weight[i] = create_buffer(self.ln_1_weight[i],d,params)

    print("copying ln_1_bias")
    for i in range(len(self.ln_1_weight)):
      self.ln_1_bias[i] = create_buffer(self.ln_1_bias[i],d,params)

    print("copying attn_c_attn_weight")
    for i in range(len(self.ln_1_weight)):
      self.attn_c_attn_weight[i] = create_buffer(self.attn_c_attn_weight[i].transpose(1,0).flatten(),d,params)

    print("copying attn_c_attn_bias")
    for i in range(len(self.ln_1_weight)):
      self.attn_c_attn_bias[i] = create_buffer(self.attn_c_attn_bias[i],d,params)

    print("copying attn_c_proj_bias")
    for i in range(len(self.ln_1_weight)):
      self.attn_c_proj_bias[i] = create_buffer(self.attn_c_proj_bias[i],d,params)

    print("copying ln_2_weight")
    for i in range(len(self.ln_1_weight)):
      self.ln_2_weight[i] = create_buffer(self.ln_2_weight[i],d,params)

    print("copying ln_2_bias")
    for i in range(len(self.ln_1_weight)):
      self.ln_2_bias[i] = create_buffer(self.ln_2_bias[i],d,params)

    print("copying mlp_c_fc_bias")
    for i in range(len(self.ln_1_weight)):
      self.mlp_c_fc_bias[i] = create_buffer(self.mlp_c_fc_bias[i],d,params)

    print("copying mlp_c_proj_bias")
    for i in range(len(self.ln_1_weight)):
      self.mlp_c_proj_bias[i] = create_buffer(self.mlp_c_proj_bias[i],d,params)

    print("copying ln_f_weight")
    self.ln_f_weight = create_buffer(self.ln_f_weight,d,params)

    print("copying ln_f_bias")
    self.ln_f_bias = create_buffer(self.ln_f_bias,d,params)

    print("copying mlp_c_fc_weight")
    for i in range(len(self.ln_1_weight)):
      self.mlp_c_fc_weight[i] = create_buffer(self.mlp_c_fc_weight[i].transpose(1,0).flatten(),d,params)

    print("copying lm_head_weight_unf")
    self.lm_head_weight_unf = create_buffer(self.lm_head_weight.transpose(),d,params)

    print("copying lm_head_weight")
    self.lm_head_weight = create_buffer(self.lm_head_weight.flatten(),d,params)

    print("copying self_wte_weight")
    self.wte_weight = create_buffer(self.wte_weight.astype(np.float32),d,params)

    print("copying self_wpe_weight")
    self.wpe_weight = create_buffer(self.wpe_weight,d,params)

    print("creating attn_cache_kv")
    self.attn_cache_kv = []
    for i in range(len(self.ln_1_weight)):
      self.attn_cache_kv.append(create_buffer_empty(2*MAX_CONTEXT*n_heads*64*4,d,params))

    if d == "OpenCL" or d=="CUDA":
      print("copying attn_c_proj_weight") #TODO
      self.attn_c_proj_weight2 = []
      for i in range(len(self.ln_1_weight)):
        self.attn_c_proj_weight2.append(create_buffer(np.asfortranarray(self.attn_c_proj_weight[i]),d,params))
        self.attn_c_proj_weight[i] = create_buffer(self.attn_c_proj_weight[i].flatten(),d,params)

      print("copying mlp_c_proj_weight_unf") #TODO
      self.mlp_c_proj_weight_unf = []
      for i in range(len(self.ln_1_weight)):
        self.mlp_c_proj_weight_unf.append(create_buffer(np.asfortranarray(self.mlp_c_proj_weight[i]),d,params))
        self.mlp_c_proj_weight[i] = create_buffer(self.mlp_c_proj_weight[i].flatten(),d,params)
      return

    if d == "Metal":
      print("copying attn_c_proj_weight") #TODO
      self.attn_c_proj_weight2 = []
      for i in range(len(self.ln_1_weight)):
        self.attn_c_proj_weight2.append(create_buffer(np.asfortranarray(self.attn_c_proj_weight[i].transpose()),d,params))
        self.attn_c_proj_weight[i] = create_buffer(self.attn_c_proj_weight[i],d,params)

      print("copying mlp_c_proj_weight_unf") #TODO
      self.mlp_c_proj_weight_unf = []
      for i in range(len(self.ln_1_weight)):
        self.mlp_c_proj_weight_unf.append(create_buffer(self.mlp_c_proj_weight[i].transpose(),d,params))
        self.mlp_c_proj_weight[i] = create_buffer(self.mlp_c_proj_weight[i].flatten(),d,params)

  def forward(self, tokens, start_pos, temperature:float=0.8,n_tokens=444):
    if start_pos > 0:
      h = metalk.add(self.wte_weight,self.wpe_weight,start_pos,tokens[0])
      attn_dim = self.dim
      for i in range(0,len(self.ln_1_weight)):
        h = metalk.kernel_0(h,self.ln_1_weight[i],\
        self.ln_1_bias[i],self.attn_c_attn_weight[i],\
        self.attn_c_attn_bias[i],\
        self.attn_cache_kv[i],\
        self.attn_c_proj_weight[i],self.attn_c_proj_bias[i],\
        self.ln_2_weight[i], self.ln_2_bias[i],\
        self.mlp_c_fc_weight[i],self.mlp_c_fc_bias[i],\
        self.mlp_c_proj_weight[i],self.mlp_c_proj_bias[i],start_pos,attn_dim,i)
      unif_samples = rand.rand()
      ret = metalk.kernel_1(h,self.ln_f_weight, self.ln_f_bias,self.lm_head_weight,temperature,unif_samples).astype(np.int32)[0]
      return ret
    else:
      x = metalk.tok_emb(tokens,self.wte_weight,self.wpe_weight,n_tokens)
      for i in range(len(self.ln_1_weight)-1):
        x = metalk.kernel_2(x,self.ln_1_weight[i], self.ln_1_bias[i],self.attn_c_attn_weight[i],self.attn_c_attn_bias[i],self.attn_cache_kv[i],self.attn_c_proj_weight2[i],self.attn_c_proj_bias[i],self.ln_2_weight[i], self.ln_2_bias[i],\
        self.mlp_c_fc_weight[i],self.mlp_c_fc_bias[i],self.mlp_c_proj_weight_unf[i],self.mlp_c_proj_bias[i],n_tokens,MAX_CONTEXT,i)
    unif_samples = rand.rand()
    ret = metalk.kernel_3(x,self.ln_1_weight[-1], self.ln_1_bias[-1],self.attn_c_attn_weight[-1],self.attn_c_attn_bias[-1],self.attn_cache_kv[-1]\
    ,self.ln_f_weight, self.ln_f_bias,n_tokens,MAX_CONTEXT,self.lm_head_weight_unf,temperature,unif_samples).astype(np.int32)[0]
    return ret

  def __call__(self, tokens, start_pos, temperature:np.float32=0.0,n_tokens=1):
    return self.forward(tokens, start_pos, temperature,n_tokens)

def delete_buffers(m): #TODO, do this with a loop
    m.wpe_weight.delete()
    m.ln_f_weight.delete()
    m.ln_f_bias.delete()
    m.wte_weight.delete()
    m.lm_head_weight.delete()
    m.lm_head_weight_unf.delete()
    for x in range(len(m.ln_1_bias)): #TODO
      m.mlp_c_proj_bias[x].delete()
      m.mlp_c_proj_weight[x].delete()
      m.mlp_c_proj_weight_unf[x].delete()
      m.mlp_c_fc_bias[x].delete()
      m.mlp_c_fc_weight[x].delete()
      m.attn_c_proj_bias[x].delete()
      m.attn_c_proj_weight[x].delete()
      m.attn_c_proj_weight2[x].delete()
      m.attn_c_attn_bias[x].delete()
      m.attn_c_attn_weight[x].delete()
      m.ln_1_bias[x].delete()
      m.ln_1_weight[x].delete()
      m.ln_2_bias[x].delete()
      m.ln_2_weight[x].delete()
      m.attn_cache_kv[x].delete()

VOCAB_SIZE = 50257
class GPT2:
  @staticmethod
  def build():
    model = Transformer(n_layers=12,n_heads=12,dim=768,norm_eps=1e-5,vocab_size=VOCAB_SIZE) #small
    return GPT2(model)

  def __init__(self, model):
    self.model = model

  def generate(self, prompt:str, max_length:int, temperature:float, timing:bool=False, batch_size:int=1,expected_tokens=None):
    toks = encode(prompt)
    start_pos = 0
    n_tokens = len(toks)
    for _ in trange(max_length, disable=(timing==True)):
      if batch_size == 1 and len(toks[start_pos:]) == 1:
        tokens = np.array([toks[start_pos]])
      else:
        tokens = np.array(toks)
      tok = self.model(tokens, start_pos, temperature, n_tokens).tolist()
      start_pos = len(toks)
      if expected_tokens != None: #TODO REMOVE 13
        np.testing.assert_equal(tok,expected_tokens[start_pos-n_tokens])
      toks.append(tok)
    return decode(toks)


def get_model(model_size):
  gpt2_blank = GPT2(None)
  gpt2_blank.model = Transformer()
  n_layers = {"gpt2":12,"gpt2-medium":24,"gpt2-large":36,"gpt2-xl":48}
  model = AutoModelForCausalLM.from_pretrained(model_size)
  print(type(model))
  print("converting wpe_weight")
  gpt2_blank.model.wpe_weight = model.transformer.wpe.weight.detach().cpu().numpy().astype(np.float32)

  print("converting ln_f.weight")
  gpt2_blank.model.ln_f_weight = model.transformer.ln_f.weight.detach().cpu().numpy().astype(np.float32)

  print("converting ln_f.bias")
  gpt2_blank.model.ln_f_bias = model.transformer.ln_f.bias.detach().cpu().numpy().astype(np.float32)

  print("converting mlp_c_proj.bias")
  gpt2_blank.model.mlp_c_proj_bias = []
  for x in range(n_layers[model_size]):
      gpt2_blank.model.mlp_c_proj_bias.append(model.transformer.h[x].mlp.c_proj.bias.detach().cpu().numpy().astype(np.float32))

  print("converting mlp_c_proj.weight")
  gpt2_blank.model.mlp_c_proj_weight = []
  for x in range(n_layers[model_size]):
      gpt2_blank.model.mlp_c_proj_weight.append(model.transformer.h[x].mlp.c_proj.weight.detach().cpu().numpy().astype(np.float32))

  print("converting mlp_c_fc.bias")
  gpt2_blank.model.mlp_c_fc_bias = []
  for x in range(n_layers[model_size]):
      gpt2_blank.model.mlp_c_fc_bias.append(model.transformer.h[x].mlp.c_fc.bias.detach().cpu().numpy().astype(np.float32))

  print("converting mlp_c_fc.weight")
  gpt2_blank.model.mlp_c_fc_weight = []
  for x in range(n_layers[model_size]):
      gpt2_blank.model.mlp_c_fc_weight.append(model.transformer.h[x].mlp.c_fc.weight.detach().cpu().numpy().astype(np.float32))

  print("converting attn_c_proj.bias")
  gpt2_blank.model.attn_c_proj_bias = []
  for x in range(n_layers[model_size]):
      gpt2_blank.model.attn_c_proj_bias.append(model.transformer.h[x].attn.c_proj.bias.detach().cpu().numpy().astype(np.float32))

  print("converting attn_c_proj.weight")
  gpt2_blank.model.attn_c_proj_weight = []
  for x in range(n_layers[model_size]):
      gpt2_blank.model.attn_c_proj_weight.append(model.transformer.h[x].attn.c_proj.weight.detach().cpu().numpy().astype(np.float32))

  print("converting attn_c_attn.bias")
  gpt2_blank.model.attn_c_attn_bias = []
  for x in range(n_layers[model_size]):
      gpt2_blank.model.attn_c_attn_bias.append(model.transformer.h[x].attn.c_attn.bias.detach().cpu().numpy().astype(np.float32))

  print("converting attn_c_attn.weight")
  gpt2_blank.model.attn_c_attn_weight = []
  for x in range(n_layers[model_size]):
      gpt2_blank.model.attn_c_attn_weight.append(model.transformer.h[x].attn.c_attn.weight.detach().cpu().numpy().astype(np.float32))

  print("converting ln_1.bias")
  gpt2_blank.model.ln_1_bias = []
  for x in range(n_layers[model_size]):
      gpt2_blank.model.ln_1_bias.append(model.transformer.h[x].ln_1.bias.detach().cpu().numpy().astype(np.float32))

  print("converting ln_1.weight")
  gpt2_blank.model.ln_1_weight = []
  for x in range(n_layers[model_size]):
      gpt2_blank.model.ln_1_weight.append(model.transformer.h[x].ln_1.weight.detach().cpu().numpy().astype(np.float32))

  print("converting ln_2.bias")
  gpt2_blank.model.ln_2_bias = []
  for x in range(n_layers[model_size]):
      gpt2_blank.model.ln_2_bias.append(model.transformer.h[x].ln_2.bias.detach().cpu().numpy().astype(np.float32))

  print("converting ln_2.weight")
  gpt2_blank.model.ln_2_weight = []
  for x in range(n_layers[model_size]):
      gpt2_blank.model.ln_2_weight.append(model.transformer.h[x].ln_2.weight.detach().cpu().numpy().astype(np.float32))

  print("converting wte.weight")
  gpt2_blank.model.wte_weight = model.transformer.wte.weight.detach().cpu().numpy().astype(np.float32)

  print("converting lm_head.weight")
  gpt2_blank.model.lm_head_weight = model.lm_head.weight.detach().cpu().numpy().astype(np.float32).transpose(1,0)

  with open(folder+model_size+".pickle", 'wb') as outp:
      pickle.dump(gpt2_blank, outp)

# **** main code ****

if __name__ == "__main__":
  rand = Rand()

default_prompt = "What is the answer to life, the universe, and everything?"
#default_prompt = "What happened in 1939?"
# should output:
# .... The Jewish people rejected

#(tg random) should output:
#It was a very fateful day.
#When the Nazis occupied Poland in 1939....

np.random.seed(28)

expected_tokens = [198, 198, 1532, 345, 547, 281, 48782,\
  893, 48187, 11, 393, 655, 257, 33013, 11, 534, 3280,\
  1244, 307, 257, 1643, 1180, 13, 1114, 530, 11, 345,\
  1244, 1011, 257, 2392, 1570, 286, 262, 6881, 13,\
  887, 329, 584, 661, 851, 1390, 5519, 11, 7912,\
  11, 290, 584, 287, 12, 14108, 12, 2550, 661, 851,\
  534, 3280, 1244, 307, 1290, 517, 588, 25, 5155, 1595,\
  470, 2152, 379, 477, 13, 198, 198, 25153, 345, 389, 257,\
  1862, 1048, 508, 655, 18303, 422, 3504, 1524, 290, 468,\
  1239, 1107, 19189, 257, 3451, 287, 48782, 23154, 13, 921, 821, 319, 281, 3624]

#TODO, this is the current, but wrong output
expected_tokens_b = [198, 198, 1026, 373, 257, 845, 46873, 1110, 13, 198,
198, 2215, 262, 19147, 12030, 12873, 287, 24414, 11, 262,
6771, 547, 407, 3142, 284, 670, 287, 262, 17590, 11,
645, 2300, 703, 881, 484, 2227, 284, 13, 383, 1917,
2627, 1598, 618, 262, 5103, 1664, 286, 262, 309, 9116,
4623, 268, 4618, 11, 543, 925, 281, 3113, 329, 262,
11908, 12, 1273, 14414, 41460, 11, 3414, 617, 19008, 284,
262, 24830, 4572, 13, 198, 198, 464, 4479, 286, 7570,
21773, 2066, 82, 1908, 734, 11628, 284, 262, 31062, 13,
1881, 373, 1912, 319, 257, 1080, 286, 2829, 12370, 4827]

expected_tokens_med = [198, 198, 1544, 468, 262, 2694, 290, 262, 481, 284, 3853, 475, 339, 2391,\
  2314, 2222, 2241, 284, 466, 340, 13, 679, 318, 7787, 284, 307, 3436, 290,\
  7787, 284, 2222, 1854, 656, 340, 13, 679, 318, 7787, 284, 307, 33046, 290,\
  7787, 284, 307, 8606, 13, 198, 198, 4864, 11, 339, 318, 407, 3436, 287,\
  465, 3252, 286, 5287, 13, 198, 198, 22210, 4952, 502, 326, 3252, 2125,\
  470, 262, 6808, 2728, 286, 262, 1917, 13, 198, 198, 2025, 37848, 284,\
  674, 10251, 481, 1282, 611, 356, 12553, 262, 4950, 2000, 1176, 356,\
  423, 13, 198, 198, 2215, 345]

expected_tokens_large = [198, 198, 1532, 345, 550, 257, 40663, 11, 345, 561,
2192, 1382, 340, 656, 262, 6766, 13, 2293, 477, 11,
345, 714, 655, 4829, 262, 1468, 2272, 18556, 656, 262,
8137, 290, 1956, 340, 7382, 13, 198, 198, 1537, 326,
40663, 561, 779, 262, 976, 3716, 1080, 284, 366, 33327,
1, 262, 3404, 319, 262, 4417, 286, 262, 5440, 13,
921, 460, 470, 2824, 3404, 416, 17997, 3404, 656, 262,
1633, 960, 14108, 40663, 318, 517, 588, 257, 16285, 508,
46561, 262, 3404, 510, 422, 262, 2323, 13, 5455, 11,
345, 561, 7925, 1657, 12, 6551, 5696, 422, 262, 3668]

#a = create_buffer_empty(1*4,d,params) #TODO can't run medium in isolation without doing this first?
rand = Rand()
MAX_CONTEXT = len(encode(default_prompt))+100
metalk = Kernels(dim=768,n_heads=12,max_context=MAX_CONTEXT,device=d)
if os.path.exists(folder+"gpt2.pickle") == False:
  get_model("gpt2")
filehandler = open(folder+"gpt2.pickle", 'rb')
gpt2 = pickle.load(filehandler)
gpt2.model.to_buffer(12,768)
text = gpt2.generate(prompt=default_prompt, max_length=100, temperature=np.float32(0.8), timing=None, batch_size=1,expected_tokens=None)
print((f"Response:", "green"), text)
delete_buffers(gpt2.model)

rand = Rand()
MAX_CONTEXT = len(encode("What happened in 1939?"))+100
metalk = Kernels(dim=768,n_heads=12,max_context=MAX_CONTEXT,device=d)
filehandler = open(folder+"gpt2.pickle", 'rb')
gpt2 = pickle.load(filehandler)
gpt2.model.to_buffer(12,768)
text = gpt2.generate(prompt="What happened in 1939?", max_length=100, temperature=np.float32(0.8), timing=None, batch_size=1,expected_tokens=None)
print((f"Response:", "green"), text)
delete_buffers(gpt2.model)

MAX_CONTEXT = len(encode(default_prompt))+100
metalk = Kernels(dim=1024,n_heads=16,max_context=MAX_CONTEXT,device=d)
if os.path.exists(folder+"gpt2-medium.pickle") == False:
  get_model(folder+"gpt2-medium")
filehandler = open(folder+"gpt2-medium.pickle", 'rb')
gpt2 = pickle.load(filehandler)
#gpt2.model.to_buffer2()
gpt2.model.to_buffer(16,1024)
rand = Rand()
text = gpt2.generate(prompt=default_prompt, max_length=100, temperature=np.float32(0.8), timing=None, batch_size=1,expected_tokens=None)
print((f"Response:", "green"), text)
delete_buffers(gpt2.model)

MAX_CONTEXT = len(encode(default_prompt))+100
dim = 1280
n_heads = 20
rand = Rand()
metalk = Kernels(dim=1280,n_heads=20,max_context=MAX_CONTEXT,device=d)
if os.path.exists(folder+"gpt2-large.pickle") == False:
  get_model("gpt2-large")
filehandler = open(folder+"gpt2-large.pickle", 'rb')
gpt2 = pickle.load(filehandler)
gpt2.model.to_buffer(20,1280)
text = gpt2.generate(prompt=default_prompt, max_length=100, temperature=np.float32(0.8), timing=None, batch_size=1,expected_tokens=None)
print((f"Response:", "green"), text)
delete_buffers(gpt2.model)

if d == "Metal":
  MAX_CONTEXT = len(encode(default_prompt))+100
  dim = 1280
  n_heads = 20
  metalk = metal_kernels_large.Metal_Kernels(dim=1280,n_heads=20,max_context=MAX_CONTEXT)
  if os.path.exists(folder+"gpt2-large.pickle") == False:
    get_model("gpt2-large")
  filehandler = open(folder+"gpt2-large.pickle", 'rb')
  gpt2 = pickle.load(filehandler)
  gpt2.model.to_buffer(20,1280)
  rand = Rand()
  text = gpt2.generate(prompt=default_prompt, max_length=100, temperature=np.float32(0.8), timing=None, batch_size=1,expected_tokens=None)
  print((f"Response:", "green"), text)
  delete_buffers(gpt2.model)

''' TODO
MAX_CONTEXT = len(encode(default_prompt))+100
dim = 1600
n_heads = 25
metalk = metal_Metal_Kernels(dim=1600,n_heads=25,max_context=MAX_CONTEXT)
if os.path.exists("gpt2-xl.pickle") == False:
  get_model("gpt2-xl")
filehandler = open("gpt2-xl.pickle", 'rb')
gpt2 = pickle.load(filehandler)
gpt2.model.to_buffer(25,1600)
rand = Rand()
text = gpt2.generate(prompt=default_prompt, max_length=100, temperature=np.float32(0.8), timing=None, batch_size=1,expected_tokens=None)
print((f"Response:", "green"), text)
delete_buffers(gpt2.model)
'''