In [1]:
import copy
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import random
import time
import sys, io

import mylibrary.datasets as datasets
import mylibrary.nnlib as tnn

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [3]:
import adam_custom
import json

In [4]:
a = torch.LongTensor([2,3])
a += 1
a

tensor([3, 4])

## Changes

- Add BN after convolution directly.
    - This helps keep weight norm uniform while changing the scaling parameter of BN
    - This will help to make the weight gradient well behaved.
    
- Reuse Optimizer (Adam) for added or removed parameters
    - This will (supposedly) remove unstable training
    - Maybe we need to add different learning rate for each parameter.

In [5]:
def _get_hidden_neuron_number(i, o):
    nh =  (max(i,o)*(min(i,o)**2))**(1/3)
#     return max(nh, 1)
    return nh

class Shortcut_Conv(nn.Module):

    def __init__(self, tree, input_dim, output_dim, kernel=(3,3), stride=1):
        super().__init__()
        self.tree = tree
        self._kernel = np.array(kernel, dtype=int)
        self._padding = tuple(((self._kernel-1)/2).astype(int))
        self._stride = stride
        _wd = nn.Conv2d(input_dim, output_dim, self._kernel, stride=self._stride,
                        padding=self._padding, bias=False).weight.data
        ## Shape = OutputDim, InputDim, Kernel0, Kernel1
        self.weight = nn.Parameter(
            torch.empty_like(_wd).copy_(_wd)
        )
        del _wd
        self.bn = nn.BatchNorm2d(output_dim)
    
        ## for removing and freezing neurons
        self.to_remove = None
        self.to_freeze = None
        self.initial_remove = None
        self.initial_freeze = None
        self.initial_freeze_bn = None
        
        self.add_parameters_to_optimizer()
        return
        
    def add_parameters_to_optimizer(self):
        ## internal optimizer
#         print(list(self.parameters()))
#         self.tree.optimizer.state[pp] = {'step':0, "aa":'hahaha'}

# {'step': tensor([12, 12,  6,  6,  6]),
#               'exp_avg': tensor([ 2.0893e-11,  7.7122e-10, -6.7105e-12, -5.0940e-10, -9.8008e-10]),
#               'exp_avg_sq': tensor([2.3143e-19, 1.5871e-19, 3.0733e-20, 2.8796e-20, 4.1671e-20])}
        
        for p in self.parameters():
#             self.tree.optimizer.state[p] = {}
            self.tree.optimizer.param_groups[0]['params'].append(p)
        
        
    def forward(self, x):
        if x.shape[1] > 0 and self.weight.shape[0] > 0:
            out_dim = self.weight.shape[0]
            self.weight.data /= torch.norm(self.weight.data.reshape(out_dim, -1), dim=1).reshape(out_dim, 1, 1, 1)
            
            return self.bn(F.conv2d(x, self.weight, stride=self._stride, padding=self._padding))
        ### output dim is 0
        elif self.weight.shape[0] == 0:
            ###             #num_inp  #inp_dim    #feature
            x = torch.zeros(x.shape[0], 1, x.shape[2], x.shape[3], dtype=x.dtype, device=x.device)
            ###       #out_dim #inp_dim            #kernel
            w = torch.zeros(1, 1, self.weight.shape[2], self.weight.shape[3], dtype=x.dtype, device=x.device)
            o = F.conv2d(x, w, stride=self._stride, padding=self._padding)
            return torch.zeros(o.shape[0], 0, o.shape[2], o.shape[3], dtype=x.dtype, device=x.device)
        ### input dim is 0
        elif x.shape[1] == 0:
            ###             #num_inp  #inp_dim    #feature
            x = torch.zeros(x.shape[0], 1, x.shape[2], x.shape[3], dtype=x.dtype, device=x.device)
            ###             #out_dim            #inp_dim            #kernel
            w = torch.zeros(self.weight.shape[0], 1, self.weight.shape[2], self.weight.shape[3], dtype=x.dtype, device=x.device)
            o = F.conv2d(x, w, stride=self._stride, padding=self._padding)
            return o.data
        else:
            raise(f"Unknown shape of input {x.shape} or weight {self.weight.shape}")

#     def decay_std_ratio(self, factor):
#         self.weight.data = self.weight.data - self.tree.decay_rate_std*factor.t()*self.weight.data
        
#     def decay_std_ratio_grad(self, factor):
#         self.weight.grad = self.weight.grad + self.tree.decay_rate_std*factor.t()*self.weight.data
    
    def start_decaying_connection(self, to_remove):
        self.initial_remove = self.weight.data[:, to_remove]
#         self.initial_remove = torch.atan(self.weight.data[:, to_remove])

        self.to_remove = to_remove
        self.tree.decay_connection_shortcut.add(self)
        pass
    
    def start_freezing_connection(self, to_freeze):
        self.initial_freeze = self.weight.data[to_freeze, :]
        self.initial_freeze_bn = self.bn.weight.data[to_freeze], self.bn.bias.data[to_freeze]
        self.to_freeze = to_freeze
        self.tree.freeze_connection_shortcut.add(self)
        pass
    
    ## freeze output neuron's incoming weight 
    def freeze_connection_step(self):#, to_freeze):
        self.weight.data[self.to_freeze, :] = self.initial_freeze
        self.bn.weight.data[self.to_freeze] = self.initial_freeze_bn[0] 
        self.bn.bias.data[self.to_freeze] = self.initial_freeze_bn[1] 
        pass
    
    ## decay input neuron's outgoing weight 
    def decay_connection_step(self):#, to_remove):
        self.weight.data[:, self.to_remove] = self.initial_remove*self.tree.decay_factor
#         self.weight.data[:, self.to_remove] = torch.tan(self.initial_remove*self.tree.decay_factor)
        pass
     
    ## remove output neuron 
    def remove_freezed_connection(self, remaining):
        # print(self.weight.data.shape, "removing freezed; ", self.to_freeze)
        
        ### do the same thing to optimizer variables as well        
        ops = self.tree.optimizer.state
        
        self.weight.data = self.weight.data[remaining, :]
        self.weight.grad = None
        
        if len(ops[self.weight]) > 0:
            for _var in ['step', 'exp_avg', 'exp_avg_sq']:
                ops[self.weight][_var] = \
                        ops[self.weight][_var][remaining, :]
        
        self.initial_freeze = None
        self.to_freeze = None
        
#         ## running_mean
        _rm = self.bn.running_mean[remaining]
        self.bn.running_mean = _rm
        
#         ## running_var
        _rv = self.bn.running_var[remaining]
        self.bn.running_var = _rv
        
#         ## weight
        self.bn.weight.data = self.bn.weight.data[remaining]
        self.bn.weight.grad = None
        if len(ops[self.bn.weight]) > 0:
            for _var in ['step', 'exp_avg', 'exp_avg_sq']:
                ops[self.bn.weight][_var] = \
                        ops[self.bn.weight][_var][remaining]

        ## bias
        self.bn.bias.data = self.bn.bias.data[remaining]
        self.bn.bias.grad = None
        if len(ops[self.bn.bias]) > 0:
            for _var in ['step', 'exp_avg', 'exp_avg_sq']:
                ops[self.bn.bias][_var] = \
                        ops[self.bn.bias][_var][remaining]
        
        self.bn.num_features = len(remaining)
        pass
    
    ## remove input neuron 
    def remove_decayed_connection(self, remaining):
#         print(self.weight.data.shape, "removing decayed; ", self.to_remove)
#         print(torch.count_nonzero(self.weight.data<1e-6))
#         print(self.weight.data[:, self.to_remove])

        ops = self.tree.optimizer.state

        self.weight.data = self.weight.data[:, remaining]
        if len(ops[self.weight]) > 0:
            for _var in ['step', 'exp_avg', 'exp_avg_sq']:
                ops[self.weight][_var] = \
                        ops[self.weight][_var][:, remaining]
        self.weight.grad = None
        
        self.initial_remove = None
        self.to_remove = None
        
        pass
    
    def add_input_connection(self, num):
        # print(self.weight.data.shape)
        ops = self.tree.optimizer.state
        
        o, i, k0, k1 = self.weight.data.shape
        self.weight.data = torch.cat((self.weight.data, \
                                      torch.zeros(o, num, k0, k1, dtype=self.weight.data.dtype,
                                      device=self.weight.data.device)), 
                                     dim=1)
        self.weight.grad = None
        if len(ops[self.weight]) > 0:
            for _var in ['step', 'exp_avg', 'exp_avg_sq']:
                ops[self.weight][_var] = \
                        torch.cat((ops[self.weight][_var], \
                                  torch.zeros(o, num, k0, k1, dtype=ops[self.weight][_var].dtype,
                                              device=ops[self.weight][_var].device)), 
                                  dim=1)
        # print(self.weight.data.shape)
        pass

    def add_output_connection(self, num):
        # print(self.weight.data.shape)
        ops = self.tree.optimizer.state
        
        o, i, k0, k1 = self.weight.data.shape
        stdv = 1. / np.sqrt(i) ### similar to Xavier init ?? !!
#         stdv = torch.std(self.weight.data) ## if it does not work, revert it
    
        _new = torch.empty(num, i, k0, k1, dtype=self.weight.data.dtype,
                           device=self.weight.data.device).uniform_(-stdv, stdv)
        
        self.weight.data = torch.cat((self.weight.data, _new), dim=0)
        self.weight.grad = None
        
        if len(ops[self.weight]) > 0:
            for _var in ['step', 'exp_avg', 'exp_avg_sq']:
                ops[self.weight][_var] = \
                        torch.cat((ops[self.weight][_var], \
                                  torch.zeros(num, i, k0, k1, dtype=ops[self.weight][_var].dtype,
                                              device=ops[self.weight][_var].device)), 
                                  dim=0)
        
        # print(self.weight.data.shape)
                
        ####https://pytorch.org/docs/stable/_modules/torch/nn/modules/batchnorm.html#BatchNorm2d
        ## running_mean
        _rm = self.bn.running_mean
        _rm = torch.cat((_rm, torch.zeros(num, dtype=_rm.dtype, device=_rm.device)))
        self.bn.running_mean = _rm
        
        ## running_var
        _rv = self.bn.running_var
        _rv = torch.cat((_rv, torch.ones(num, dtype=_rv.dtype, device=_rv.device)))
        self.bn.running_var = _rv
        
        ## weight
        _w = self.bn.weight.data
        _w = torch.cat((_w, torch.ones(num, dtype=_w.dtype, device=_w.device)))
        self.bn.weight.data = _w
        self.bn.weight.grad = None
        
        if len(ops[self.bn.weight]) > 0:
            for _var in ['step', 'exp_avg', 'exp_avg_sq']:
                ops[self.bn.weight][_var] = \
                        torch.cat((ops[self.bn.weight][_var], \
                                  torch.zeros(num, dtype=ops[self.bn.weight][_var].dtype,
                                              device=ops[self.bn.weight][_var].device)), 
                                 )
        
        ## bias
        _b = self.bn.bias.data
        _b = torch.cat((_b, torch.zeros(num, dtype=_b.dtype, device=_b.device)))
        self.bn.bias.data = _b
        self.bn.bias.grad = None
        
        if len(ops[self.bn.weight]) > 0:
            for _var in ['step', 'exp_avg', 'exp_avg_sq']:
                ops[self.bn.bias][_var] = \
                        torch.cat((ops[self.bn.bias][_var], \
                                  torch.zeros(num, dtype=ops[self.bn.bias][_var].dtype,
                                              device=ops[self.bn.bias][_var].device)), 
                                 )
        
        self.bn.num_features += num
        pass
    
    def print_network_debug(self, depth):
        print(f"{'║     '*depth}S▚:{depth}[{self.weight.data.shape[1]},{self.weight.data.shape[0]}]")


In [6]:
class TempTree():
    def __init__(self):
        self.optimizer = adam_custom.Adam([nn.Parameter(torch.Tensor(0))],
                                          lr=0.0001)

In [7]:
tree = TempTree()

a = Shortcut_Conv(tree, 2, 1)
a

