In [1]:
import torch, gc
import time
import math

In [56]:
# strassen CPU attempt

a = torch.randn(2,2)
b = torch.randn(2,2)

# ------

def strassen2x2(a,b):
  
  m1 = (a[0][0] + a[1][1]) * (b[0][0] + b[1][1])
  m2 = (a[1][0] + a[1][1]) * b[0][0]
  m3 = a[0][0] * (b[0][1] - b[1][1])
  m4 = a[1][1] * (b[1][0] - b[0][0])
  m5 = (a[0][0] + a[0][1]) * b[1][1]
  m6 = (a[1][0] - a[0][0]) * (b[0][0] + b[0][1])
  m7 = (a[0][1] - a[1][1]) * (b[1][0] + b[1][1])
  
  return torch.tensor([[m1 + m4 - m5 + m7 , m3 + m5],
                        [m2 + m4 , m1 - m2 + m3 + m6]])


o1 = strassen2x2(a,b)
o2 = torch.matmul(a,b)

for y in range(len(o1)):
  for x in range(len(o1[y])):
    assert(abs(o1[y][x].item() - o2[y][x].item()) < 0.000001)

In [None]:
#pytorch CPU

n = 8000
m1 = torch.randn(n,n)
m2 = torch.randn(n,n)
r = torch.randn(n,n)

start_time = time.time()
r = torch.matmul(m1,m2)

t1 = time.time() - start_time
print("pytorch CPU time "+str(n)+"x"+str(n),t1)

n = int(n/2)
m1 = torch.randn(n,n)
m2 = torch.randn(n,n)
r = torch.randn(n,n)

start_time = time.time()
r = torch.matmul(m1,m2)

t2 = time.time() - start_time
print("pytorch CPU time "+str(n)+"x"+str(n),t2)

x = t1 / t2
print(x,"times slower")

p = math.log(x) / math.log(2)

print("pytorch CPU -> n^" + str(p) + "\n\n")

#pytorch GPU

device = torch.device("cuda")

n = 24000
m1 = torch.randn(n,n).to(device)
m2 = torch.randn(n,n).to(device)
r = torch.randn(n,n).to(device)

start_time = time.time()
r = torch.matmul(m1,m2)

torch.cuda.synchronize()
t1 = time.time() - start_time
print("pytorch GPU time "+str(n)+"x"+str(n),t1)

n = int(n/2)

gc.collect()
torch.cuda.empty_cache()

m1 = torch.randn(n,n).to(device)
m2 = torch.randn(n,n).to(device)
r = torch.randn(n,n).to(device)

start_time = time.time()
r = torch.matmul(m1,m2)

torch.cuda.synchronize()
t2 = time.time() - start_time
print("pytorch GPU time "+str(n)+"x"+str(n),t2)

x = t1 / t2
print(x,"times slower")

p = math.log(x) / math.log(2)

print("pytorch GPU -> n^" + str(p))


# strassen CPU attempt