# Gradient-Based Training of WANNs

In [1]:
import torch
from torch.autograd import Variable
from functools import reduce, partial
import numpy as np

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

from sklearn.metrics import accuracy_score, precision_score, recall_score

In [2]:
class MultiActivationModule(torch.nn.Module):
    """Applies multiple elementwise activation functions to a tensor."""
    
    discretized = False
    
    available_act_functions = [
        ('relu', torch.relu),
        ('sigmoid', torch.sigmoid),
        ('tanh', torch.tanh),
        ('gaussian (standard)', lambda x: torch.exp(-torch.square(x) / 2.0)),
        ('step', lambda t: (t > 0.0) * 1.0),
        ('identity', lambda x: x),
        ('inverse', torch.neg),
        ('squared', torch.square),
        ('abs', torch.abs),
        ('cos', torch.cos),
        ('sin', torch.sin),
    ]
    
    @property
    def n_funcs(self):
        return len(self.funcs)
    
    def __init__(self, n_out):
        super().__init__()
        self.funcs = [f[1] for f in self.available_act_functions]
        
        self.weight = torch.nn.Parameter(torch.zeros((self.n_funcs, n_out)))
        self.frozen = torch.zeros(n_out, dtype=bool)
        self.soft = torch.nn.Softmax(dim=0)

    def forward(self, x):
        coefficients = self.soft(self.weight)
        
        return reduce(
            lambda first, act: (
                torch.add(
                    first,
                    torch.mul(
                        act[1](x),  # apply activation func
                        coefficients[act[0], :])
            )),
            enumerate(self.funcs),  # index, func
            torch.zeros_like(x)  # start value
        )

In [3]:
class RestrictedLinear(torch.nn.Module):
    """Similar to torch.nn.Linear, but restricts weights between -1 and 1."""
    
    def __init__(self, n_in, n_out):
        super().__init__()     
        self.sign = torch.nn.Softsign()
        self.weight = torch.nn.Parameter(torch.empty(n_in, n_out))
        self.frozen = torch.zeros(*self.weight.size(), dtype=bool)
        
    def forward(self, x):
        weight = torch.where(self.frozen, self.weight, self.sign(self.weight))
        return torch.nn.functional.linear(x, weight.T)
        

In [4]:
class ConcatLayer(torch.nn.Module):
    """Contatenates output of the active nodes and prior nodes."""
    
    def __init__(self, n_in, n_out, shared_weight):
        super().__init__()
        self.linear = RestrictedLinear(n_in, n_out)
        self.activation = MultiActivationModule(n_out)
        
        self.shared_weight = shared_weight
        
    def forward(self, x):
        linear = self.linear(x) * self.shared_weight[:, None, None]
        
        inner_out = self.activation(linear)
        
        return torch.cat([x, inner_out], dim=-1)

In [5]:
def weight_init(m):
    if isinstance(m, RestrictedLinear):
        torch.nn.init.normal_(m.weight.data)
    elif isinstance(m, MultiActivationModule):
        torch.nn.init.normal_(m.weight.data)
               
def reset_masked_gradients(m):
    # set gradient of masked parameters to 0 -> these parameters won't be updated by the optimizer
    if isinstance(m, MultiActivationModule):
        m.weight.grad[:, m.frozen] = 0.0
    if isinstance(m, RestrictedLinear):
        m.weight.grad[m.frozen] = 0.0
        
def freeze_some_act_funcs(m, ratio=0.2):
    if isinstance(m, MultiActivationModule):
        indices = torch.max(m.weight, 0).indices
        
        to_freeze = (torch.rand(*m.frozen.size()) <= ratio) & ~m.frozen
        
        if torch.any(to_freeze):
            indices = torch.max(m.weight[:, to_freeze], 0).indices
            m.weight.data[:, to_freeze] = torch.nn.functional.one_hot(indices, m.n_funcs).T.float()
            m.frozen[to_freeze] = True
        