Shortcut_Conv(
  (bn): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [8]:
class NonLinearity_Conv(nn.Module):

    def __init__(self, tree, io_dim, actf_obj=nn.ReLU()):
        super().__init__()
        self.tree = tree
        self.actf = actf_obj

    def forward(self, x):
        return self.actf(x)

    def add_neuron(self, num):
        pass
        
    def remove_neuron(self, remaining):
        pass

In [9]:
class NonLinearity(nn.Module):

    def __init__(self, tree, io_dim, actf_obj=nn.ReLU()):
        super().__init__()
        self.tree = tree
        self.bias = nn.Parameter(torch.zeros(io_dim))
        self.actf = actf_obj
        
        self.tree.optimizer.state[self.bias] = {}
        tree.optimizer.param_groups[0]['params'].append(self.bias)

    def forward(self, x):
        return self.actf(x+self.bias)

    def add_neuron(self, num):
        _b = torch.cat((self.bias.data, torch.zeros(num, dtype=self.bias.data.dtype,
                                                    device=self.bias.data.device)))
        del self.bias
        self.bias = nn.Parameter(_b)
        
    def remove_neuron(self, remaining):
        _b = self.bias.data[remaining]
        del self.bias
        self.bias = nn.Parameter(_b)

In [10]:
class Residual_Conv(nn.Module):

    def __init__(self, tree, input_dim, hidden_dim, output_dim, stride=1, activation=nn.ReLU()):
        super().__init__()
        self.tree = tree
        self.hidden_dim = hidden_dim
#         self.stride = stride
        self.del_neurons = 0.
        self.neurons_added = 0

        ## Shortcut or Hierarchical Residual Layer
        self.fc0 = HierarchicalResidual_Conv(self.tree, input_dim, hidden_dim, stride=stride, activation=activation) 
        self.non_linearity = NonLinearity_Conv(self.tree, hidden_dim, activation)
        self.fc1 = HierarchicalResidual_Conv(self.tree, hidden_dim, output_dim, activation=activation)
        self.fc1.shortcut.bn.weight.data *= 0.        
        self.fc1.shortcut.weight.data *= 0.1        
        
        self.tree.parent_dict[self.fc0] = self
        self.tree.parent_dict[self.fc1] = self
        self.tree.parent_dict[self.non_linearity] = self
        
        self.hook = None
        self.activations = None
        self.significance = None
        self.count = None
        self.apnz = None
        self.to_remove = None
    
    def forward(self, x):
        x = self.fc0(x)
        x = self.non_linearity(x)
        self.activations = x.data
        x = self.fc1(x)
        return x
    
    def start_computing_significance(self):
        self.significance = 0.
        self.count = 0
        self.apnz = 0
        self.hook = self.non_linearity.register_backward_hook(self.compute_neuron_significance)
        pass
            
    def finish_computing_significance(self):
        self.hook.remove()
        self.significance = self.significance#/self.count
#         print(f"Significance before rethinking(apnz)\n{self.significance}")
#         print(f"Apnz\n{self.apnz}")
        if isinstance(self.non_linearity.actf, nn.ReLU):
            self.apnz = self.apnz/self.count
#             self.significance = self.significance*(1-self.apnz) * 4 ## tried on desmos.
            self.significance = self.significance*(1-self.apnz**33) / 0.872 ## tried on desmos.

#         print(f"Significance after rethinking(apnz)\n{self.significance}")
#         self.count = None

        self.hook = None
        pass
    
    def compute_neuron_significance(self, _class, grad_input, grad_output):
        with torch.no_grad():
            z = torch.sum(grad_output[0].data*self.activations, dim=(2,3))
#             self.significance += z.pow(2).sum(dim=0)
            self.significance += z.abs().sum(dim=0)
#             self.significance += z.abs().pow(0.8).sum(dim=0)
#             print(f"SIG ACT:\n{float(self.activations.abs().mean())}")
#             print(f"GRAD Mean, Std:\n{float(grad_output[0].data.abs().mean()), float(grad_output[0].data.std())}")

            if isinstance(self.non_linearity.actf, nn.ReLU):
                self.count += grad_output[0].shape[0]*grad_output[0].shape[2]*grad_output[0].shape[3]
        #         self.apnz += torch.count_nonzero(self.activations.data, dim=0)
                self.apnz += torch.sum(self.activations > 0., dim=(0,2,3), dtype=z.dtype).to(z.device)
        pass
    
    def identify_removable_neurons(self, below=None, above=None, mask=None):
        if self.to_remove is not None:
            print("First remove all previous less significant neurons")
            return
        if mask is None:
            mask = torch.zeros(self.significance.numel(), dtype=torch.bool)
        if below:
            mask = torch.logical_or(mask,self.significance<=below)
        if above:
            mask = torch.logical_or(mask,self.significance>above)
            
        print(f"Significance:\n{self.significance}\nPrune:\n{mask}")
            
        self.to_remove = torch.nonzero(mask).reshape(-1)
        if len(self.to_remove)>0:
            self.fc0.start_freezing_connection(self.to_remove)
            self.fc1.start_decaying_connection(self.to_remove)
            self.tree.remove_neuron_residual.add(self)
            return len(self.to_remove)
        
        self.to_remove = None
        return 0

    def remove_decayed_neurons(self):
        remaining = []
        for i in range(self.hidden_dim):
            if i not in self.to_remove:
                remaining.append(i)
        
        self.non_linearity.remove_neuron(remaining)
        self.fc0.remove_freezed_connection(remaining)
        self.fc1.remove_decayed_connection(remaining)
        
        self.neurons_added -= len(self.to_remove)
        self.hidden_dim = len(remaining)
        self.to_remove = None
        pass
    
    def compute_del_neurons(self):
        self.del_neurons = (1-self.tree.beta_del_neuron)*self.neurons_added \
                            + self.tree.beta_del_neuron*self.del_neurons
        self.neurons_added = 0
        return
    
    def add_hidden_neuron(self, num):
        self.fc0.add_output_connection(num)
        self.non_linearity.add_neuron(num)
        self.fc1.add_input_connection(num)
        
        self.hidden_dim += num
        self.neurons_added += num
        pass

    def morph_network(self):
        self.fc0.morph_network()
        self.fc1.morph_network()
#         max_dim = np.ceil((self.tree.parent_dict[self].input_dim+\
#             self.tree.parent_dict[self].output_dim)/2)
        max_dim = _get_hidden_neuron_number(self.tree.parent_dict[self].input_dim,
            self.tree.parent_dict[self].output_dim)+1
        if self.hidden_dim <= max_dim:
            if self.fc0.residual is None: ## it is shortcut conv
                if self.fc0 in self.tree.DYNAMIC_LIST:
                    self.tree.DYNAMIC_LIST.remove(self.fc0)
            if self.fc1.residual is None:
                if self.fc1 in self.tree.DYNAMIC_LIST:
                    self.tree.DYNAMIC_LIST.remove(self.fc1)
        return 

    def print_network_debug(self, depth):
        print(f"{'║     '*depth}R▚:{depth}[{self.hidden_dim}|{self.non_linearity.bias.data.shape[0]}]")
        self.fc0.print_network_debug(depth+1)
        self.fc1.print_network_debug(depth+1)
        
    def print_network(self, pre_string):
        self.fc0.print_network(pre_string)
        print(f"{pre_string}{self.hidden_dim}")
        self.fc1.print_network(pre_string)
        return

In [11]:
class HierarchicalResidual_Conv(nn.Module):

    def __init__(self, tree, input_dim, output_dim, stride=1, activation=nn.ReLU()):
        super().__init__()

        self.tree = tree
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.stride = 1
        
        self.activation = activation
        
        ## this can be Shortcut Layer or None
        self.shortcut = Shortcut_Conv(tree, self.input_dim, self.output_dim, stride=self.stride).to(self.tree.device)
        self.tree.parent_dict[self.shortcut] = self
        
        self.residual = None ## this can be Residual Layer or None
        ##### only one of shortcut or residual can be None at a time
        self.forward = self.forward_shortcut
        
        self.std_ratio = 0. ## 0-> all variation due to shortcut, 1-> residual
        self.target_std_ratio = 0. ##
    
    def forward_both(self, r):

        s = self.shortcut(r)
        r = self.residual(r)

        if self.residual.hook is None: ### dont execute when computing significance
            s_std = torch.std(s, dim=(0,2,3), keepdim=True).reshape(1, -1)
            r_std = torch.std(r, dim=(0,2,3), keepdim=True).reshape(1, -1)
            stdr = r_std/(s_std+r_std)

            self.std_ratio = self.tree.beta_std_ratio*self.std_ratio + (1-self.tree.beta_std_ratio)*stdr.data
            if r_std.min() > 1e-9:
                ## recover for the fact that when decaying neurons, target ratio should also be reducing
                if self.tree.total_decay_steps:
                    i, o = self.shortcut.weight.shape[1],self.shortcut.weight.shape[0]
                    if self.shortcut.to_remove is not None:
                        i -= len(self.shortcut.to_remove)
                    if self.shortcut.to_freeze is not None:
                        o -= len(self.shortcut.to_freeze)
                    h = self.residual.hidden_dim
                    if self.residual.to_remove is not None:
                        h -= len(self.residual.to_remove)
                    
#                     tr = h/np.ceil((i+o)/2 +1)
                    tr = h/_get_hidden_neuron_number(i, o)
                    self.compute_target_std_ratio(tr)
                else:
                    self.compute_target_std_ratio()
                self.get_std_loss(stdr)
        return s+r
    
    def forward_shortcut(self, x):
        return self.shortcut(x)
    
    def forward_residual(self, x):
        self.compute_target_std_ratio()
        return self.residual(x)
    
    def compute_target_std_ratio(self, tr = None):
        if tr is None:
#             tr = self.residual.hidden_dim/np.ceil((self.input_dim+self.output_dim)/2 +1)
            tr = self.residual.hidden_dim/_get_hidden_neuron_number(self.input_dim, self.output_dim)
#             tr = self.residual.hidden_dim/np.ceil(self.output_dim/2 +1)

        tr = np.clip(tr, 0., 1.)
        self.target_std_ratio = self.tree.beta_std_ratio*self.target_std_ratio +\
                                (1-self.tree.beta_std_ratio)*tr
        pass        
    
    def get_std_loss(self, stdr):
        del_std = self.target_std_ratio-stdr
        del_std_loss = (del_std**2 + torch.abs(del_std)).mean()
#         del_std_loss = (del_std**2).mean()
        self.tree.std_loss += del_std_loss
        return
            
    def start_freezing_connection(self, to_freeze):
        if self.shortcut:
            self.shortcut.start_freezing_connection(to_freeze)
        if self.residual:
            self.residual.fc1.start_freezing_connection(to_freeze)
        pass
        
    def start_decaying_connection(self, to_remove):
        if self.shortcut:
            self.shortcut.start_decaying_connection(to_remove)
        if self.residual:
            self.residual.fc0.start_decaying_connection(to_remove)
        pass
    
    def remove_freezed_connection(self, remaining):
        if self.shortcut:
            self.shortcut.remove_freezed_connection(remaining)
        if self.residual:
            self.residual.fc1.remove_freezed_connection(remaining)
            if self.shortcut: self.std_ratio = self.std_ratio[:, remaining]
        self.output_dim = len(remaining)
        pass
    
    def remove_decayed_connection(self, remaining):
        if self.shortcut:
            self.shortcut.remove_decayed_connection(remaining)
        if self.residual:
            self.residual.fc0.remove_decayed_connection(remaining)
        self.input_dim = len(remaining)
        pass
    
    def add_input_connection(self, num):
        self.input_dim += num
        if self.shortcut:
            self.shortcut.add_input_connection(num)
        if self.residual:
            self.residual.fc0.add_input_connection(num)

    def add_output_connection(self, num):
        self.output_dim += num
        if self.shortcut:
            self.shortcut.add_output_connection(num)
        if self.residual:
            self.residual.fc1.add_output_connection(num)
            # if torch.is_tensor(self.std_ratio):
            if self.shortcut:
                self.std_ratio = torch.cat((self.std_ratio, torch.zeros(1, num, device=self.tree.device)), dim=1)

    def add_hidden_neuron(self, num):
        if num<1: return
        
        if self.residual is None:
            # print(f"Adding {num} hidden units.. in new residual_layer")
            self.residual = Residual_Conv(self.tree, self.input_dim,
                                          num, self.output_dim, stride=self.stride,
                                          activation=self.activation).to(self.tree.device)
            
            self.tree.parent_dict[self.residual] = self
            if self.shortcut is None:
                self.forward = self.forward_residual
                self.std_ratio = 1.
            else:
                self.forward = self.forward_both
                self.std_ratio = torch.zeros(1, self.output_dim, device=self.tree.device)
                
        else:
            # print(f"Adding {num} hidden units..")
            self.residual.add_hidden_neuron(num)
        return
    
    def maintain_shortcut_connection(self):
        if self.residual is None: return
        
        if self.shortcut:
            if self.std_ratio.min()>0.98 and self.target_std_ratio>0.98:
                del self.tree.parent_dict[self.shortcut]
                del self.shortcut
                self.shortcut = None
                self.forward = self.forward_residual
                self.std_ratio = 1.
            
        elif self.target_std_ratio<0.95:
            self.shortcut = Shortcut_Conv(self.tree, self.input_dim, self.output_dim, stride=self.stride)
            self.shortcut.bn.weight.data *= 0.
            self.shortcut.weight.data *= 0.1
            self.forward = self.forward_both
            
        self.residual.fc0.maintain_shortcut_connection()
        self.residual.fc1.maintain_shortcut_connection()
        
    def morph_network(self):
        if self.residual is None: return
        
        if self.residual.hidden_dim < 1:
            del self.tree.parent_dict[self.residual]
            del self.residual
            ### its parent (Residual_Conv) removes it from dynamic list if possible
            self.residual = None
            self.forward = self.forward_shortcut
            self.std_ratio = 0.
            return
        
#         max_dim = np.ceil((self.input_dim+self.output_dim)/2)
        # max_dim = min((self.input_dim, self.output_dim))+1
        max_dim = _get_hidden_neuron_number(self.input_dim, self.output_dim) + 1 
        # print("MaxDIM", max_dim, self.residual.hidden_dim)
        if self.residual.hidden_dim > max_dim:
            self.tree.DYNAMIC_LIST.add(self.residual.fc0)
            self.tree.DYNAMIC_LIST.add(self.residual.fc1)
            # print("Added", self.residual)
            
        # self.residual.fc0.morph_network()
        # self.residual.fc1.morph_network()
        self.residual.morph_network()
        
    def print_network_debug(self, depth):
        stdr = self.std_ratio
        if torch.is_tensor(self.std_ratio):
            stdr = self.std_ratio.min()
            
        print(f"{'|     '*depth}H:{depth}[{self.input_dim},{self.output_dim}]"+\
              f"σ[t:{self.target_std_ratio}, s:{stdr}")
        if self.shortcut:
            self.shortcut.print_network_debug(depth+1)
        if self.residual:
            self.residual.print_network_debug(depth+1)
        pass
    
    def print_network(self, pre_string=""):
        if self.residual is None:
            return
        
        if self.shortcut:
            print(f"{pre_string}╠════╗")
            self.residual.print_network(f"{pre_string}║    ")
            print(f"{pre_string}╠════╝")
        else:
            print(f"{pre_string}╚════╗")
            self.residual.print_network(f"{pre_string}     ")
            print(f"{pre_string}╔════╝")
        return

### Conv Conv Connector

In [12]:
class Residual_Conv_Connector(nn.Module):

    def __init__(self, tree, hrnet0, hrnet1, activation, hidden_dim, post_activation=None):
        super().__init__()
        self.tree = tree
        self.hidden_dim = hidden_dim
        self.del_neurons = 0.
        self.neurons_added = 0
        self.post_activation = post_activation

        ## Shortcut or Hierarchical Residual Layer
        self.fc0 = hrnet0
        self.non_linearity = NonLinearity_Conv(self.tree, hidden_dim, activation)
        self.fc1 = hrnet1
        
        self.tree.parent_dict[self.fc0] = self
        self.tree.parent_dict[self.fc1] = self
        self.tree.parent_dict[self.non_linearity] = self
        
        self.hook = None
        self.activations = None
        self.significance = None
        self.count = None
        self.apnz = None
        self.to_remove = None
    
    def forward(self, x):
        x = self.fc0(x)
        x = self.non_linearity(x)
        self.activations = x.data
        if self.post_activation:
            x = self.post_activation(x)
        x = self.fc1(x)
        return x
    
    def start_computing_significance(self):
        self.significance = 0.
        self.count = 0
        self.apnz = 0
        self.hook = self.non_linearity.register_backward_hook(self.compute_neuron_significance)
        pass
            
    def finish_computing_significance(self):
        self.hook.remove()
        self.significance = self.significance#/self.count
#         print(f"Significance before rethinking(apnz)\n{self.significance}")
#         print(f"Apnz\n{self.apnz}")
        if isinstance(self.non_linearity.actf, nn.ReLU):
            self.apnz = self.apnz/self.count
#             self.significance = self.significance*(1-self.apnz) * 4 ## tried on desmos.
            self.significance = self.significance*(1-self.apnz**33) / 0.872 ## tried on desmos.
#         print(f"Significance after rethinking(apnz)\n{self.significance}")
#         self.count = None

        self.hook = None
        pass
    
    
    def compute_neuron_significance(self, _class, grad_input, grad_output):
        with torch.no_grad():
            z = torch.sum(grad_output[0].data*self.activations, dim=(2,3))
#             self.significance += z.pow(2).sum(dim=0)
            self.significance += z.abs().sum(dim=0)
#             self.significance += z.abs().pow(0.8).sum(dim=0)
#             print("Current Significance \n", self.significance)
#             print(f"SIG ACT:\n{float(self.activations.abs().mean())}")
#             print(f"GRAD Mean, Std:\n{float(grad_output[0].data.abs().mean()), float(grad_output[0].data.std())}")

            if isinstance(self.non_linearity.actf, nn.ReLU):
                self.count += grad_output[0].shape[0]*grad_output[0].shape[2]*grad_output[0].shape[3]
        #         self.apnz += torch.count_nonzero(self.activations.data, dim=0)
                self.apnz += torch.sum(self.activations > 0., dim=(0,2,3), dtype=z.dtype).to(z.device)
        pass
    
    def identify_removable_neurons(self, below=None, above=None, mask=None):
        if self.to_remove is not None:
            print("First remove all previous less significant neurons")
            return
        if mask is None:
            mask = torch.zeros(self.significance.numel(), dtype=torch.bool)
        if below:
            mask = torch.logical_or(mask,self.significance<=below)
        if above:
            mask = torch.logical_or(mask,self.significance>above)
            
        print(f"Significance:\n{self.significance}\nPrune:\n{mask}")
        
        self.to_remove = torch.nonzero(mask).reshape(-1)
        if len(self.to_remove)>0:
            self.fc0.start_freezing_connection(self.to_remove)
            self.fc1.start_decaying_connection(self.to_remove)
            self.tree.remove_neuron_residual.add(self)
            return len(self.to_remove)
        
        self.to_remove = None
        return 0

    def remove_decayed_neurons(self):
        remaining = []
        for i in range(self.hidden_dim):
            if i not in self.to_remove:
                remaining.append(i)
        
        self.non_linearity.remove_neuron(remaining)
        self.fc0.remove_freezed_connection(remaining)
        self.fc1.remove_decayed_connection(remaining)
        
        self.neurons_added -= len(self.to_remove)
        self.hidden_dim = len(remaining)
        self.to_remove = None
        pass
    
    def compute_del_neurons(self):
        self.del_neurons = (1-self.tree.beta_del_neuron)*self.neurons_added \
                            + self.tree.beta_del_neuron*self.del_neurons
        self.neurons_added = 0
        return
    
    def add_hidden_neuron(self, num):
        self.fc0.add_output_connection(num)
        self.non_linearity.add_neuron(num)
        self.fc1.add_input_connection(num)
        
        self.hidden_dim += num
        self.neurons_added += num
        pass

    def morph_network(self):
        self.fc0.morph_network()
        self.fc1.morph_network()
        max_dim = _get_hidden_neuron_number(self.tree.parent_dict[self].input_dim,
            self.tree.parent_dict[self].output_dim)+1
        if self.hidden_dim <= max_dim:
            if self.fc0.residual is None:
                if self.fc0 in self.tree.DYNAMIC_LIST:
                    self.tree.DYNAMIC_LIST.remove(self.fc0)
            if self.fc1.residual is None:
                if self.fc1 in self.tree.DYNAMIC_LIST:
                    self.tree.DYNAMIC_LIST.remove(self.fc1)
        return 

    def print_network_debug(self, depth):
        print(f"{'║     '*depth}R▚:{depth}[{self.hidden_dim}|{self.non_linearity.bias.data.shape[0]}]")
        self.fc0.print_network_debug(depth+1)
        self.fc1.print_network_debug(depth+1)
        
    def print_network(self, pre_string):
        self.fc0.print_network(pre_string)
        print(f"{pre_string}{self.hidden_dim}")
        self.fc1.print_network(pre_string)
        return

In [13]:
a = torch.randn(10)<0 
b = torch.randn(10) > 0.5
torch.nonzero(torch.logical_and(a,b), as_tuple=False)

tensor([[4]])

In [14]:
class HierarchicalResidual_Connector(nn.Module):

    def __init__(self, tree, hrnet0, hrnet1, activation=nn.ReLU(), post_activation=None):
        super().__init__()

        self.tree = tree
        self.input_dim = hrnet0.input_dim
        self.output_dim = hrnet1.output_dim
        
        ## this can be Shortcut Layer or None
        self.shortcut = None
        self.residual = Residual_Conv_Connector(self.tree, hrnet0, hrnet1,
                                                activation, hrnet0.output_dim, post_activation)
        self.tree.parent_dict[self.residual] = self
            
    
    def forward(self, x):
        return self.residual(x)
    
    def start_freezing_connection(self, to_freeze):
        self.residual.fc1.start_freezing_connection(to_freeze)
        pass
        
    def start_decaying_connection(self, to_remove):
        self.residual.fc0.start_decaying_connection(to_remove)
        pass
    
    def remove_freezed_connection(self, remaining):
        self.residual.fc1.remove_freezed_connection(remaining)
        self.output_dim = len(remaining)
        pass
    
    def remove_decayed_connection(self, remaining):
        self.residual.fc0.remove_decayed_connection(remaining)
        self.input_dim = len(remaining)
        pass
    
    def add_input_connection(self, num):
        self.input_dim += num
        self.residual.fc0.add_input_connection(num)

    def add_output_connection(self, num):
        self.output_dim += num
        self.residual.fc1.add_output_connection(num)
        
    def add_hidden_neuron(self, num):
        if num<1: return
        self.residual.add_hidden_neuron(num)
        return
    
    def maintain_shortcut_connection(self):  
        self.residual.fc0.maintain_shortcut_connection()
        self.residual.fc1.maintain_shortcut_connection()
        
    def morph_network(self):
        self.residual.morph_network()
        
    def print_network_debug(self, depth):
        print(f"{'|     '*depth}H:{depth}[{self.input_dim},{self.output_dim}]"+\
              f"σ[t:{None}, s:{None}")
        self.residual.print_network_debug(depth+1)
        pass
    
    def print_network(self, pre_string=""):
        print(f"{pre_string}╚╗")
        self.residual.print_network(f"{pre_string} ")
        print(f"{pre_string}╔╝")
        return

## Shortcut only Hierarchical Residual Network

In [15]:
class Shortcut(nn.Module):

    def __init__(self, tree, input_dim, output_dim):
        super().__init__()
        self.tree = tree
        _wd = nn.Linear(input_dim, output_dim, bias=False).weight.data
        self.weight = nn.Parameter(
            torch.empty_like(_wd).copy_(_wd)
        )
    
        ## for removing and freezing neurons
        self.to_remove = None
        self.to_freeze = None
        self.initial_remove = None
        self.initial_freeze = None
        self.add_parameters_to_optimizer()
        return
        
    def add_parameters_to_optimizer(self):
        for p in self.parameters():
#             self.tree.optimizer.state[p] = {}
            self.tree.optimizer.param_groups[0]['params'].append(p)
        return
    
    def forward(self, x):
        ## input_dim        ## output_dim
        if x.shape[1] + self.weight.shape[1] > 0:
            return x.matmul(self.weight.t())
        else:
            # print(x.shape, self.weight.shape)
            # print(x.matmul(self.weight.t()))
            if x.shape[1] + self.weight.shape[1] == 0:
                return torch.zeros(x.shape[0], self.weight.shape[0], dtype=x.dtype, device=x.device)
        
    def start_decaying_connection(self, to_remove):
        self.initial_remove = self.weight.data[:, to_remove]
        self.to_remove = to_remove
        self.tree.decay_connection_shortcut.add(self)
        pass
    
    def start_freezing_connection(self, to_freeze):
        self.initial_freeze = self.weight.data[to_freeze, :]
        self.to_freeze = to_freeze
        self.tree.freeze_connection_shortcut.add(self)
        pass
    
    def freeze_connection_step(self):#, to_freeze):
        self.weight.data[self.to_freeze, :] = self.initial_freeze
        pass
    
    def decay_connection_step(self):#, to_remove):
        self.weight.data[:, self.to_remove] = self.initial_remove*self.tree.decay_factor
        pass
            
     
    def remove_freezed_connection(self, remaining):
        # print(self.weight.data.shape, "removing freezed; ", self.to_freeze)
#         _w = self.weight.data[remaining, :]
#         del self.weight
#         self.weight = nn.Parameter(_w)
        ops = self.tree.optimizer.state
        
        self.weight.data = self.weight.data[remaining, :]
        self.weight.grad = None
        
        if len(ops[self.weight]) > 0:
            for _var in ['step', 'exp_avg', 'exp_avg_sq']:
                ops[self.weight][_var] = \
                        ops[self.weight][_var][remaining, :]
        
        self.initial_freeze = None
        self.to_freeze = None
        pass
    
    def remove_decayed_connection(self, remaining):
        # print(self.weight.data.shape, "removing decayed; ", self.to_remove)
#         _w = self.weight.data[:, remaining]
#         del self.weight
#         self.weight = nn.Parameter(_w)
        ops = self.tree.optimizer.state

        self.weight.data = self.weight.data[:, remaining]
        self.weight.grad = None
        
        if len(ops[self.weight]) > 0:
            for _var in ['step', 'exp_avg', 'exp_avg_sq']:
                ops[self.weight][_var] = \
                        ops[self.weight][_var][:, remaining]

        self.initial_remove = None
        self.to_remove = None
        pass
    
    def add_input_connection(self, num):
        ops = self.tree.optimizer.state
        
        # print(self.weight.data.shape)
        o, i = self.weight.data.shape
        _w = torch.zeros(o, num, dtype=self.weight.data.dtype, device=self.weight.data.device)
#         _w += torch.randn_like(_w)
        _w = torch.cat((self.weight.data, _w), dim=1)
        self.weight.data = _w
        self.weight.grad = None
                
        if len(ops[self.weight]) > 0:
            for _var in ['step', 'exp_avg', 'exp_avg_sq']:
                ops[self.weight][_var] = \
                        torch.cat((ops[self.weight][_var], \
                                  torch.zeros(o, num, dtype=ops[self.weight][_var].dtype,
                                              device=ops[self.weight][_var].device)), 
                                  dim=1)
        
        # print(self.weight.data.shape)
        pass

    def add_output_connection(self, num):
        ops = self.tree.optimizer.state
        
        # print(self.weight.data.shape)
        o, i = self.weight.data.shape
        stdv = 1. / np.sqrt(i)
#         stdv = torch.std(self.weight.data)
    
        _new = torch.empty(num, i, dtype=self.weight.dtype,
                           device=self.weight.data.device).uniform_(-stdv, stdv)
        
        _w = torch.cat((self.weight.data, _new), dim=0)
        self.weight.data = _w
        self.weight.grad = None
        
        
        if len(ops[self.weight]) > 0:
            for _var in ['step', 'exp_avg', 'exp_avg_sq']:
                ops[self.weight][_var] = \
                        torch.cat((ops[self.weight][_var], \
                                  torch.zeros(num, i, k0, k1, dtype=ops[self.weight][_var].dtype,
                                              device=ops[self.weight][_var].device)), 
                                  dim=0)
        
        # print(self.weight.data.shape)        
        pass
    
    def print_network_debug(self, depth):
        print(f"{'|     '*depth}S:{depth}[{self.weight.data.shape[1]},{self.weight.data.shape[0]}]")


class HierarchicalResidual_Shortcut(nn.Module):

    def __init__(self, tree, input_dim, output_dim, kernel=None, stride=1):
        super().__init__()

        self.tree = tree
        self.input_dim = input_dim
        self.output_dim = output_dim
        ## this can be Shortcut Layer or None
        if kernel is None:
            self.shortcut = Shortcut(tree, self.input_dim, self.output_dim) 
        else:
            self.shortcut = Shortcut_Conv(tree, self.input_dim, self.output_dim, kernel, stride) 
        self.tree.parent_dict[self.shortcut] = self
        
        self.residual = None
    
    def forward(self, x):
        return self.shortcut(x)
    
    def start_freezing_connection(self, to_freeze):
        self.shortcut.start_freezing_connection(to_freeze)
        
    def start_decaying_connection(self, to_remove):
        self.shortcut.start_decaying_connection(to_remove)
    
    def remove_freezed_connection(self, remaining):
        self.shortcut.remove_freezed_connection(remaining)
        self.output_dim = len(remaining)
    
    def remove_decayed_connection(self, remaining):
        self.shortcut.remove_decayed_connection(remaining)
        self.input_dim = len(remaining)
        pass
    
    def add_input_connection(self, num):
        self.input_dim += num
        self.shortcut.add_input_connection(num)

    def add_output_connection(self, num):
        self.output_dim += num
        self.shortcut.add_output_connection(num)

    def add_hidden_neuron(self, num):
        print("Cannot Add Hidden neuron to Shortcut Only Layer")
        return
    
    def maintain_shortcut_connection(self):
        pass
        
    def morph_network(self):
        pass
        
    def print_network_debug(self, depth):
        print(f"{'|     '*depth}H:{depth}[{self.input_dim},{self.output_dim}]"+\
              f"σ[t:{None}, s:{None}")
        self.shortcut.print_network_debug(depth+1)
        pass
    
    def print_network(self, pre_string=""):
        pass

# Tree and Controller

In [16]:
class Tree_State():
    
    def __init__(self):
        self.DYNAMIC_LIST = set() ## residual parent is added, to make code effecient.
        ## the parents which is not intended to have residual connection should not be added.
        self.beta_std_ratio = None
        self.beta_del_neuron = None
        self.device = 'cpu'
    
        self.parent_dict = {}
    
        self.total_decay_steps = None
        self.current_decay_step = None
        self.decay_factor = None
        self.remove_neuron_residual:set = None
        self.freeze_connection_shortcut:set = None
        self.decay_connection_shortcut:set = None

        self.decay_rate_std = 0.001

        self.add_to_remove_ratio = 2.
        
#         self.dummy_param = nn.Parameter(torch.Tensor([0]))
#         self.optimizer = adam_custom.Adam([self.dummy_param])
        self.optimizer = None
    
        pass
    
    def get_decay_factor(self):
        ratio = self.current_decay_step/self.total_decay_steps
#         self.decay_factor = np.exp(-2*ratio)*(1-ratio)
        ratio = np.clip(ratio, 0, 1)
        self.decay_factor = (1-ratio)**2
#         self.decay_factor = (1-ratio)
        pass
    
    def clear_decay_variables(self):
        self.total_decay_steps = None
        self.current_decay_step = None
        self.decay_factor = None
        self.remove_neuron_residual = None
        self.freeze_connection_shortcut = None
        self.decay_connection_shortcut = None
        

In [17]:
t = Tree_State()

## constructing Hierarchical Residual CNN (Resnet Inspired)

In [18]:
class DropoutActivation(nn.Module):
    
    def __init__(self, p=0.1, activation=nn.ReLU()):
        super().__init__()
        self.dropout = nn.Dropout2d(p) ## ok to reuse dropout !! caution
        self.activation = activation
        
    def forward(self, x):
        return self.activation(self.dropout(x))

In [19]:
class Dynamic_CNN(nn.Module):

    def __init__(self, device, lr, input_dim = 1, hidden_dims = [8, 16, 32, 64], output_dim = 10, final_activation=None,
                 num_stat=5, num_std=100, decay_rate_std=0.001):
        super().__init__()
        self.tree = Tree_State()
        self.tree.beta_del_neuron = (num_stat-1)/num_stat
        self.tree.beta_std_ratio = (num_std-1)/num_std
        self.tree.decay_rate_std = decay_rate_std
        self.tree.device = device
        
        
        dummy_param = nn.Parameter(torch.Tensor([0]))
        ############################################################
        self.tree.optimizer = adam_custom.Adam([dummy_param], lr=lr, weight_decay=1e-5)
        self.tree.optimizer.param_groups[0]['params'] = []
        ############################################################
        
        
        self.root_net = None
        self._construct_root_net(input_dim, hidden_dims, output_dim)
#         self._construct_root_net2(input_dim, hidden_dims, output_dim)
        
#         self.tree.DYNAMIC_LIST.add(self.root_net)
        self.tree.parent_dict[self.root_net] = None
        
        if final_activation is None:
            final_activation = lambda x: x
        self.non_linearity = NonLinearity(self.tree, output_dim, final_activation)
        
        self.neurons_added = 0

        self._remove_below = None ## temporary variable
        
    def _construct_root_net(self, input_dim, hidden_dims, output_dim):
        
        actf = DropoutActivation()
#         actf = lambda x: x
#         actf = nn.ReLU()

        hrnR = HierarchicalResidual_Shortcut(self.tree, 3, 16, kernel=(3,3), stride=1)
        hrn0 = HierarchicalResidual_Conv(self.tree, 16, 16, activation=actf)
        hrn1 = HierarchicalResidual_Conv(self.tree, 16, 32, stride=2, activation=actf)
        hrn2 = HierarchicalResidual_Conv(self.tree, 32, 64, stride=2, activation=actf)
        hrn3 = HierarchicalResidual_Conv(self.tree, 64, 128, stride=2, activation=actf)

    
        actf = lambda x: x
        hrnR0 = HierarchicalResidual_Connector(self.tree, hrnR, hrn0, actf)
        hrnR01 = HierarchicalResidual_Connector(self.tree, hrnR0, hrn1, actf)
        hrnR012 = HierarchicalResidual_Connector(self.tree, hrnR01, hrn2, actf)
        hrnR0123 = HierarchicalResidual_Connector(self.tree, hrnR012, hrn3, actf)
        hrnfc = HierarchicalResidual_Shortcut(self.tree, 128, output_dim)
        
        def pool_and_reshape(x):
            x = F.adaptive_avg_pool2d(x, (1,1))
            x = x.view(x.shape[0], -1)
            return x
        
        actf = lambda x: x
#         actf = nn.ReLU()

        hrnR0123fc = HierarchicalResidual_Connector(self.tree, hrnR0123, hrnfc,
                                                   activation=actf, post_activation=pool_and_reshape)
        self.root_net = hrnR0123fc
        
        ## make every hierarchical Layer Morphable
        morphables = [self.root_net, hrnR0123, hrnR012, hrnR01, hrnR0, hrn3, hrn2, hrn1, hrn0]
#         morphables = [self.root_net, hrn0123, hrn012, hrn01]
        for hr in morphables:
            self.tree.DYNAMIC_LIST.add(hr)
        return
    
    def _construct_root_net2(self, input_dim, hidden_dims, output_dim):
        
        actf = DropoutActivation()
        
#         hrnR = HierarchicalResidual_Shortcut(self.tree, 3, 16, kernel=(3,3), stride=1)
#         hrn0 = HierarchicalResidual_Conv(self.tree, 16, 16)
#         hrn1 = HierarchicalResidual_Conv(self.tree, 16, 32, stride=2)
#         hrn2 = HierarchicalResidual_Conv(self.tree, 32, 64, stride=2)

        hrnR = HierarchicalResidual_Shortcut(self.tree, 3, 32, kernel=(3,3), stride=1)
        hrn0 = HierarchicalResidual_Conv(self.tree, 32, 32)
        hrn1 = HierarchicalResidual_Conv(self.tree, 32, 64, stride=2)
        hrn2 = HierarchicalResidual_Conv(self.tree, 64, 128, stride=2)


#         actf = lambda x: x ## don't use connector with activation
#         actf = nn.ReLU() ## use connector with activation
    
        hrnR0 = HierarchicalResidual_Connector(self.tree, hrnR, hrn0, actf)
        hrnR01 = HierarchicalResidual_Connector(self.tree, hrnR0, hrn1, actf)
        hrnR012 = HierarchicalResidual_Connector(self.tree, hrnR01, hrn2, actf)
        
        hrnfc = HierarchicalResidual_Shortcut(self.tree, 128, output_dim)
        
        def pool_and_reshape(x):
            x = F.adaptive_avg_pool2d(x, (1,1))
            x = x.view(x.shape[0], -1)
            return x
        
#         actf = lambda x: x
#         actf = nn.ReLU()

        hrnR012fc = HierarchicalResidual_Connector(self.tree, hrnR012, hrnfc,
                                                   activation=actf, post_activation=pool_and_reshape)
        self.root_net = hrnR012fc
        
        ## make every hierarchical Layer Morphable
#         morphables = [hrn2, hrn1, hrn0]
        morphables = [self.root_net, hrnR012, hrnR01, hrnR0, hrn2, hrn1, hrn0]
#         morphables = [self.root_net, hrn0123, hrn012, hrn01]
        for hr in morphables:
            self.tree.DYNAMIC_LIST.add(hr)
        return
    
    def forward(self, x):
        return self.non_linearity(self.root_net(x))

    def add_neurons(self, num):
        num_stat = int(num*0.7)
        num_random = num - num_stat
        
        DL = list(self.tree.DYNAMIC_LIST)
        if num_random>0:
            rands = torch.randint(high=len(DL), size=(num_random,))
            index, count = torch.unique(rands, sorted=False, return_counts=True)
            for i, idx in enumerate(index):
                DL[idx].add_hidden_neuron(int(count[i]))

        if num_stat>0:
            del_neurons = []
            for hr in DL:
                if hr.residual:
                    del_neurons.append(hr.residual.del_neurons)#+1e-7)
                else:
                    del_neurons.append(0.)#1e-7) ## residual layer yet not created 
            
            prob_stat = torch.tensor(del_neurons)
            prob_stat = torch.log(torch.exp(prob_stat)+1.)
            m = torch.distributions.multinomial.Multinomial(total_count=num_stat,
                                                            probs= prob_stat)
            count = m.sample()#.type(torch.long)
            for i, hr in enumerate(DL):
                if count[i] < 1: continue
                hr.add_hidden_neuron(int(count[i]))
        
        self.neurons_added += num 
        pass

    def identify_removable_neurons(self, num=None, threshold_min=0., threshold_max=1.):
        
        all_sig = []
        self.all_sig_ = []
        
        for hr in self.tree.DYNAMIC_LIST:
            if hr.residual:
                all_sig.append(hr.residual.significance)
                
        all_sigs = torch.cat(all_sig)
        del all_sig
        
#         print("All_sigs", all_sigs)
        
#         print("Normalization", (all_sigs/all_sigs.sum()).sum())
        
        ### Normalizes such that importance 1 is average importance
        normalizer = float(torch.sum(all_sigs))/len(all_sigs)
        all_sig = all_sigs/normalizer

        ### Normalizes to range [0, 1]
#         max_sig = all_sigs.max()
#         all_sig = all_sigs/(max_sig+1e-9)
#         print("All_sig", all_sig)
#         print("Sig sum", all_sig.sum())
        print(f"Significance Stat:\nMin, Max: {float(all_sig.min()), float(all_sig.max())}")
        print(f"Mean, Std: {float(all_sig.mean()), float(all_sig.std())}")
        all_sig = all_sig[all_sig<threshold_max]
        if len(all_sig)<1: ## if all significance is above threshold max 
            return 0, None, all_sigs
        all_sig = torch.sort(all_sig)[0] ### sorted significance scores
        
        self.all_sig_ = all_sig
        
        if not num:num = int(np.ceil(self.neurons_added/self.tree.add_to_remove_ratio))
        ## reset the neurons_added number if decay is started

        remove_below = threshold_min
        if num>len(all_sig):
            remove_below = float(all_sig[-1])
        elif num>0:
            remove_below = float(all_sig[num-1])
        
        ### sig < threshold_min is always removed; whatsoever
        if remove_below < threshold_min:
            remove_below = threshold_min
            
        print("remove_below", remove_below, "true:", remove_below*normalizer)
        remove_below *= normalizer
#         remove_below *= max_sig
#         print("remove_below", remove_below)

        self._remove_below = remove_below
#         self._remove_above = remove_above*normalizer
        self._remove_above = None

        return remove_below, all_sigs

    def decay_neuron_start(self, decay_steps=1000):
        if self._remove_below is None: return 0
        
        self.neurons_added = 0 ## resetting this variable
        
        self.tree.total_decay_steps = decay_steps
        self.tree.current_decay_step = 0
        self.tree.remove_neuron_residual = set()
        self.tree.freeze_connection_shortcut = set()
        self.tree.decay_connection_shortcut = set()
        
        count_remove = 0
        for hr in self.tree.DYNAMIC_LIST:
            if hr.residual:
                ### always prune 1 % of the neurons randomly. It might overlap with less significant neurons
                mask = torch.bernoulli(torch.ones_like(hr.residual.significance)*0.05).type(torch.bool)
                count_remove += hr.residual.identify_removable_neurons(below=self._remove_below,
                                                                       above=self._remove_above,
                                                                       mask = mask
                                                                      )
        if count_remove<1:
            self.tree.clear_decay_variables()
        return count_remove
    
    def decay_neuron_step(self):
        if self.tree.total_decay_steps is None:
            return 0
        
        self.tree.current_decay_step += 1
        
        if self.tree.current_decay_step < self.tree.total_decay_steps:
            self.tree.get_decay_factor()
            for sh in self.tree.decay_connection_shortcut:
                sh.decay_connection_step()
            for sh in self.tree.freeze_connection_shortcut:
                sh.freeze_connection_step()
            return 1
        else:
#             if self.tree.current_decay_step == self.tree.total_decay_steps:
#                 for sh in self.tree.decay_connection_shortcut:
#     #                 sh.decay_connection_step()
#                     print("------------------")
#                     print(sh.weight.data.shape, "removing decayed; ", sh.to_remove)
#                     print("Small vals", torch.count_nonzero(sh.weight.data<1e-6))
#                     print("data", sh.weight.data[:, sh.to_remove])
#                     print("grads", sh.weight.grad[:, sh.to_remove])
#                     print("initial", sh.initial_remove)
#                     break

            
            
#             for rs in self.tree.remove_neuron_residual:
#                 rs.remove_decayed_neurons()
                
#             self.tree.clear_decay_variables()
#             self.maintain_network()

            ### need to decay and freeze all the time
            for sh in self.tree.decay_connection_shortcut:
                sh.decay_connection_step()
            for sh in self.tree.freeze_connection_shortcut:
                sh.freeze_connection_step()
            return -1
        
    def remove_decayed_neurons(self):
        for rs in self.tree.remove_neuron_residual:
            rs.remove_decayed_neurons()
                
        self.tree.clear_decay_variables()
        self.maintain_network()
        return

    def compute_del_neurons(self):
        for hr in self.tree.DYNAMIC_LIST:
            if hr.residual:
                hr.residual.compute_del_neurons()
    
    def maintain_network(self):
        self.root_net.maintain_shortcut_connection()
        self.root_net.morph_network()
        
    def start_computing_significance(self):
        for hr in self.tree.DYNAMIC_LIST:
            if hr.residual:
                hr.residual.start_computing_significance()

    def finish_computing_significance(self):
        for hr in self.tree.DYNAMIC_LIST:
            if hr.residual:
                hr.residual.finish_computing_significance()
            
    def print_network_debug(self):
        self.root_net.print_network_debug(0)
        
    def print_network(self):
        print(self.root_net.input_dim)
        self.root_net.print_network()
        print("│")
        print(self.root_net.output_dim)
        return

## Train dycnn

In [20]:
device = torch.device('cuda:0')

In [21]:
from torchvision import datasets, transforms

In [22]:
# cifar_train = transforms.Compose([
#     transforms.RandomCrop(size=32, padding=4),
#     transforms.RandomHorizontalFlip(),
#     transforms.ToTensor(),
#     transforms.Normalize(
#         mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
#         std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
#     ),
# ])

# cifar_test = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize(
#         mean=[0.4914, 0.4822, 0.4465], # mean=[0.5071, 0.4865, 0.4409] for cifar100
#         std=[0.2023, 0.1994, 0.2010], # std=[0.2009, 0.1984, 0.2023] for cifar100
#     ),
# ])

# train_dataset = datasets.CIFAR10(root="../../_Datasets/cifar10/", train=True, download=False, transform=cifar_train)
# test_dataset = datasets.CIFAR10(root="../../_Datasets/cifar10/", train=False, download=False, transform=cifar_test)

In [23]:
cifar_train = transforms.Compose([
    transforms.RandomCrop(size=32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5071, 0.4865, 0.4409],
        std=[0.2009, 0.1984, 0.2023],
    ),
])

