In [None]:
class Network(torch.nn.Module):
    def __init__(self, n_recc_neur: List[int]=[256, 256], 
                n_channels: int=512, 
                recurr_weight_scale: float=1.0,
                quantize: str='default'):
        super(Network, self).__init__()

        cuba_params = {
                'threshold'    : 1.25, 
                'current_decay': 0.25, 
                'voltage_decay': 0.25, 
                'tau_grad'     : 0.1,
                'scale_grad'   : 0.8,
                'shared_param' : False, 
                'requires_grad': False, 
                'graded_spike' : False,
            }


        # recurrent
        self.blocks = torch.nn.ModuleList([
                    slayer.block.cuba.Recurrent(cuba_params, n_channels, n_recc_neur[0], weight_scale=recurr_weight_scale, pre_hook_fx=None) if quantize == 'None' else 
                    slayer.block.cuba.Recurrent(cuba_params, n_channels, n_recc_neur[0], weight_scale=recurr_weight_scale),
                ]
                +
                [slayer.block.cuba.Recurrent(cuba_params, n_recc_neur[i-1], n_recc_neur[i], weight_scale=recurr_weight_scale, pre_hook_fx=None) if quantize == 'None' else 
                slayer.block.cuba.Recurrent(cuba_params, n_recc_neur[i-1], n_recc_neur[i], weight_scale=recurr_weight_scale)
                    for i in range(1, len(n_recc_neur))]
                +
                [ 
                slayer.block.cuba.Dense(cuba_params, n_recc_neur[-1], 288, weight_scale=2, pre_hook_fx=None) if quantize == 'None' else 
                slayer.block.cuba.Dense(cuba_params, n_recc_neur[-1], 288, weight_scale=2)
                ,
                slayer.block.cuba.Average(num_outputs=12)
            ])
        # feedforward

        self.blocks = torch.nn.ModuleList([
                slayer.block.cuba.Dense(cuba_params, n_channels, n_recc_neur[0], weight_scale=recurr_weight_scale, pre_hook_fx=None) if quantize == 'None' else 
                slayer.block.cuba.Dense(cuba_params, n_channels, n_recc_neur[0], weight_scale=recurr_weight_scale),
                ]
                +
                [slayer.block.cuba.Dense(cuba_params, n_recc_neur[i-1], n_recc_neur[i], weight_scale=recurr_weight_scale, pre_hook_fx=None) if quantize == 'None' else 
                slayer.block.cuba.Dense(cuba_params, n_recc_neur[i-1], n_recc_neur[i], weight_scale=recurr_weight_scale)
                    for i in range(1, len(n_recc_neur))]
                +
                [
                slayer.block.cuba.Dense(cuba_params, n_recc_neur[-1], 288, weight_scale=2),
                slayer.block.cuba.Average(num_outputs=12) #TODO 35
            ])
        # norse recurrent



        class mean_across_time(torch.nn.Module):
            def forward(self, x):
                return torch.mean(x, dim=0)

        class mean_across_groups_of_neurons(torch.nn.Module):
            def __init__(self, n_out: int=12):
                self.n_out = n_out
                super().__init__()

            def forward(self, x):
                new_shape = (*x.shape[:-1], self.n_out, int(x.shape[-1] / self.n_out))
                return x.reshape(*new_shape).mean(dim=-1)

        self.model = norse.torch.SequentialState(
            norse.torch.LIFRecurrent(n_channels, n_recc_neur[0], dt=norse_param_dict['dt'], p=LIF_parameters),
            *[norse.torch.LIFRecurrent( n_recc_neur[i], n_recc_neur[i+1], dt=norse_param_dict['dt'], p=LIF_parameters) for i in range(len(n_recc_neur)-1)],
            torch.nn.Linear(n_recc_neur[-1], 288),
            norse.torch.LIF(dt=norse_param_dict['dt'], p=LIF_parameters),
            mean_across_time(),
            mean_across_groups_of_neurons(12),
        )

        # norse feedforwadr

        class mean_across_time(torch.nn.Module):
            def forward(self, x):
                return torch.mean(x, dim=0)

        class mean_across_groups_of_neurons(torch.nn.Module):
            def __init__(self, n_out: int=12):
                self.n_out = n_out
                super().__init__()

            def forward(self, x):
                new_shape = (*x.shape[:-1], self.n_out, int(x.shape[-1] / self.n_out))
                return x.reshape(*new_shape).mean(dim=-1)
        self.model = norse.torch.SequentialState(
            torch.nn.Linear(n_channels, n_recc_neur[0]),
            norse.torch.LIF(dt=norse_param_dict['dt'], p=LIF_parameters),
            *list(itertools.chain(
                *[
                [torch.nn.Linear(n_recc_neur[i], n_recc_neur[i+1]),
            norse.torch.LIF(dt=norse_param_dict['dt'], p=LIF_parameters)] for i in range(len(n_recc_neur)-1)
            ])),
            torch.nn.Linear(n_recc_neur[-1], 288),
            norse.torch.LIF(dt=norse_param_dict['dt'], p=LIF_parameters),
            mean_across_time(),
            mean_across_groups_of_neurons(12),
        ) 
