In [1]:
import torch, gc
import time
import math

In [None]:
# strassen CPU attempt

a = torch.randn(2,2)
b = torch.randn(2,2)

# ------

def naive(a,b):
  n = int(len(a))
  c = torch.zeros(n,n)

  for y in range(n):
    for x in range(n):
      for z in range(n):
        c[y][x] += a[y][z] * b[z][x]

  return c


def strassen2x2(a,b):
  
  m1 = (a[0][0] + a[1][1]) * (b[0][0] + b[1][1])
  m2 = (a[1][0] + a[1][1]) * b[0][0]
  m3 = a[0][0] * (b[0][1] - b[1][1])
  m4 = a[1][1] * (b[1][0] - b[0][0])
  m5 = (a[0][0] + a[0][1]) * b[1][1]
  m6 = (a[1][0] - a[0][0]) * (b[0][0] + b[0][1])
  m7 = (a[0][1] - a[1][1]) * (b[1][0] + b[1][1])
  
  return torch.tensor([[m1 + m4 - m5 + m7 , m3 + m5],
                        [m2 + m4 , m1 - m2 + m3 + m6]])

def strassen(a,b):

  if a.size() == torch.Size([2, 2]):
    return strassen2x2(a,b)

  n = int(len(a))
  n0 = n
  x = math.ceil(math.log2(n))
  newSize = int(math.pow(2,x))
  if n != newSize:
    a2 = torch.zeros(newSize,newSize)
    b2 = torch.zeros(newSize,newSize)
    a2[:n,:n] = a
    b2[:n,:n] = b
    a = a2
    b = b2

  n = int(newSize / 2)

  a11 = a[:n,:n]
  a12 = a[:n,n:]
  a21 = a[n:,:n]
  a22 = a[n:,n:]

  b11 = b[:n,:n]
  b12 = b[:n,n:]
  b21 = b[n:,:n]
  b22 = b[n:,n:]

  m1 = strassen((a11 + a22) , (b11 + b22))
  m2 = strassen((a21 + a22) , b11)
  m3 = strassen(a11 , (b12 - b22))
  m4 = strassen(a22 , (b21 - b11))
  m5 = strassen((a11 + a12) , b22)
  m6 = strassen((a21 - a11) , (b11 + b12))
  m7 = strassen((a12 - a22) , (b21 + b22))

  ret = torch.randn(n*2,n*2)
  ret[:n,:n] = m1 + m4 - m5 + m7
  ret[:n,n:] = m3 + m5
  ret[n:,:n] = m2 + m4
  ret[n:,n:] = m1 - m2 + m3 + m6

  return ret[:n0,:n0]

o1 = strassen(a,b)
o2 = torch.matmul(a,b)

for y in range(len(o1)):
  for x in range(len(o1[y])):
    assert(abs(o1[y][x].item() - o2[y][x].item()) < 0.000001)

a = torch.randn(16,16)
b = torch.randn(16,16)

o1 = strassen(a,b)
o2 = torch.matmul(a,b)
o3 = naive(a,b)

for y in range(len(o1)):
  for x in range(len(o1[y])):
    assert abs(o1[y][x].item() - o2[y][x].item()) < 0.0001 , ("ffs",o1[y][x].item(),o2[y][x].item())
    assert abs(o3[y][x].item() - o2[y][x].item()) < 0.0001 , ("ffs",o3[y][x].item(),o2[y][x].item())


a = torch.randn(20,20)
b = torch.randn(20,20)

o1 = strassen(a,b)
o2 = torch.matmul(a,b)
o3 = naive(a,b)

for y in range(len(o1)):
  for x in range(len(o1[y])):
    assert abs(o1[y][x].item() - o2[y][x].item()) < 0.0001 , ("ffs",o1[y][x].item(),o2[y][x].item())
    assert abs(o3[y][x].item() - o2[y][x].item()) < 0.0001 , ("ffs",o3[y][x].item(),o2[y][x].item())

# SPEED TESTS

n = 128
a = torch.randn(n,n)
b = torch.randn(n,n)

start_time = time.time()
o = torch.matmul(a,b)
t = time.time() - start_time
print("matmul " + str(n) + "x" + str(n) + " took",t,"seconds")


# NAIVE SPEED

o = naive(a,b)
t1 = time.time() - start_time
print("naive " + str(n) + "x" + str(n) + " took",t1,"seconds")

n = int(n/2)
a = torch.randn(n,n)
b = torch.randn(n,n)

start_time = time.time()
o = naive(a,b)
t2 = time.time() - start_time
print("naive " + str(n) + "x" + str(n) + " took",t2,"seconds")