cifar_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5071, 0.4865, 0.4409],
        std=[0.2009, 0.1984, 0.2023],
    ),
])

train_dataset = datasets.CIFAR100(root="../../_Datasets/cifar100/", train=True, download=False, transform=cifar_train)
test_dataset = datasets.CIFAR100(root="../../_Datasets/cifar100/", train=False, download=False, transform=cifar_test)

In [24]:
# train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=32, shuffle=True, num_workers=4)
# test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=32, shuffle=False, num_workers=4)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False, num_workers=4)

In [25]:
### hyperparameters
learning_rate = 0.0003
# learning_rate = 0.00003

num_add_neuron = 50 #50#25#10
num_decay_steps = int(len(train_loader)*2)#3

remove_above = 12 #10
threshold_max = 0.5
threshold_min = 0.01

train_epoch_min = 1 #1
train_epoch_max = 15 #10 #5

In [26]:
num_decay_steps

782

In [27]:
dynet = Dynamic_CNN(device, learning_rate, output_dim=100).to(device)
criterion = nn.CrossEntropyLoss()

In [28]:
# dynet

In [29]:
optimizer = dynet.tree.optimizer

In [30]:
len(train_loader), len(test_loader)

(391, 79)

In [31]:
dynet.tree.add_to_remove_ratio = 2.5

