In [1]:
import torch

In [834]:
n_inputs = 100
n_hidden = 64
inputs = torch.randn((5, n_inputs))

In [1050]:
class TanhLayer:
    def forward(self, inputs):
        self.output = torch.tanh(inputs)
        return self.output
    
    def backward(self, upstream_grad):
        return upstream_grad * (1 - self.output**2)
    
    def int_analysis_forward(self, center, radius):
        lower = torch.tanh(center - radius)
        upper = torch.tanh(center + radius)
        self.output_center = (lower+upper)/2
        self.output_radius = upper - self.output_center

        return self.output_center, self.output_radius
    
    def int_analysis_backward(self, upstream_center, upstream_radius):
        grad_center = upstream_center * (1 - self.output_center**2)
        grad_radius = upstream_radius * 2 * self.output_radius
        return grad_center, grad_radius

In [1081]:
class LinearLayer:
    def __init__(self, input_width, output_width):
        self.weights = 0.1*torch.randn((input_width, output_width))
        
    def forward(self, inputs):
        self.input = inputs
        self.output = inputs@self.weights
        return self.output
    
    def backward(self, upstream_grad):
        self.grad = self.input.T @ upstream_grad
        return upstream_grad @ self.weights.T
    
    def int_analysis_forward(self, center, radius):
        #print('center:', center.shape, 'radius:', radius.shape, 'weights:', self.weights.shape)
        self.input_center = center
        self.input_radius = radius
        new_center = center@self.weights
        new_radius = radius@self.weights.abs() # beta is all zeros
        
        self.output_center = new_center
        return new_center, new_radius
    
    def int_analysis_backward(self, upstream_center, upstream_radius):
        grad_center = upstream_center @ self.weights.T
        grad_radius = upstream_radius @ self.weights.T.abs() # beta is all zeros
        
        self.grad_center = self.input_center.T @ upstream_center
        #print('needed dim:', self.grad_center.shape)
        ur = torch.nn.functional.normalize(upstream_radius)
        #ur = upstream_radius
        #print(ur)
        self.grad_radius = self.input_center.T.abs() @ ur +\
                           self.input_radius.T @ ur.abs() +\
                           self.input_radius.T @ ur
        #self.grad_radius = self.input_radius
        return grad_center, grad_radius

In [1082]:
class Classifier:
    def __init__(self, n_features, n_hidden=10):
        self.layers = [
            LinearLayer(n_features, n_hidden),
            TanhLayer(),
            LinearLayer(n_hidden, n_hidden),
            TanhLayer(),
            LinearLayer(n_hidden, n_hidden),
            TanhLayer(),
            LinearLayer(n_hidden, 1),
            TanhLayer(),
        ]

    def forward(self, inputs):
        output = inputs
        
        for l in self.layers:
            output = l.forward(output)
        
        return output
    
    def backward(self):
        last_layer = self.layers[-1]
        output_shape = last_layer.output.shape
        # dL / dY
        upstream_grad = torch.ones(output_shape)
        #upstream_grad = last_layer.output
        
        for layer in reversed(self.layers):
            downstream_grad = layer.backward(upstream_grad)          
            upstream_grad = downstream_grad

        return upstream_grad
    
    def int_analysis_forward(self, center, radius):
        for layer in self.layers:
            #print(layer)
            center, radius = layer.int_analysis_forward(center, radius)
            
        return center, radius
    
    def int_analysis_backward(self):
        last_layer = self.layers[-1]
        output_shape = last_layer.output_center.shape
        # dL / dY
        upstream_grad_center = torch.zeros(output_shape)
        upstream_grad_radius = torch.ones(output_shape)
        
        for layer in reversed(self.layers):
            d_center, d_radius = layer.int_analysis_backward(upstream_grad_center, upstream_grad_radius)

            upstream_grad_center = d_center
            upstream_grad_radius = d_radius

        return upstream_grad_center, upstream_grad_radius

In [1083]:
c = Classifier(n_inputs, n_hidden)

In [1084]:
%%time
x = c.forward(inputs)

CPU times: user 1.17 ms, sys: 654 µs, total: 1.83 ms
Wall time: 1.33 ms


In [1085]:
c.backward().shape

torch.Size([5, 100])

In [1086]:
if False:
    for layer in c.layers:
        if isinstance(layer, LinearLayer):
            print('weights shape:', layer.weights.shape)
            print('grad shape:', layer.grad.shape)

In [1087]:
center = torch.zeros((1, n_inputs))
radius = torch.full((1, n_inputs), 1.)

In [1088]:
%%time
c.int_analysis_forward(center, radius)

CPU times: user 1.5 ms, sys: 1.08 ms, total: 2.58 ms
Wall time: 1.46 ms


(tensor([[0.]]), tensor([[1.0000]]))

In [1089]:
%%time
center, radius = c.int_analysis_backward()

CPU times: user 3.73 ms, sys: 1.78 ms, total: 5.51 ms
Wall time: 3.7 ms


In [1090]:
radius.norm()

tensor(2007.8702)

In [1091]:
for layer in c.layers:
    if isinstance(layer, LinearLayer):
        #print(layer.grad_center)
        print(layer.grad_radius.norm())

tensor(20.0000)
tensor(16.0000)
tensor(15.9981)
tensor(15.9982)
