In [None]:
import numpy as np
import torch
import time
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
import torchvision

In [None]:
# ALGO1: nn.conv2d
def myconv2d(input, weight, bias=None, stride=(1,1), padding=(0,0), dilation=(1,1), groups=1):
    """
    Function to process an input with a standard convolution
    """
    mul_count = 0
#     print('input', input.shape)
#     print('wt', weight.shape)
    batch_size, in_channels, in_h, in_w = input.shape
    out_channels, in_channels, kh, kw = weight.shape
    out_h = int((in_h - kh + 2 * padding[0]) / stride[0] + 1)
    out_w = int((in_w - kw + 2 * padding[1]) / stride[1] + 1)
    unfold = torch.nn.Unfold(kernel_size=(kh, kw), dilation=dilation, padding=padding, stride=stride)
    inp_unf = unfold(input)
    w_ = weight.view(weight.size(0), -1).t()
    if bias is None:
        out_unf = inp_unf.transpose(1, 2).matmul(w_).transpose(1, 2)
        mul_count += batch_size*out_channels*out_h*out_w*in_channels*kh*kw
    else:
        out_unf = (inp_unf.transpose(1, 2).matmul(w_) + bias).transpose(1, 2)
        mul_count += batch_size*out_channels*out_h*out_w*in_channels*kh*kw
    out = out_unf.view(batch_size, out_channels, out_h, out_w)
#     print(out)
    return (out.float(), mul_count)
    # return out.float()

##############################################################################################

class comp_vector():
  def __init__(self, arr):
    self.x = arr.size(dim=2)
    self.y = arr.size(dim=1)
    self.c = arr.size(dim=0)
    self.index_vector = []
    self.data_vector = []
    for i in range(self.c):
      # print(arr[i])
      self.index_vector.append(np.flatnonzero(arr[i].cpu()))
      self.data_vector.append(arr[i].ravel()[self.index_vector[-1]])

    # index_vector = np.flatnonzero(arr)
    # data_vector = arr.ravel()[index_vector]

  def get_index_vector(self):
    return self.index_vector

  def get_data_vector(self):
    return self.data_vector