In [32]:
dynet.tree.optimizer.param_groups[0]['lr'] = learning_rate

## Training log

In [33]:
index = '00.3'
name = 'dynCNN_reuse_optim_v1_c100'
exp_index = 0

## Json

In [34]:
hyp = {
    'learning_rate':learning_rate,
    'num_add_neuron':num_add_neuron,
    'num_decay_steps':num_decay_steps,
    'remove_above':remove_above,
    'threshold_max':threshold_max,
    'threshold_min':threshold_min,
    'train_epoch_min':train_epoch_min,
    'train_epoch_max':train_epoch_max,
    'add_to_remove_ratio':dynet.tree.add_to_remove_ratio,
}

In [35]:
hyp_json = f'hyperparameters/{index}_hyp_exp_{exp_index}.json'
with open(hyp_json, 'w') as fp:
    json.dump(hyp, fp, indent=0)

## Auto_Training

In [36]:
class AutoTrainer:
    
    def __init__(self):
        self.training_func = None
        self.adding_func = None
        self.pruning_func = None
        self.maintainance_func = None
        self.extra_func = None
        
        self.log_func = None
        
    def loop(self, count = 15):
        cb = count
        for i in range(count):
            if i>-0.1:
                self.adding_func()
            else:
#                 global optimizer, warmup
                dynet.print_network()    
                
                reset_optimizer()