x = t1 / t2
print(x,"times slower")

p = math.log(x) / math.log(2)

print("naive CPU -> n^" + str(p) + "\n\n")

# STRASSEN SPEED

n = 128
a = torch.randn(n,n)
b = torch.randn(n,n)

start_time = time.time()
o = strassen(a,b)
t1 = time.time() - start_time
print("strassen " + str(n) + "x" + str(n) + " took",t1,"seconds")

n = int(n/2)
a = torch.randn(n,n)
b = torch.randn(n,n)

start_time = time.time()
o = strassen(a,b)
t2 = time.time() - start_time
print("strassen " + str(n) + "x" + str(n) + " took",t2,"seconds")

x = t1 / t2
print(x,"times slower")

p = math.log(x) / math.log(2)

print("strassen CPU -> n^" + str(p) + "\n\n")

In [None]:
#pytorch CPU

n = 8000
m1 = torch.randn(n,n)
m2 = torch.randn(n,n)
r = torch.randn(n,n)

start_time = time.time()
r = torch.matmul(m1,m2)

t1 = time.time() - start_time
print("pytorch CPU time "+str(n)+"x"+str(n),t1)

n = int(n/2)
m1 = torch.randn(n,n)
m2 = torch.randn(n,n)
r = torch.randn(n,n)

start_time = time.time()
r = torch.matmul(m1,m2)

t2 = time.time() - start_time
print("pytorch CPU time "+str(n)+"x"+str(n),t2)

x = t1 / t2
print(x,"times slower")

p = math.log(x) / math.log(2)

print("pytorch CPU -> n^" + str(p) + "\n\n")

#pytorch GPU

device = torch.device("cuda")

n = 24000
m1 = torch.randn(n,n).to(device)
m2 = torch.randn(n,n).to(device)
r = torch.randn(n,n).to(device)

start_time = time.time()
r = torch.matmul(m1,m2)

torch.cuda.synchronize()
t1 = time.time() - start_time
print("pytorch GPU time "+str(n)+"x"+str(n),t1)

n = int(n/2)

gc.collect()
torch.cuda.empty_cache()

m1 = torch.randn(n,n).to(device)
m2 = torch.randn(n,n).to(device)
r = torch.randn(n,n).to(device)

start_time = time.time()
r = torch.matmul(m1,m2)

torch.cuda.synchronize()
t2 = time.time() - start_time
print("pytorch GPU time "+str(n)+"x"+str(n),t2)

x = t1 / t2
print(x,"times slower")

p = math.log(x) / math.log(2)

print("pytorch GPU -> n^" + str(p))

In [None]:
!pip install pycuda

In [20]:
# MY OWN CUDA ATTEMPT
import pycuda.compiler as comp
import pycuda.driver as cuda
import pycuda.autoinit

mod = comp.SourceModule(
    """
  __global__ void matmul(float *nodesD, float *weights, float *nodesA, int ncA, int ncB, int nrA, int startn0, int startD, int startW)
{
  int row = threadIdx.y + blockDim.y * blockIdx.y;
  int col = threadIdx.x + blockDim.x * blockIdx.x;
  float t = 0;
  if(col < ncB && row < nrA)
  {
  for(int i = 0; i < ncA; i++){
    t += weights[startW + (row * ncA) + i] * nodesA[startn0 + col + (i * ncB)];
  }
    nodesD[startD + (row * ncB) + col] = t;
  }
}
"""
)
MAX_THREADS_PER_BLOCK = \
    cuda.Device(0).get_attribute(pycuda._driver.device_attribute.MAX_THREADS_PER_BLOCK)
cudaMatMul = mod.get_function("matmul")


n = 4
m1 = torch.randn(n,n)
m2 = torch.randn(n,n)
r = torch.randn(n,n)

r = torch.matmul(m1,m2)

m1np = m1.numpy()
m2np = m2.numpy()
rnp = r.numpy()

m1cuda = cuda.mem_alloc(m1np.nbytes)
m2cuda = cuda.mem_alloc(m2np.nbytes)
rcuda = cuda.mem_alloc(rnp.nbytes)

cuda.memcpy_htod(m1cuda,m1np)
cuda.memcpy_htod(m2cuda,m2np)
cuda.memcpy_htod(rcuda,rnp)


### check copy works ok
cuda.memcpy_dtoh(rnp,rcuda)

for y in range(len(rnp)):
  for x in range(len(rnp[y])):
    assert rnp[y][x] == r[y][x], "copy didn't work"