def freeze_some_weights(m, ratio=0.2, zero_ratio=0.4):
    if isinstance(m, RestrictedLinear):       
        # mask-0 operation
        if torch.any(~m.frozen):
            alpha_zero = np.percentile(m.weight.data[~m.frozen].abs(), 100 * ratio * zero_ratio)
            mask_zero = (m.weight.data.abs() <= alpha_zero) & ~m.frozen

            # mask-1 operation
            alpha_one = np.percentile(-m.weight.data[~m.frozen].abs(), 100 * ratio * (1-zero_ratio))
            mask_one = (-m.weight.data.abs() <= alpha_one) & ~m.frozen
        
            m.frozen[mask_zero] = True
            m.frozen[mask_one] = True

            m.weight.data[mask_zero] = 0.0
            m.weight.data[mask_one] = m.weight.data[mask_one].sign()

In [6]:
dataset = load_iris()
#train_X, test_X, train_y, test_y = train_test_split(dataset['data'],
#                                                    dataset['target'], test_size=0.2)


# Just going for fitting here. (Todo: Change stuff in paper for usefull comparison)
train_X = dataset['data']
train_y = dataset['target']
test_X = train_X
test_y = train_y

train_X = np.hstack([train_X, np.ones((train_X.shape[0], 1))])
test_X = np.hstack([test_X, np.ones((test_X.shape[0], 1))])

In [7]:
# wrap up with Variable in pytorch
train_X = Variable(torch.Tensor(train_X).float())
test_X = Variable(torch.Tensor(test_X).float())
train_y = Variable(torch.Tensor(train_y).long())
test_y = Variable(torch.Tensor(test_y).long())

In [8]:
class Model(torch.nn.Module):
    def __init__(self, *layer_sizes):
        shared_weight = Variable(torch.Tensor([1]))
        super().__init__()
        
        layers = list()
        
        n_in = layer_sizes[0]
        
        for n_out in layer_sizes[1:]:
            layers.append(ConcatLayer(n_in, n_out, shared_weight))
            n_in += n_out
        
        self.network = torch.nn.Sequential(*layers)
        self.softmax = torch.nn.Softmax(dim=-1)

    def forward(self, x):
        net_out = self.network(x)
        net_out = net_out[..., -3:]
        return self.softmax(net_out)    

In [9]:
criterion = torch.nn.CrossEntropyLoss()# cross entropy loss

def train(model, n_epochs=2000):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
    for epoch in range(n_epochs):
        optimizer.zero_grad()

        batch = train_X.unsqueeze(dim=0)
        out = model(batch).view(-1, 3)
        loss = criterion(out, train_y)
        loss.backward()
        
        # set masked gradients to zero, so the weights will not be updated
        model.apply(reset_masked_gradients)
        
        optimizer.step()

        if epoch % 100 == 0:
            print ('number of epoch', epoch, 'loss', loss.data)

def evaluate(model):
    predict_out = model(test_X.unsqueeze(dim=0)).view(-1, 3)
    _, predict_y = torch.max(predict_out, 1)

    print ('prediction accuracy', accuracy_score(test_y.data, predict_y.data))

#    print ('macro precision', precision_score(test_y.data, predict_y.data, average='macro'))
#    print ('micro precision', precision_score(test_y.data, predict_y.data, average='micro'))
#    print ('macro recall', recall_score(test_y.data, predict_y.data, average='macro'))
#    print ('micro recall', recall_score(test_y.data, predict_y.data, average='micro'))
    

In [10]:
model = Model(5,2,2,3)
model.apply(weight_init)
train(model, 1000)
evaluate(model)

number of epoch 0 loss tensor(1.1104)
number of epoch 100 loss tensor(1.0213)
number of epoch 200 loss tensor(0.9851)
number of epoch 300 loss tensor(0.9474)
number of epoch 400 loss tensor(0.8961)
number of epoch 500 loss tensor(0.8268)
number of epoch 600 loss tensor(0.7644)
number of epoch 700 loss tensor(0.7366)
number of epoch 800 loss tensor(0.7192)
number of epoch 900 loss tensor(0.7023)
prediction accuracy 0.98