#                 optimizer = torch.optim.Adam(dynet.parameters(), lr=learning_rate)
#                 optimizer = torch.optim.SGD(dynet.parameters(), lr=learning_rate, momentum=0.9)
#                 warmup = WarmupLR_Polynomial(optimizer, 0, len(train_loader))
            
            
            self.training_func()

            self.log_func(i)
            if self.extra_func:
                self.extra_func()
            
            if i>-0.1:
                self.pruning_func()
            self.maintainance_func()
            
            self.log_func(i)
            if self.extra_func:
                self.extra_func()
            
            print(f"=====================")
            print(f"===LOOPS FINISHED :{i} ===")
            print(f"Pausing for 2 second to give user time to STOP PROCESS")
            time.sleep(2)
        self.training_func()

### when to stop training functionality

In [37]:
def update_coeff(num_iter, coeff0, coeff1, coeff2, coeff_opt, loss_list):
    if len(loss_list)<10: return np.array([0]), np.array([0]), float(coeff0.data.cpu()[0])
    
    _t = torch.tensor(loss_list)
    _t = (_t - _t[-1])/(_t[0]-_t.min()) ## normalize to make first point at 1 and last at 0 
    _t = torch.clamp(_t, -1.1, 1.1)
    _x = torch.linspace(0, 1, steps=len(_t))
    
    for _ in range(num_iter):
        coeff_opt.zero_grad()
        _y = torch.exp(coeff0*_x)*(1-_x)*coeff1 + coeff2

        _loss = ((_y - _t)**2).mean()
        _loss.backward()
        coeff_opt.step()

        coeff0.data = torch.clamp(coeff0.data, -20., 20.)
        coeff1.data = torch.clamp(coeff1.data, 0.7, 2.)
        coeff2.data = torch.clamp(coeff2.data, -0.2,0.1)
        
    if torch.isnan(coeff0.data[0]):
        coeff0.data[0] = 0.
        coeff1.data[0] = 0.
        coeff2.data[0] = 1. ## this gives signal
        
    _y = torch.exp(coeff0*_x)*(1-_x)*coeff1 + coeff2

    return _x.numpy(), _t.numpy(), _y.data.cpu().numpy()