def conv_compressed(comp_inp, comp_wt, stride=1):
#     print('called conv_compressed')
    acc_x, acc_y, acc_c = int((comp_inp.x - comp_wt.x)//stride  + 1) , int((comp_inp.y - comp_wt.y)//stride  +1), comp_wt.c
#     print(acc_x, acc_y, acc_c)
    mult_count = 0
    # print(acc_x, acc_y, acc_c)
    acc_buf = torch.FloatTensor(acc_x, acc_y).zero_()
    inp_index_vector = comp_inp.get_index_vector()
    inp_data_vector = comp_inp.get_data_vector()
    wt_index_vector = comp_wt.get_index_vector()
    wt_data_vector = comp_wt.get_data_vector()
    # print(inp_index_vector[0])
    # print(len(inp_index_vector[0]))
    for c in range(acc_c):
      for i in range(len(inp_index_vector[c])):
        for j in range(len(wt_index_vector[c])):
          inp_x = inp_index_vector[c][i]//comp_inp.x
          inp_y = inp_index_vector[c][i]%comp_inp.y
          wt_x = wt_index_vector[c][j]//comp_wt.x
          wt_y = wt_index_vector[c][j]%comp_wt.y

          out_x = (inp_x - wt_x)
          out_y = (inp_y- wt_y)
          if out_x%stride==0 and out_y%stride==0:
            out_x = out_x//stride
            out_y = out_y//stride
            # print(out_x, out_y,c,i,j,)
            if 0<=out_x<acc_x and 0<=out_y<acc_y:
              # print("yes")
              acc_buf[out_x][out_y]+=float(inp_data_vector[c][i] * wt_data_vector[c][j])
              mult_count +=1
    
    return (acc_buf,mult_count)

def myconv2d_sparse(input, weight, bias=None, stride=(1,1), padding=(0,0), dilation=(1,1), groups=1):
  input = torch.nn.functional.pad(input, (padding[1], padding[1], padding[0], padding[0]), "constant", 0)
#   print(input.size())
  comp_in = comp_vector(input[0])
  in_x = input.size(dim=3)
  in_y = input.size(dim=2)
  wt_x = weight.size(dim=3)
  wt_y = weight.size(dim=2)
  c = weight.size(dim=1)
  k = weight.size(dim=0)
  out = torch.empty(size=(1,k, int((in_x-wt_x)/stride[0]+1), int((in_y-wt_y)/stride[1]+1)))

  mult_count = 0
  for i in range(k):
    comp_wt = comp_vector(weight[i])
    out[0][i], num =conv_compressed(comp_in, comp_wt, stride[0])
    out[0][i] += bias[i]
    mult_count+=num
#   print(out)
  return (out,mult_count)

######################################################################################################

# ALGO 3
def compute_weight_list(kernel):    
    kernels = []
    filter_count = kernel.shape[0]
    depth = kernel.shape[1]
    height = kernel.shape[2]
    width = kernel.shape[3]
    for f in range(filter_count):
        weight_list = []
        for k in range(depth):
            for i in range(height):
                for j in range(width):
                    w = kernel[f][k][i][j]
                    if w < 0:
                        weight_list.append(tuple((w, k, i, j)))
        sorted_weight_list = sorted(weight_list, key = lambda x: x[0])
        kernels.append(sorted_weight_list)
    return kernels

def compute_conv_onlypred(img, weight_list, weights, r, c, bias=0):
    img_out_cell = 0
    conv_mult_count = 0
    depth = weights.shape[0]
    height = weights.shape[1]
    width = weights.shape[2]
    for k in range(depth):
        for i in range(width):
            for j in range(height):          
                if weights[k][i][j]>0:
                  #  and r+i<img.shape[2] and c+j<img.shape[3]
                    conv_mult_count += 1 
                    img_out_cell += img[0][k][r+i][c+j]*weights[k][i][j]

    img_out_cell+=bias
    
    for tup in weight_list:
        conv_mult_count += 1
        # if r+tup[2]>=img.shape[2] or c+tup[3]>=img.shape[3]:
        #   continue
        img_out_cell += tup[0]*img[0][tup[1]][r+tup[2]][c+tup[3]]
        if img_out_cell < 0:
            break
    return img_out_cell, conv_mult_count

def compute_filter_conv_onlypred(img, weight_list, weights, kernel_id,stride=(1,1), padding=(0,0), bias=0):
    width_out = int((img.shape[3]-weights.shape[2])/stride[1]+1)
    height_out = int((img.shape[2]-weights.shape[1])/stride[0]+1)
    img_out_channel = torch.zeros(width_out,height_out)
    filter_mult_count = 0
    # print(img.shape[2]+2*padding[0]-weights.shape[1], img.shape[3]+2*padding[1]-weights.shape[2])
    for r in range(0,img.shape[2]-weights.shape[1]+1,stride[0]):
        for c in range(0,img.shape[3]-weights.shape[2]+1,stride[1]):
            r_out = int(r/stride[0])
            c_out = int(c/stride[1])
            # print(r_out, c_out)
            img_out_channel[r_out][c_out], mult_count = compute_conv_onlypred(img, weight_list, weights, r, c, bias)
            # img_out_channel[r_out][c_out] += bias
            filter_mult_count += mult_count
    return img_out_channel, filter_mult_count

def myconv2d_onlypred(img, weights, bias=None, stride=(1,1), padding=(0,0), dilation=(1,1), groups=1):
    img = torch.nn.functional.pad(img, (padding[1], padding[1], padding[0], padding[0]), "constant", 0)
    layer_mult_count = 0
    filter_count = weights.shape[0]
    depth = weights.shape[1]
    height = weights.shape[2]
    width = weights.shape[3]
    channels_out=filter_count
    width_out = int((img.shape[3]-width)/stride[1]+1)
    height_out = int((img.shape[2]-height)/stride[0]+1)
    img_conv_output = torch.zeros(1, channels_out, width_out, height_out)
    filters_list = compute_weight_list(weights)
    for kernel_id in range(filter_count):
        if kernel_id%8==0:
            print("kernel_id", kernel_id)
        weight_list = filters_list[kernel_id]
        img_conv_channel, mult_count = compute_filter_conv_onlypred(img, weight_list, weights[kernel_id], kernel_id, stride, padding, bias[kernel_id])
        img_conv_output[0][kernel_id] = img_conv_channel
        layer_mult_count += mult_count
    return (img_conv_output, layer_mult_count)


######################################################################################


# ALGO 4
class comp_vector_pred():
  def __init__(self, arr):
    self.y = arr.size(dim=2)
    self.x = arr.size(dim=1)
    self.c = arr.size(dim=0)
    self.pos_vector = [] #stores tuples of (data, index)
    self.neg_vector = []

    for k in range(self.c):
            for i in range(self.x):
                for j in range(self.y):
                    w = arr[k][i][j]
                    if w > 0:
                      self.pos_vector.append(tuple((w, k, i, j)))
                    elif w<0:
                      self.neg_vector.append(tuple((w, k, i, j)))

    self.neg_vector = sorted(self.neg_vector, key = lambda x: x[0])

  def get_pos_vector(self):
    return self.pos_vector

  def get_neg_vector(self):
    return self.neg_vector


def compute_conv_sparsepred(input, weight, comp_wt, r, c, bias=0):
  img_out_cell = 0
  conv_mult_count = 0
  pos = comp_wt.get_pos_vector()
  neg = comp_wt.get_neg_vector()

  x = weight.shape[1]
  y = weight.shape[2]
  k = weight.shape[0]

  mult_nonzero = 0
  for channel in range(k):
    inp_window = input[0][channel][r:r+x, c:c+y]
    inp_nonzero = np.flatnonzero(inp_window)
    wt_nonzero = np.flatnonzero(weight[channel])
    common = sum(X == Y for X, Y in zip(inp_nonzero, wt_nonzero))
    mult_nonzero += common

  for tup in pos:
    # if(r+tup[2]>=input.shape[2] or c+tup[3]>=input.shape[3]):
    #   continue
    conv_mult_count += 1
    img_out_cell += tup[0]*input[0][tup[1]][r+tup[2]][c+tup[3]]

  img_out_cell+=bias

  for tup in neg:
    if img_out_cell < 0:
      break
    # if(r+tup[2]>=input.shape[2] or c+tup[3]>=input.shape[3]):
    #   continue
    conv_mult_count += 1
    img_out_cell += tup[0]*input[0][tup[1]][r+tup[2]][c+tup[3]]

  return img_out_cell, conv_mult_count, mult_nonzero


def compute_filter_conv_sparsepred(input, weights, comp_wt, width_out, height_out,stride=(1,1), padding=(0,0), bias=0):
#     print('called compute_filter_conv')
    img_out_channel = torch.zeros(width_out, height_out)
    filter_mult_count = 0
    filter_calc_mult = 0
    for r in range(0,input.shape[2]-weights.shape[1]+1,stride[0]):
        for c in range(0,input.shape[3]-weights.shape[2]+1,stride[1]):
            r_out = int(r/stride[0])
            c_out = int(c/stride[1])
            img_out_channel[r_out][c_out], mult_count, calc_mult = compute_conv_sparsepred(input, weights,comp_wt, r, c, bias)
            filter_mult_count += mult_count
            filter_calc_mult += calc_mult
    return img_out_channel, filter_mult_count, filter_calc_mult


def myconv2d_sparse_pred(input, weight, bias=None, stride=(1,1), padding=(0,0), dilation=(1,1), groups=1):
    input = torch.nn.functional.pad(input, (padding[1], padding[1], padding[0], padding[0]), "constant", 0)
    in_x = input.shape[2]
    in_y = input.shape[3]
    wt_x = weight.shape[2]
    wt_y = weight.shape[3]
    c = weight.shape[1]
    filter_count = weight.shape[0]
    w = int((in_x-wt_x)//stride[0]+1)
    h = int((in_y-wt_y)//stride[1]+1)
    out = torch.empty(size=(1, filter_count, w, h))

    mult_count = 0
    calc_mult = 0
    for i in range(filter_count):
        comp_wt = comp_vector_pred(weight[i])
        out[0][i], num1, num2 =compute_filter_conv_sparsepred(input, weight[i], comp_wt, w, h,stride, padding, bias[i])
        # out[0][i] += bias[i]
        mult_count+=num1
        calc_mult+=num2
    # mult_count is predictive sparse(weight only)
    # calc_mult is baseline 2(sparse non-predictive)
#     print(out)
    return (out,mult_count)
###################################################################################

# ALGO 5
class comp_vector_pred_twosided():
  def __init__(self, arr):
    self.x = arr.size(dim=2)
    self.y = arr.size(dim=1)
    self.c = arr.size(dim=0)
    self.pos_vector = [] #stores tuples of (data, index)
    self.neg_vector = []

    for k in range(self.c):
            for i in range(self.y):
                for j in range(self.x):
                    w = arr[k][i][j]
                    if w > 0:
                      self.pos_vector.append(tuple((w, k, i, j)))
                    elif w<0:
                      self.neg_vector.append(tuple((w, k, i, j)))

    self.neg_vector = sorted(self.neg_vector, key = lambda x: x[0])

  def get_pos_vector(self):
    return self.pos_vector

  def get_neg_vector(self):
    return self.neg_vector


def compute_conv_sparsepred_twosided(input, weight, comp_wt, r, c,bias=0):
  img_out_cell = 0
  conv_mult_count = 0
  pos = comp_wt.get_pos_vector()
  neg = comp_wt.get_neg_vector()

  x = weight.shape[1]
  y = weight.shape[2]
  k = weight.shape[0]

  mult_nonzero = 0
  for channel in range(k):
    inp_window = input[0][channel][r:r+x, c:c+y]
    inp_nonzero = np.flatnonzero(inp_window)
    wt_nonzero = np.flatnonzero(weight[channel])
    common = sum(X == Y for X, Y in zip(inp_nonzero, wt_nonzero))
    mult_nonzero += common

  for tup in pos:
    if(input[0][tup[1]][r+tup[2]][c+tup[3]]==0):
      #  or r+tup[2]>=input.shape[2] or c+tup[3]>=input.shape[3]
      continue
    conv_mult_count += 1
    img_out_cell += tup[0]*input[0][tup[1]][r+tup[2]][c+tup[3]]

  img_out_cell+=bias

  # idx = 0
  # while img_out_cell>=0 and idx<len(neg):
  #   tup = neg[idx]
  #   if(input[0][tup[1]][r+tup[2]][c+tup[3]]==0):
  #     continue
  #   conv_mult_count += 1
  #   img_out_cell += tup[0]*input[0][tup[1]][r+tup[2]][c+tup[3]]
  #   idx+=1

  for tup in neg:
    if(input[0][tup[1]][r+tup[2]][c+tup[3]]==0):
      #  or r+tup[2]>=input.shape[2] or c+tup[3]>=input.shape[3]
      continue
    if img_out_cell < 0:
      break
    conv_mult_count += 1
    img_out_cell += tup[0]*input[0][tup[1]][r+tup[2]][c+tup[3]]

  return img_out_cell, conv_mult_count, mult_nonzero


def compute_filter_conv_sparsepred_twosided(input, weights, comp_wt, width_out, height_out,stride=(1,1), padding=(0,0), bias=0):
    img_out_channel = torch.zeros(width_out, height_out)
    filter_mult_count = 0
    filter_calc_mult = 0
    # print(input.shape)
    # print(weights.shape)
    # print(len(stride))
    # print(len(padding))
    for r in range(0,input.shape[2]-weights.shape[1]+1,stride[0]):
        for c in range(0,input.shape[3]-weights.shape[2]+1,stride[1]):
            r_out = int(r/stride[0])
            c_out = int(c/stride[1])
            img_out_channel[r_out][c_out], mult_count, calc_mult = compute_conv_sparsepred_twosided(input, weights,comp_wt, r, c, bias)
            # img_out_channel[r][c] += bias 
            # Bias added in compute_conv_sparsepred_twosided
            filter_mult_count += mult_count
            filter_calc_mult += calc_mult
    return img_out_channel, filter_mult_count, filter_calc_mult


def myconv2d_sparse_pred_twosided(input, weight, bias=None, stride=(1,1), padding=(0,0), dilation=(1,1), groups=1):
  input = torch.nn.functional.pad(input, (padding[1], padding[1], padding[0], padding[0]), "constant", 0)
  in_x = input.shape[2]
  in_y = input.shape[3]
  wt_x = weight.shape[2]
  wt_y = weight.shape[3]
  c = weight.shape[1]
  filter_count = weight.shape[0]
  w = int((in_x-wt_x)//stride[0]+1)
  h = int((in_y-wt_y)//stride[1]+1)
  # print(w,h)
  out = torch.empty(size=(1, filter_count, w, h))

  mult_count = 0
  calc_mult = 0
  for i in range(filter_count):
    comp_wt = comp_vector_pred_twosided(weight[i])
    out[0][i], num1, num2 =compute_filter_conv_sparsepred_twosided(input, weight[i], comp_wt, w, h,stride=stride,padding=padding,bias=bias[i])
    mult_count+=num1
    calc_mult+=num2
  
  return (out,mult_count)