In [11]:
for i in range(20):
    model.apply(partial(freeze_some_act_funcs, ratio=0.1))
    train(model, 100)
    if i % 4 == 0:
        print (f"--- Iteration {i} ---")
        evaluate(model)
    
print ("=== Final ===")
model.apply(partial(freeze_some_act_funcs, ratio=1))
evaluate(model)

number of epoch 0 loss tensor(0.6604)
--- Iteration 0 ---
prediction accuracy 0.98
number of epoch 0 loss tensor(0.6103)
number of epoch 0 loss tensor(0.5796)
number of epoch 0 loss tensor(0.6108)
number of epoch 0 loss tensor(0.5776)
--- Iteration 4 ---
prediction accuracy 0.9866666666666667
number of epoch 0 loss tensor(0.5701)
number of epoch 0 loss tensor(0.5684)
number of epoch 0 loss tensor(0.5667)
number of epoch 0 loss tensor(0.5645)
--- Iteration 8 ---
prediction accuracy 0.9933333333333333
number of epoch 0 loss tensor(0.8847)
number of epoch 0 loss tensor(0.5706)
number of epoch 0 loss tensor(0.5683)
number of epoch 0 loss tensor(0.5666)
--- Iteration 12 ---
prediction accuracy 0.9933333333333333
number of epoch 0 loss tensor(0.5654)
number of epoch 0 loss tensor(0.5644)
number of epoch 0 loss tensor(0.8510)
number of epoch 0 loss tensor(0.5697)
--- Iteration 16 ---
prediction accuracy 0.9933333333333333
number of epoch 0 loss tensor(0.5671)
number of epoch 0 loss tensor(0.5

In [12]:
epochs = 20

for i in range(epochs):
    model.apply(partial(freeze_some_weights, ratio=1/(epochs-i)))
    train(model, 50)
    if i % 4 == 0:
        print (f"--- Iteration {i} ---")
        evaluate(model)

print("=== Final ===")
evaluate(model)

number of epoch 0 loss tensor(1.1854)
--- Iteration 0 ---
prediction accuracy 0.48
number of epoch 0 loss tensor(0.6064)
number of epoch 0 loss tensor(0.9192)
number of epoch 0 loss tensor(1.2106)
number of epoch 0 loss tensor(0.9370)
--- Iteration 4 ---
prediction accuracy 0.6333333333333333
number of epoch 0 loss tensor(0.8860)
number of epoch 0 loss tensor(0.8859)
number of epoch 0 loss tensor(0.8851)
number of epoch 0 loss tensor(0.8848)
--- Iteration 8 ---
prediction accuracy 0.6666666666666666
number of epoch 0 loss tensor(0.8848)
number of epoch 0 loss tensor(0.8848)
number of epoch 0 loss tensor(0.8848)
number of epoch 0 loss tensor(0.8848)
--- Iteration 12 ---
prediction accuracy 0.6666666666666666
number of epoch 0 loss tensor(0.8848)
number of epoch 0 loss tensor(0.8848)
number of epoch 0 loss tensor(0.8848)
number of epoch 0 loss tensor(0.8848)
--- Iteration 16 ---
prediction accuracy 0.6666666666666666
number of epoch 0 loss tensor(0.8848)
number of epoch 0 loss tensor(0.8

In [13]:
model.network[2].linear.weight

Parameter containing:
tensor([[ 1., -1., -1.],
        [ 1.,  0.,  0.],
        [ 1., -1.,  0.],
        [ 0.,  0.,  0.],
        [ 1., -1.,  0.],
        [ 0.,  1.,  0.],
        [ 1.,  0., -1.],
        [ 0.,  0.,  1.],
        [ 1.,  0.,  0.]], requires_grad=True)