## Train Network dynamically

In [38]:
## global variables
optimizer = None
warmup = None
coeff_opt = None

loss_all = []
accs_all = []
accs_test = []
events_all = []

## for adam optimizer = 
# learning_rate *= 0.1

In [39]:
# def reset_optimizer():
#     global optimizer, warmup
# #     optimizer = torch.optim.Adam(dynet.parameters(), lr=learning_rate, weight_decay=1e-4)
#     optimizer = adam_custom.Adam(dynet.parameters(), lr=learning_rate, weight_decay=1e-4)

# #     optimizer = torch.optim.SGD(dynet.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-4)
#     warmup = WarmupLR_Polynomial(optimizer, 0.5, len(train_loader), power=2)
# #     warmup = WarmupLR_Polynomial(optimizer, 10/len(train_loader), len(train_loader))
# #     get_bn_params()

In [40]:
def reset_optimizer():
    global dynet, warmup, optimizer
    
#     optimizer = dynet.tree.optimizer   
#     dynet.tree.optimizer = adam_custom.Adam(dynet.parameters(), lr=learning_rate, weight_decay=1e-4)
    
    ## there are no param groups, but consider there are len=1
    pg = dynet.tree.optimizer.param_groups
    for i in range(len(pg)):
        pg[i]['lr'] = learning_rate
        
    warmup = WarmupLR_Polynomial(dynet.tree.optimizer, 2, len(train_loader), power=2)
#     copy_optimizer()
    return

In [41]:
# def reset_optimizer():
#     global dynet, warmup, optimizer
    
#     def get_cosine_lr(lr, epoch, tmax):
#         return (np.cos(epoch/tmax*np.pi)*0.5+0.5)*lr
    
#     ## there are no param groups, but consider there are len=1
#     pg = dynet.tree.optimizer.param_groups
#     for i in range(len(pg)):
#         pg[i]['lr'] = get_cosine_lr(learning_rate, len(accs_all), tmax=400)
        
#     warmup = WarmupLR_Polynomial(dynet.tree.optimizer, 0.0, len(train_loader), power=2)
    
    
# #     copy_optimizer()
#     return

In [42]:
# def copy_optimizer():
#     global dynet
#     old_optim = dynet.tree.optimizer
#     new_optim = adam_custom.Adam(dynet.parameters(), lr=learning_rate, weight_decay=1e-4)
    
#     found=False
#     for p in dynet.parameters():
#         for _p in old_optim.param_groups[0]['params']:
#             if _p is p:
#                 found = True
#                 new_optim.state[p] = old_optim.state[p]
#     if not found:
#         raise ValueError("Parameter could not be found")
    
#     dynet.tree.optimizer = new_optim
#     return

In [43]:
# optimizer = dynet.tree.optimizer

# copy_optimizer(optimizer)

# dynet.tree.optimizer.param_groups

In [44]:
class WarmupLR_Polynomial():
    
    def __init__(self, optimizer, warmup_epoch, num_batch_in_epoch, power=5):
        self.warmup_epoch = warmup_epoch
        self.optimizer = optimizer
        self.num_batch = num_batch_in_epoch
        self.steps = 0
        self.power = power
        self.backup_lr = []
        for group in self.optimizer.param_groups:
            self.backup_lr.append(float(group['lr']))
        
    def step(self):
        self.steps += 1
        steps = self.steps/self.num_batch
        
        factor = 1
        warming = False
        if steps<self.warmup_epoch:
            factor = (steps/self.warmup_epoch)**self.power
            warming = True
            
        for group, bkp_lr in zip(self.optimizer.param_groups, self.backup_lr):
            group['lr'] = bkp_lr*factor
        
        return warming

In [45]:
def add_neurons_func():
    global optimizer, warmup, added, events_all
    
    ######################################33
    ################# CHECK IF ADDING NEURONS CHANGES ACCURACY #####################
#     with torch.no_grad():
#         corrects = 0
#         for test_x, test_y in train_loader:
#             test_x  = test_x.to(device)
#             yout = dynet.forward(test_x)
#             outputs = tnn.Logits.logit_to_index(yout.data.cpu().numpy())
#             correct = (outputs == test_y.data.cpu().numpy()).sum()
#             corrects += correct
#         accs_all.append(corrects/len(train_dataset)*100)

#     with torch.no_grad():
#         corrects = 0
#         dynet.eval()
#         for test_x, test_y in test_loader:
#             test_x  = test_x.to(device)
#             yout = dynet.forward(test_x)
#             outputs = tnn.Logits.logit_to_index(yout.data.cpu().numpy())
#             correct = (outputs == test_y.data.cpu().numpy()).sum()
#             corrects += correct
#         dynet.train()
#         accs_test.append(corrects/len(test_dataset)*100)
    ######################################33
    
    
    ### number of neurons
    count = 0
    for hr in dynet.tree.DYNAMIC_LIST:
        if hr.residual:
            count += hr.residual.hidden_dim
    ## add more neurons relatively (+x%)
    adding = num_add_neuron+int(0.07*count)
    dynet.add_neurons(adding)
    print(f"Adding {adding} Neurons")
    added = adding
    dynet.print_network()    
    
    reset_optimizer()
#     optimizer = torch.optim.Adam(dynet.parameters(), lr=learning_rate)
#     optimizer = torch.optim.SGD(dynet.parameters(), lr=learning_rate, momentum=0.9)
#     warmup = WarmupLR_Polynomial(optimizer, 0, len(train_loader), power=1)
    
            
    ######################################33
    if len(accs_all)>0:
        
#         accs_all.append(accs_all[-1])
        with torch.no_grad():
            corrects = 0
            for test_x, test_y in train_loader:
                test_x  = test_x.to(device)
                yout = dynet.forward(test_x)
                outputs = tnn.Logits.logit_to_index(yout.data.cpu().numpy())
                correct = (outputs == test_y.data.cpu().numpy()).sum()
                corrects += correct
            accs_all.append(corrects/len(train_dataset)*100)
        
        with torch.no_grad():
            corrects = 0
            dynet.eval()
            for test_x, test_y in test_loader:
                test_x  = test_x.to(device)
                yout = dynet.forward(test_x)
                outputs = tnn.Logits.logit_to_index(yout.data.cpu().numpy())
                correct = (outputs == test_y.data.cpu().numpy()).sum()
                corrects += correct
            dynet.train()
            accs_test.append(corrects/len(test_dataset)*100)
    ######################################33
    
    events_all.append((len(accs_all), "neurons added"))

    return

In [46]:
def get_children(module):
    child = list(module.children())
    if len(child) == 0:
        return [module]
    children = []
    for ch in child:
        grand_ch = get_children(ch)
        children+=grand_ch
    return children

bn_params = []
def get_bn_params():
    global dynet, bn_params
    bn_params = []
    for module in get_children(dynet):
        if isinstance(module, nn.BatchNorm2d):
            bn_params.append(module.weight)
            bn_params.append(module.bias)
            
def clip_bn_weight_grads(val=0.05):
    global bn_params
    for bnp in bn_params:
        bnp.grad = torch.clamp(bnp.grad, -val, val)
        
def get_bn_params_grads(val=0.05):
    global bn_params
    for bnp in bn_params:
        if bnp.grad.abs().max() > val:
            print("Batch Norm receiving high gradients!!")
            print(bnp.grad)
            print()
            
def decay_bn_params(val=5e-5):
    global bn_params
    for bnp in bn_params:
        bnp.data -= torch.sign(bnp.data)*val

In [47]:
# def train_step(xx, yy):
#     global dynet
    
#     yout = dynet(xx)
#     loss = criterion(yout, yy) #+ dynet.tree.decay_rate_std*dynet.tree.std_loss

#     dynet.tree.optimizer.zero_grad(set_to_none = True)
    
#     loss.backward(create_graph=False, retain_graph=False)
#     clip_bn_weight_grads()

#     dynet.tree.optimizer.step()
# #     dynet.zero_grad(True)
    
#     return yout, loss

In [48]:
def training_network_func():
    global optimizer, warmup, loss_all, accs_all
    
    coeff0 = torch.zeros(1, requires_grad=True)
    coeff1 = torch.zeros(1, requires_grad=True)
    coeff2 = torch.zeros(1, requires_grad=True)
    coeff_opt = torch.optim.Adam([coeff0, coeff1, coeff2], lr=0.8)
    loss_list = []
    prev_loss = None
    beta_loss = (1000-1)/1000
    loss_ = []
    optimizer = dynet.tree.optimizer
    
    fig = plt.figure(figsize=(10,4))
    ax = fig.add_subplot(121)
    ax2 = fig.add_subplot(122)
    breakall=False
    
    steps_ = -1
    for epoch in range(train_epoch_max):
        
        train_acc = 0
        train_count = 0
        for train_x, train_y in train_loader:
            train_x, train_y = train_x.to(device), train_y.to(device)
            steps_ += 1
            
#             dynet.decay_neuron_step()
            dynet.tree.std_loss = 0.    

            yout = dynet(train_x)
            loss = criterion(yout, train_y) #+ dynet.tree.decay_rate_std*dynet.tree.std_loss
                    
#             dynet.zero_grad()
            optimizer.zero_grad()
            loss.backward(retain_graph=False)
            
            clip_bn_weight_grads()
            optimizer.step()
#             yout, loss = train_step(train_x, train_y)
            
            warmup.step()
            
            if steps_>100:
                prev_loss = (1-beta_loss)*float(loss)+beta_loss*prev_loss
                loss_list.append(prev_loss)
            elif steps_ == 100:
                loss_.append(float(loss))
                prev_loss = np.mean(loss_)
                loss_ = []
            else:
                loss_.append(float(loss))
            
            
