In [33]:
import torch

In [34]:
import fc_layer

In [44]:
import math
import torch
import fc_layer

class FCFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weights, bias):
        outputs = fc_layer.forward(input, weights, bias)
        variables = [input, weights, bias]
        ctx.save_for_backward(*variables)
        output = outputs[0]
        
        return output
    
    @staticmethod
    def backward(ctx, dout):
        output = fc_layer.backward(*ctx.saved_variables, dout)
        dx, dW, db = output
        
        return dx, dW, db
    
class FullyConnected(torch.nn.Module):
    def __init__(self, input_features, output_features):
        super(FullyConnected, self).__init__()
        self.input_features = input_features
        self.output_features = output_features
        self.weights = torch.nn.Parameter(torch.normal(0, 1/input_features**0.5,(input_features, output_features))) 
        self.bias = torch.nn.Parameter(torch.normal(0, 1/input_features**0.5, (output_features,)))
        
    def forward(self, input):
        return FCFunction.apply(input, self.weights, self.bias)

In [None]:
FC = torch.nn.Linear(2, 3)
input1 = torch.randn(5, 2)
output1 = FC(input1) 

In [51]:
import time
import torch

batch_size = 128
input_size = 20
output_size = 30
input1 = torch.randn(batch_size, input_size)

FC_Cpp = FullyConnected(input_size, output_size)
FC_Torch = torch.nn.Linear(input_size, output_size)

forward2 = 0
backward2 = 0

for _ in range(100000):
    start = time.time()
    output2 = FC_Cpp(input1)
    forward2 += time.time() - start

    start = time.time()
    (output2.sum()).backward()
    backward2 += time.time() - start

print('Forward: {:.3f} us | Backward {:.3f} us'.format(forward * 1e6/1e5, backward * 1e6/1e5))

forward1 = 0
backward1 = 0

for _ in range(100000):
    start = time.time()
    output1 = FC_Torch(input1)
    forward1 += time.time() - start

    start = time.time()
    (output1.sum()).backward()
    backward1 += time.time() - start

print('Forward: {:.3f} us | Backward {:.3f} us'.format(forward1 * 1e6/1e5, backward1 * 1e6/1e5))



Forward: 34.797 us | Backward 94.797 us
Forward: 46.753 us | Backward 90.307 us