#             decay_bn_params()
            
            outputs = tnn.Logits.logit_to_index(yout.data.cpu().numpy())
            targets = train_y.data.cpu().numpy()

            correct = (outputs == targets).sum()
            train_acc += correct
            train_count += len(outputs)

            if steps_%100 == 0 and steps_>0:
                if len(loss_list)>0:
                    max_indx = np.argmax(loss_list)
                    loss_list = loss_list[max_indx:]
    #                 loss_all.append(float(loss))
                
                _x, _t, _y = update_coeff(50, coeff0, coeff1, coeff2, coeff_opt, loss_list)
                _c = float(coeff0.data.cpu()[0])
    #             if coeff2.data[0] > 0.5: ## this is a signal to reset optimizer
                coeff_opt = torch.optim.Adam([coeff0, coeff1, coeff2], lr=0.8)
                _info = f'ES: {epoch}:{steps_}, coeff:{_c:.3f}/{-5}, \nLoss:{float(loss):.3f}, Acc:{correct/len(outputs)*100:.3f}%'

                ax.clear()
                if len(_x)>0:
                    ax.plot(_x, _t, c='c')
                    ax.plot(_x, _y, c='m')
                xmin, xmax = ax.get_xlim()
                ymin, ymax = ax.get_ylim()
                ax.text(xmin, ymin, _info)
                    
                ax2.clear()
                if len(accs_all)>0:
                    acc_tr = accs_all
                    acc_te = accs_test
                    if len(acc_tr)>20: acc_tr = acc_tr[-20:]
                    if len(acc_te)>20: acc_te = acc_te[-20:]
                    ax2.plot(acc_tr, marker='.', label="train")
                    ax2.plot(acc_te, marker='.', label="test")
                    ax2.legend(loc="lower right")
                    
                    ymin, ymax = ax2.get_ylim()
                    ax2.text(0, 0.1*ymin+0.9*ymax, f"TR:max{max(acc_tr):.3f} end{acc_tr[-1]:.3f}")
                    ax2.text(0, 0.2*ymin+0.8*ymax, f"TE:max{max(acc_te):.3f} end{acc_te[-1]:.3f}")

                
                fig.canvas.draw()
                plt.savefig(f"./output/logs/_{index}_temp_train_plot.png")

                torch.cuda.empty_cache()
                if _c < -5 and epoch>train_epoch_min: 
                    breakall=True
                    break
                    
        if not breakall:
            accs_all.append(train_acc/train_count*100.)
            with torch.no_grad():
                corrects = 0
                dynet.eval()
                for test_x, test_y in test_loader:
                    test_x  = test_x.to(device)
                    yout = dynet.forward(test_x)
                    outputs = tnn.Logits.logit_to_index(yout.data.cpu().numpy())
                    correct = (outputs == test_y.data.cpu().numpy()).sum()
                    corrects += correct
                dynet.train()
                accs_test.append(corrects/len(test_dataset)*100)
    plt.close()
    return

In [49]:
%matplotlib inline

In [50]:
def pruning_func():
    global optimizer, warmup
    reset_optimizer()
#     optimizer = torch.optim.Adam(dynet.parameters(), lr=learning_rate)
#     optimizer = torch.optim.SGD(dynet.parameters(), lr=learning_rate, momentum=0.9)
#     warmup = WarmupLR_Polynomial(optimizer, 0, len(train_loader), power=0.5)
    
    optimizer = dynet.tree.optimizer
    
    
    print(f"Computing Network Siginificance")
    
    dynet.eval()
    dynet.start_computing_significance()

    for train_x, train_y in train_loader:
        train_x, train_y = train_x.to(device), train_y.to(device)
        dynet.tree.std_loss = 0.    
        yout = dynet(train_x)
#         yout.backward(gradient=torch.ones_like(yout))
        loss = criterion(yout, train_y)
        optimizer.zero_grad()
        loss.backward(retain_graph=False)

    optimizer.zero_grad()
    dynet.finish_computing_significance()
    
    dynet.identify_removable_neurons(num=None,
                                 threshold_min = threshold_min,
                                 threshold_max = threshold_max)
    num_remove = dynet.decay_neuron_start(decay_steps=num_decay_steps)
    
    dynet.train()
    
    if num_remove > 0:
#     if num_remove < 0:
        decayed = False
        print(f"pruning {num_remove} neurons.")
        
        fig = plt.figure(figsize=(10,4))
        ax = fig.add_subplot(121)
        ax2 = fig.add_subplot(122)
        
        loss_list = []
        steps_ = -1
        breakall=False

        for epoch in range(train_epoch_max+int(np.ceil(num_decay_steps/len(train_loader)))):
            loss_ = []
            train_acc = 0
            train_count = 0
            
            for train_x, train_y in train_loader:
                train_x, train_y = train_x.to(device), train_y.to(device)
                steps_ += 1
                
#                 with torch.no_grad():
                ret = dynet.decay_neuron_step()
                dynet.tree.std_loss = 0.    
        
                if ret == -1 and not decayed:
                    events_all.append((len(accs_all), "neurons decayed"))
                    decayed = True
                
#                     copy_optimizer()
#                     breakall = True
#                     break

                yout = dynet(train_x)
                loss = criterion(yout, train_y) #+ dynet.tree.decay_rate_std*dynet.tree.std_loss
                
                optimizer.zero_grad() ##set_to_none = True
                loss.backward(retain_graph=False)
                clip_bn_weight_grads()
                
                optimizer.step()
                
                loss = float(loss)
#                 yout, loss = train_step(train_x, train_y)
                                
                warmup.step()
#                 decay_bn_params()
                loss_.append(float(loss))
                

                outputs = tnn.Logits.logit_to_index(yout.data.cpu().numpy())
                targets = train_y.data.cpu().numpy()
                correct = (outputs == targets).sum()
                train_acc += correct
                train_count += len(outputs)

#                 dynet.decay_neuron_step()
                
                if steps_%50 == 0 and steps_>0:
                    loss = np.mean(loss_)
                    loss_ = []
                    loss_list.append(loss)
                
                if steps_%100 == 0 and steps_>0:
                    
                    _info = f'ES: {epoch}:{steps_}, Loss:{float(loss):.3f}, Acc:{correct/len(outputs)*100:.3f}%'
#                     print(_info)
                    ax.clear()
                    out = (yout.data.cpu().numpy()>0.5).astype(int)
                    ax.plot(loss_list)
                    
                    xmin, xmax = ax.get_xlim()
                    ymin, ymax = ax.get_ylim()
                    ax.text(xmin, ymin, _info)
                    
                    ax2.clear()
                    if len(accs_all)>0:
                        acc_tr = accs_all
                        acc_te = accs_test
                        if len(acc_tr)>20: acc_tr = acc_tr[-20:]
                        if len(acc_te)>20: acc_te = acc_te[-20:]
                        ax2.plot(acc_tr, marker='.', label="train")
                        ax2.plot(acc_te, marker='.', label="test")
                        ax2.legend(loc="lower right")

                        ymin, ymax = ax2.get_ylim()
                        ax2.text(0, 0.1*ymin+0.9*ymax, f"TR:max{max(acc_tr):.3f} end{acc_tr[-1]:.3f}")
                        ax2.text(0, 0.2*ymin+0.8*ymax, f"TE:max{max(acc_te):.3f} end{acc_te[-1]:.3f}")

                    
                    fig.canvas.draw()
                    plt.savefig(f"./output/logs/_{index}_temp_prune_plot.png")
#                     plt.pause(0.01)
#                     print("\n")
                    
#                 if steps_>num_decay_steps+int(num_decay_steps/2): breakall=True
#                 if steps_>(num_decay_steps+int(len(train_loader)*2.05)): breakall=True
#                 if breakall: break

#             if steps_>=(num_decay_steps):
            if epoch >= (num_decay_steps/len(train_loader))+1.99:
                breakall = True
                
            with torch.no_grad():
                corrects = 0
                ret = dynet.decay_neuron_step()
                dynet.eval()
                for test_x, test_y in test_loader:
                    test_x  = test_x.to(device)
                    yout = dynet.forward(test_x)
                    outputs = tnn.Logits.logit_to_index(yout.data.cpu().numpy())
                    correct = (outputs == test_y.data.cpu().numpy()).sum()
                    corrects += correct
                dynet.train()
                accs_test.append(corrects/len(test_dataset)*100)        

            accs_all.append(train_acc/train_count*100.)

#             if not breakall:
#                 accs_all.append(train_acc/train_count*100.)
#             else:
#                 accs_all.append(accs_all[-1])
#                 break
            if breakall: break

        plt.close()
    
    dynet.remove_decayed_neurons()
    events_all.append((len(accs_all), "neurons pruned"))
    return

In [51]:
len(train_loader)

391

In [52]:
def maintain_network():
    dynet.compute_del_neurons()
    dynet.maintain_network()
    dynet.print_network()
    return

In [53]:
def save_network_stat(loop_indx):
    stdout = sys.stdout
    s = io.StringIO(newline="")
    sys.stdout = s
    dynet.print_network()
    sys.stdout = stdout
    s.seek(0)
    # prints = s.read()
    architecture = s.getvalue()
    s.close()
    
    ### number of neurons
    count = 0
    for hr in dynet.tree.DYNAMIC_LIST:
        if hr.residual:
            count += hr.residual.hidden_dim
    
    with open(f"output/logs/{index}_{name}_log_{exp_index}.txt", "a+") as f:
        ### Print the configuration at top.
#         if loop_indx == 0:
        
        if loop_indx >= 0:
    
            f.write(f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n")
            f.write(f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n")
            
            from datetime import datetime
            now = datetime.now()
            dt_string = now.strftime("%B %d, %Y @ %H:%M:%S")
            f.write(f"DateTime: {dt_string}")
            
            f.write(f"num_add_neuron :{num_add_neuron}\n add_to_remove_ratio :{dynet.tree.add_to_remove_ratio}\n")
            f.write(f"learning_rate :{learning_rate}\n num_decay_steps :{num_decay_steps}\n")
            f.write(f"threshold_max :{threshold_max}\n threshold_min :{threshold_min}\n")
            f.write(f"train_epoch_min :{train_epoch_min}\n threshold_max :{train_epoch_max}\n")
            f.write(f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n")
        
        f.write(f"####################| Loop:{loop_indx} | Epoch: {len(accs_all)} \n")
        num_params = sum(p.numel() for p in dynet.parameters())
        num_trainable = sum(p.numel() for p in dynet.parameters() if p.requires_grad)
        f.write(f"| Dynamic Neurons:{count} | Total Parameters: {num_params} | Trainable Parameters: {num_trainable}\n")
        f.write(f"| Train Acc:{accs_all[-1]:.3f} | Test Acc: {accs_test[-1]:.3f}\n")
        f.write(architecture)
        f.write("\n\n")

In [54]:
def load_hyperparameters_from_json():
    global learning_rate,num_add_neuron,num_decay_steps,\
            remove_above,threshold_max,threshold_min,train_epoch_min,train_epoch_max,\
            dynet
    with open(hyp_json, 'r') as fp:
        hyps = json.load(fp)
        learning_rate = hyps['learning_rate']
        num_add_neuron = hyps['num_add_neuron']
        num_decay_steps = hyps['num_decay_steps']
        threshold_max = hyps['threshold_max']
        threshold_min = hyps['threshold_min']
        train_epoch_min = hyps['train_epoch_min']
        train_epoch_max = hyps['train_epoch_max']
        dynet.tree.add_to_remove_ratio = hyps['add_to_remove_ratio']

In [55]:
def plot_accs_save():
    plt.plot(accs_all, label="train")
    plt.plot(accs_test, label="test")
    ymin, ymax = plt.gca().get_ylim()
    plt.text(0, 0.8*ymin+0.2*ymax, f"Train-> max:{max(accs_all):.3f} end:{accs_all[-1]:.3f} \nTest-> max:{max(accs_test):.3f} end:{accs_test[-1]:.3f}")

    plt.legend()
    plt.savefig(f"output/plots/{index}_{name}_cifar100_{exp_index}.png")
    plt.close()
    
    with open(f"output/plots/{index}_{name}_cifar100_{exp_index}_event_dict.json", 'w') as f:
        d = {
            "train_accs":accs_all,
            "test_accs":accs_test,
            "event_dict":events_all,
        }
        json.dump(d, f, indent=0)

In [56]:
def extra_func():
    load_hyperparameters_from_json()
    plot_accs_save()

# Set all functions and begin automated loop

In [57]:
trainer = AutoTrainer()

In [58]:
trainer.adding_func = add_neurons_func
trainer.training_func = training_network_func
trainer.pruning_func = pruning_func
trainer.maintainance_func = maintain_network
trainer.log_func = save_network_stat
trainer.extra_func = extra_func

In [59]:
# add_neurons_func()

In [60]:
dynet.print_network()

3
╚╗
 ╚╗
  ╚╗
   ╚╗
    ╚╗
     16
    ╔╝
    16
   ╔╝
   32
  ╔╝
  64
 ╔╝
 128
╔╝
│
100


In [None]:
'''
/home/tsuman/All_Files/Program_Files/miniconda/lib/python3.9/site-packages/torch/nn/modules/module.py:1033: 
UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. 
This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.
  warnings.warn("Using a non-full backward hook when the forward contains multiple autograd Nodes "
'''

trainer.loop(40)

Adding 67 Neurons
3
╚╗
 ╚╗
  ╚╗
   ╚╗
    ╚╗
     28
     ╠════╗
     ║    6
     ╠════╝
    ╔╝
    21
    ╠════╗
    ║    8
    ╠════╝
   ╔╝
   40
   ╠════╗
   ║    7
   ╠════╝
  ╔╝
  72
  ╠════╗
  ║    8
  ╠════╝
 ╔╝
 133
╔╝
│
100
Computing Network Siginificance




Significance Stat:
Min, Max: (0.22380323708057404, 4.426582336425781)
Mean, Std: (1.0, 0.9323133230209351)
remove_below 0.2608362138271332 true: 10.014231497516471
Significance:
tensor([39.7486, 59.2894, 44.9194, 58.8616, 68.1051, 63.1673, 53.3258],
       device='cuda:0')
Prune:
tensor([False, False, False, False, False, False, False], device='cuda:0')
Significance:
tensor([ 94.9595, 104.6909,  65.5081, 108.3864, 119.2151,  66.5272,  61.8384,
        100.5653,  80.3782, 112.0496,  74.8113,  74.6065, 110.4939,  91.7815,
        115.2610,  87.4883,  70.4332,  32.1204,  93.7566, 136.4382, 132.8173,
         52.2519,  52.3999,  63.3256,  48.1672,  46.2661,  59.3031,  24.6271],
       device='cuda:0')
Prune:
tensor([False, False, False, False, False, False, False, False,  True, False,
        False, False, False, False, False, False, False, False, False,  True,
        False, False, False, False, False, False,  True, False],
       device='cuda:0')
Significance:
tensor([13.6113, 14.2317, 1

3
╚╗
 ╚╗
  ╚╗
   ╚╗
    ╚╗
     27
     ╠════╗
     ║    11
     ╠════╝
    ╔╝
    24
    ╠════╗
    ║    20
    ╠════╝
   ╔╝
   45
   ╠════╗
   ║    12
   ╠════╝
  ╔╝
  69
  ╠════╗
  ║    19
  ╠════╝
 ╔╝
 75
╔╝
│
100
===LOOPS FINISHED :1 ===
Pausing for 2 second to give user time to STOP PROCESS
Adding 71 Neurons
3
╚╗
 ╚╗
  ╚╗
   ╚╗
    ╚╗
     32
     ╠════╗
     ║    18
     ╠════╝
    ╔╝
    29
    ╠════╗
    ║    32
    ╠════╝
   ╔╝
   56
   ╠════╗
   ║    19
   ╠════╝
  ╔╝
  77
  ╠════╗
  ║    31
  ╠════╝
 ╔╝
 79
╔╝
│
100
Computing Network Siginificance
Significance Stat:
Min, Max: (0.11784704029560089, 7.508367538452148)
Mean, Std: (0.9999999403953552, 0.9722679257392883)
remove_below 0.34298115968704224 true: 19.56260183641635
Significance:
tensor([38.3228, 51.3449, 31.0247, 33.8964, 43.3232, 40.8503, 37.8105, 37.2553,
        39.1778, 32.9133, 79.5452, 33.0307, 39.6319, 40.2456, 31.3873, 39.1138,
        39.9800, 34.9961, 24.8156], device='cuda:0')
Prune:
tensor([ True, False,

3
╚╗
 ╚╗
  ╚╗
   ╚╗
    ╚╗
     27
     ╠════╗
     ║    16
     ╠════╝
    ╔╝
    27
    ╠════╗
    ║    29
    ╠════╝
   ╔╝
   45
   ╠════╗
   ║    16
   ╠════╝
  ╔╝
  64
  ╠════╗
  ║    30
  ╠════╝
 ╔╝
 70
╔╝
│
100
===LOOPS FINISHED :2 ===
Pausing for 2 second to give user time to STOP PROCESS
Adding 72 Neurons
3
╚╗
 ╚╗
  ╚╗
   ╚╗
    ╚╗
     34
     ╠════╗
     ║    22
     ╠════╝
    ╔╝
    32
    ╠════╗
    ║    39
    ╠════╝
   ╔╝
   54
   ╠════╗
   ║    23
   ╠════╝
  ╔╝
  67
  ╠════╗
  ║    52
  ╠════╝
 ╔╝
 73
╔╝
│
100
Computing Network Siginificance
Significance Stat:
Min, Max: (0.07295168191194534, 7.3780198097229)
Mean, Std: (1.0, 0.9924499988555908)
remove_below 0.24474456906318665 true: 14.651596834164318
Significance:
tensor([51.0117, 41.2123, 25.1683, 57.8832, 39.2307, 39.1956, 21.7188, 42.0842,
        97.0670, 30.7621, 35.7598, 43.2440, 37.2424, 30.1626, 36.0937, 26.7707,
        20.6318, 28.6216, 34.7739, 24.3996, 23.7586, 20.1020, 22.0941],
       device='cuda:0')
P

3
╚╗
 ╚╗
  ╚╗
   ╚╗
    ╚╗
     26
     ╠════╗
     ║    21
     ╠════╝
    ╔╝
    25
    ╠════╗
    ║    37
    ╠════╝
   ╔╝
   44
   ╠════╗
   ║    21
   ╠════╝
  ╔╝
  61
  ╠════╗
  ║    49
  ╠════╝
 ╔╝
 63
╔╝
│
100
===LOOPS FINISHED :3 ===
Pausing for 2 second to give user time to STOP PROCESS
Adding 74 Neurons
3
╚╗
 ╚╗
  ╚╗
   ╚╗
    ╚╗
     35
     ╠════╗
     ║    29
     ╠════╝
    ╔╝
    31
    ╠════╗
    ║    ╠════╗
    ║    ║    1
    ║    ╠════╝
    ║    50
    ║    ╠════╗
    ║    ║    3
    ║    ╠════╝
    ╠════╝
   ╔╝
   53
   ╠════╗
   ║    25
   ╠════╝
  ╔╝
  65
  ╠════╗
  ║    64
  ╠════╝
 ╔╝
 65
╔╝
│
100
Computing Network Siginificance
Significance Stat:
Min, Max: (0.04906795546412468, 7.5425543785095215)
Mean, Std: (1.0000001192092896, 1.074583649635315)
remove_below 0.17580267786979675 true: 10.01737934639987
Significance:
tensor([44.0550, 48.6391, 30.7270, 40.4542, 41.1871, 41.2453, 24.3705, 33.2434,
        99.8060, 31.3628, 43.5789, 35.8084, 33.7666, 31.1903, 21.

3
╚╗
 ╚╗
  ╚╗
   ╚╗
    ╚╗
     26
     ╠════╗
     ║    28
     ╠════╝
    ╔╝
    24
    ╠════╗
    ║    49
    ║    ╠════╗
    ║    ║    1
    ║    ╠════╝
    ╠════╝
   ╔╝
   42
   ╠════╗
   ║    25
   ╠════╝
  ╔╝
  61
  ╠════╗
  ║    53
  ╠════╝
 ╔╝
 57
╔╝
│
100
===LOOPS FINISHED :4 ===
Pausing for 2 second to give user time to STOP PROCESS
Adding 75 Neurons
3
╚╗
 ╚╗
  ╚╗
   ╚╗
    ╚╗
     29
     ╠════╗
     ║    ╠════╗
     ║    ║    4
     ║    ╠════╝
     ║    31
     ║    ╠════╗
     ║    ║    4
     ║    ╠════╝
     ╠════╝
    ╔╝
    30
    ╠════╗
    ║    ╠════╗
    ║    ║    1
    ║    ╠════╝
    ║    65
    ║    ╠════╗
    ║    ║    6
    ║    ╠════╝
    ╠════╝
   ╔╝
   45
   ╠════╗
   ║    34
   ╠════╝
  ╔╝
  62
  ╠════╗
  ║    71
  ╠════╝
 ╔╝
 59
╔╝
│
100


In [None]:
dynet.print_network()

In [None]:
dynet.tree.optimizer

In [None]:
### check if parameter in param_groupd
c = 0
for p in optimizer.param_groups[0]['params']:
    print(p.shape)
    c+=1
print(c)

In [None]:
dynet.root_net.residual.fc1.shortcut.weight.shape

In [None]:
for train_x, train_y in train_loader:
    train_x, train_y = train_x.to(device), train_y.to(device)
    yout = dynet(train_x)
    break

In [None]:
c = 0
for p in dynet.parameters():
    print(p.shape)
    for _p in optimizer.param_groups[0]['params']:
        if _p is p:
            print('Found')
    print()
    c += 1
print(c)

In [None]:
# trainer.training_func()
# trainer.pruning_func()
# trainer.maintainance_func()

In [None]:
with torch.no_grad():
    corrects = 0
    dynet.eval()
    for test_x, test_y in test_loader:
        test_x  = test_x.to(device)
        yout = dynet.forward(test_x)
        outputs = tnn.Logits.logit_to_index(yout.data.cpu().numpy())
        correct = (outputs == test_y.data.cpu().numpy()).sum()
        corrects += correct
    dynet.train()
    acc = corrects/len(test_dataset)*100
acc

In [None]:
# with torch.no_grad():
#     corrects = 0
#     dynet.train()
#     for test_x, test_y in train_loader:
#         test_x  = test_x.to(device)
#         yout = dynet.forward(test_x)
#         outputs = tnn.Logits.logit_to_index(yout.data.cpu().numpy())
#         correct = (outputs == test_y.data.cpu().numpy()).sum()
#         corrects += correct
#     acc = corrects/len(train_dataset)*100
# acc

In [None]:
# tr 66.908 -> 62.422 ## the adding neuron function is wrong.. not preserving the function.
# te 71.77 -> 41.959999999999994

# te -> 53.32, 53.32
# 68.51 -> 68.51

In [None]:
# trainer.adding_func()

In [None]:
dynet.tree.beta_del_neuron

In [None]:
dynet.print_network()

In [None]:
plt.plot(accs_all, label="train")
plt.plot(accs_test, label="test")
ymin, ymax = plt.gca().get_ylim()
plt.text(0, 0.8*ymin+0.2*ymax, f"Train-> max:{max(accs_all):.3f} end:{accs_all[-1]:.3f} \nTest-> max:{max(accs_test):.3f} end:{accs_test[-1]:.3f}")
                    
plt.legend()
plt.savefig(f"output/plots/{index}_{name}_cifar10_{exp_index}.png")

In [None]:
max(accs_test)

In [None]:
max(accs_all)

In [None]:
np.argmax(accs_test)

In [None]:
len(accs_test)

In [None]:
dynet.non_linearity.bias

In [None]:
# torch.cuda.memory_allocated(device="cuda:0")
torch.cuda.empty_cache()