In [1]:
from __future__ import print_function
import torch
import os
from SNN import SNNetwork
from utils.training_utils import train, get_acc_and_loss
import time
import numpy as np
import tables
import argparse
import utils
import pickle
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.animation as animation
from matplotlib import rc
import utils.filters as filters
from IPython.display import HTML
%matplotlib inline

In [2]:
local_data_path = '/home/cream/Desktop/arafin_experiments/SOCC/FL-SNN/data/'
save_path = os.getcwd() + r'/results'

datasets = {
            'mnist_dvs_10': r'mnist_dvs_25ms_26pxl_10_digits.hdf5'
            }


dataset = local_data_path + datasets['mnist_dvs_10']

In [3]:
input_train = torch.FloatTensor(tables.open_file(dataset).root.train.data[:])
output_train = torch.FloatTensor(tables.open_file(dataset).root.train.label[:])

input_test = torch.FloatTensor(tables.open_file(dataset).root.test.data[:])
output_test = torch.FloatTensor(tables.open_file(dataset).root.test.label[:])

### sanity check
print("Shape of the training dataset:", input_train.shape)
print("Shape of the test dataset", input_test.shape)

### Network parameters
n_input_neurons = input_train.shape[1]
n_output_neurons = output_train.shape[1]
n_hidden_neurons = 4
epochs = input_train.shape[0]
epochs_test = input_test.shape[0]

test_accs = []

learning_rate = 0.005 / n_hidden_neurons
kappa = 0.2
alpha = 1
deltas = 1
num_ite = 1
r = 0.3
weights_magnitude=0.05
task='supervised'
mode='train', 
tau_ff=10
tau_fb=10
mu=1.5, 
n_basis_feedforward = 8
feedforward_filter = filters.raised_cosine_pillow_08
feedback_filter = filters.raised_cosine_pillow_08
n_basis_feedback = 1


Shape of the training dataset: torch.Size([9000, 676, 80])
Shape of the test dataset torch.Size([1000, 676, 80])


In [4]:
#divide the training and test data into 10 segments

user1_train_data = input_train[0:100,:,:]
user2_train_data =  input_train[100:200,:,:]
user3_train_data = input_train[200:300,:,:]
print(user1_train_data.shape,user2_train_data.shape, user3_train_data.shape)
user1_test_data = input_test[0:10,:,:]
user2_test_data =  input_test[10:20,:,:]
user3_test_data = input_test[20:30,:,:]
print(user1_test_data.shape,user2_test_data.shape, user3_test_data.shape)

torch.Size([100, 676, 80]) torch.Size([100, 676, 80]) torch.Size([100, 676, 80])
torch.Size([10, 676, 80]) torch.Size([10, 676, 80]) torch.Size([10, 676, 80])


In [5]:
topology = torch.ones([n_hidden_neurons + n_output_neurons, n_input_neurons + n_hidden_neurons + n_output_neurons], dtype=torch.float)
topology[[i for i in range(n_output_neurons + n_hidden_neurons)], [i + n_input_neurons for i in range(n_output_neurons + n_hidden_neurons)]] = 0
assert torch.sum(topology[:, :n_input_neurons]) == (n_input_neurons * (n_hidden_neurons + n_output_neurons))
print(topology[:, n_input_neurons:])

tensor([[0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0.]])


In [6]:
#start distributed training
import torch.distributed as dist
import torch.multiprocessing as mp
import os
from utils.training_utils import local_feedback_and_update, feedforward_sampling, get_acc_and_loss
from utils.distributed_utils import init_training, global_update, init_processes

In [7]:

    filters_dict = {'base_ff_filter': filters.base_feedforward_filter, 'base_fb_filter': filters.base_feedback_filter, 'cosine_basis': filters.cosine_basis,
                    'raised_cosine': filters.raised_cosine, 'raised_cosine_pillow_05': filters.raised_cosine_pillow_05, 'raised_cosine_pillow_08': filters.raised_cosine_pillow_08}

    network_parameters = {'n_input_neurons': n_input_neurons,
                          'n_hidden_neurons': n_hidden_neurons,
                          'n_output_neurons': n_output_neurons,
                          'topology': topology,
                          'n_basis_feedforward': n_basis_feedforward,
                          'feedforward_filter': feedforward_filter,
                          'n_basis_feedback': 1,
                          'feedback_filter': feedforward_filter,
                          'tau_ff': tau_ff,
                          'tau_fb': tau_ff,
                          'mu': mu,
                          'weights_magnitude': weights_magnitude,
                          'save_path': save_path
                          }

    training_parameters = {'dataset': dataset,
                           'tau': 10,
                           'learning_rate': learning_rate,
                           'epochs': epochs,
                           'epochs_test': epochs_test,
                           'eta': 1,
                           'kappa': kappa,
                           'deltas': deltas,
                           'alpha':alpha,
                           'r': r,
                           'num_ite': num_ite
                           }

In [None]:
from torch.multiprocessing import Process
import datetime


def init_training(rank, num_nodes, nodes_group, dataset, eta, epochs, net_parameters):
    print("At init_training")
    """"
    Initializes the different parameters for distributed training
    """
    # Initialize an SNN
    network = SNNetwork(**net_parameters)

    # At the beginning, the master node:
    # - transmits its weights to the workers
    # - distributes the samples among workers
    if rank == 0:
        # Initializing an aggregation list for future weights collection
        weights_list = [[torch.zeros(network.feedforward_weights.shape, dtype=torch.float) for _ in range(num_nodes)],
                        [torch.zeros(network.feedback_weights.shape, dtype=torch.float) for _ in range(num_nodes)],
                        [torch.zeros(network.bias.shape, dtype=torch.float) for _ in range(num_nodes)],
                        [torch.zeros(1, dtype=torch.float) for _ in range(num_nodes)]]
    else:
        weights_list = []
    print(weights_list)
    dist.barrier(nodes_group)
    # Randomly create a distribution of the training samples among nodes
    local_training_sequence = distribute_samples(nodes_group, rank, dataset, eta, epochs)
    S_prime = local_training_sequence.shape[-1]
    S = S_prime * epochs

    dist.barrier(nodes_group)

    # Master node sends its weights
    for parameter in network.get_parameters():
        dist.broadcast(network.get_parameters()[parameter], 0, group=nodes_group)
    if rank == 0:
        print('Node 0 has shared its model and training data is partitioned among workers')

    # The nodes initialize their eligibility trace and learning signal
    eligibility_trace = {'ff_weights': 0, 'fb_weights': 0, 'bias': 0}
    et_temp = {'ff_weights': 0, 'fb_weights': 0, 'bias': 0}

    learning_signal = 0
    ls_temp = 0

    return network, local_training_sequence, weights_list, S_prime, S, eligibility_trace, et_temp, learning_signal, ls_temp


def distribute_samples(nodes, rank, dataset, eta, epochs):
    """
    The master node (rank 0) randomly chooses and transmits samples indices to each device for training.
    Upon reception of their assigned samples, the nodes create their training dataset
    """

    if rank == 0:
        inpi = tables.open_file(dataset).root.train.data.shape[0]

        print(inpi)
        n_samples = tables.open_file(dataset).root.train.data.shape[0]  # Total number of samples
        n_samples_train_per_class = int(n_samples / 2 * 0.9)  # There are 2 classes and 10% of the dataset is kept for testing

        # Indices corresponding to each class
        indices_0 = np.asarray(torch.max(torch.sum(torch.FloatTensor(tables.open_file(dataset).root.train.label[:]), dim=-1), dim=-1).indices == 0).nonzero()[0][:n_samples_train_per_class]
        indices_1 = np.asarray(torch.max(torch.sum(torch.FloatTensor(tables.open_file(dataset).root.train.label[:]), dim=-1), dim=-1).indices == 1).nonzero()[0][:n_samples_train_per_class]

        assert len(indices_0) == len(indices_1)
        n_main_class = math.floor(epochs * eta)
        n_secondary_class = epochs - n_main_class
        assert (n_main_class + n_secondary_class) == epochs

        # Randomly select samples for each worker
        indices_worker_0 = np.hstack((np.random.choice(indices_0, [n_main_class], replace=False), np.random.choice(indices_1, [n_secondary_class], replace=False)))
        np.random.shuffle(indices_worker_0)
        remaining_indices_0 = [i for i in indices_0 if i not in indices_worker_0]
        remaining_indices_1 = [i for i in indices_1 if i not in indices_worker_0]
        indices_worker_1 = np.hstack((np.random.choice(remaining_indices_0, [n_secondary_class], replace=False), np.random.choice(remaining_indices_1, [n_main_class], replace=False)))
        np.random.shuffle(indices_worker_1)

        assert len(indices_worker_0) == len(indices_worker_1)

        # Send samples to the workers
        indices = [torch.zeros([epochs], dtype=torch.int), torch.IntTensor(indices_worker_0), torch.IntTensor(indices_worker_1)]
        indices_local = torch.zeros([epochs], dtype=torch.int)
        dist.scatter(tensor=indices_local, src=0, scatter_list=indices, group=nodes)

        # Save samples sent to the workers at master to evaluate train loss and accuracy later
        indices_local = torch.IntTensor(np.hstack((indices_worker_0, indices_worker_1)))
        local_input = tables.open_file(dataset).root.train.data[:][indices_local]
        local_output = tables.open_file(dataset).root.train.label[:][indices_local]
        local_teaching_signal = torch.cat((torch.FloatTensor(local_input), torch.FloatTensor(local_output)), dim=1)

    else:
        indices_local = torch.zeros([epochs], dtype=torch.int)
        dist.scatter(tensor=indices_local, src=0, scatter_list=[], group=nodes)

        assert torch.sum(indices_local) != 0

        local_input = tables.open_file(dataset).root.train.data[:][indices_local]
        local_output = tables.open_file(dataset).root.train.label[:][indices_local]

        local_teaching_signal = torch.cat((torch.FloatTensor(local_input), torch.FloatTensor(local_output)), dim=1)

    return local_teaching_signal


def feedforward_sampling_accum_gradients(network, training_sequence, et, ls, gradients_accum, s, S_prime, alpha, r):
    """"
    Runs a feedforward sampling pass:
    - computes log probabilities
    - accumulates learning signal
    - accumulates eligibility trace,
    and accumulates gradients during the procedure
    """
    # Run forward pass
    log_proba = network(training_sequence[int(s / S_prime), :, s % S_prime])

    # Accumulate learning signal
    ls += torch.sum(log_proba[network.output_neurons - network.n_non_learnable_neurons]) / network.n_learnable_neurons \
          - alpha*torch.sum(network.spiking_history[network.hidden_neurons, -1]
          * torch.log(1e-07 + torch.sigmoid(network.potential[network.hidden_neurons - network.n_non_learnable_neurons]) / r)
          + (1 - network.spiking_history[network.hidden_neurons, -1])
          * torch.log(1e-07 + (1. - torch.sigmoid(network.potential[network.hidden_neurons - network.n_non_learnable_neurons])) / (1 - r))) / network.n_learnable_neurons

    # Accumulate eligibility trace
    for parameter in et:
        if parameter == 'ff_weights':
            gradients_accum += torch.abs(network.get_gradients()[parameter])
        et[parameter] += network.get_gradients()[parameter]

    return log_proba, gradients_accum, ls, et


def global_update(nodes, rank, network, weights_list):
    """"
    Global update step for distributed learning.
    """

    for j, parameter in enumerate(network.get_parameters()):
        if rank != 0:
            dist.gather(tensor=network.get_parameters()[parameter].data, gather_list=[], dst=0, group=nodes)
        else:
            dist.gather(tensor=network.get_parameters()[parameter].data, gather_list=weights_list[j], dst=0, group=nodes)
            network.get_parameters()[parameter].data = torch.mean(torch.stack(weights_list[j][1:]), dim=0)
        dist.broadcast(network.get_parameters()[parameter], 0, group=nodes)


def global_update_subset(nodes, rank, network, weights_list, gradients_accum, n_weights_to_send):
    """"
    Global update step for distributed learning when transmitting only a subset of the weights.
    Each worker node transmits a tensor in which only the indices corresponding to the synapses with the largest n_weights_to_send accumulated gradients are kept nonzero
    """

    for j, parameter in enumerate(network.get_parameters()):
        if j == 0:
            if rank != 0:
                to_send = network.get_parameters()[parameter].data  # each worker node copies its weights in a new vector
                # Selection of the indices to set to zero before transmission
                indices_not_to_send = [i for i in range(network.n_basis_feedforward) if i not in torch.topk(torch.sum(gradients_accum, dim=(0, 1)), n_weights_to_send)[1]]
                to_send[:, :, indices_not_to_send] = 0

                # Transmission of the quantized weights
                dist.gather(tensor=to_send, gather_list=[], dst=0, group=nodes)
            else:
                dist.gather(tensor=network.get_parameters()[parameter].data, gather_list=weights_list[j], dst=0, group=nodes)

                indices_received = torch.bincount(torch.nonzero(torch.sum(torch.stack(weights_list[j][1:]), dim=(1, 2)))[:, 1])
                multiples = torch.zeros(network.n_basis_feedforward)  # indices of weights transmitted by two devices at once: those will be averaged
                multiples[:len(indices_received)] = indices_received
                multiples[multiples == 0] = 1

                # Averaging step
                network.get_parameters()[parameter].data = torch.sum(torch.stack(weights_list[j][1:]), dim=0) / multiples.type(torch.float)

        else:
            if rank != 0:
                dist.gather(tensor=network.get_parameters()[parameter].data, gather_list=[], dst=0, group=nodes)
            else:
                dist.gather(tensor=network.get_parameters()[parameter].data, gather_list=weights_list[j], dst=0, group=nodes)
                network.get_parameters()[parameter].data = torch.mean(torch.stack(weights_list[j][1:]), dim=0)
        dist.broadcast(network.get_parameters()[parameter], 0, group=nodes)




def run(rank,size,net_params,train_params):
    #print('at rank',rank, 'training params', training_parameters,'network_parameters',network_parameters)
    # Setup training parameters
    dataset = train_params['dataset']
    epochs = train_params['epochs']
    epochs_test = train_params['epochs_test']
    deltas = train_params['deltas']
    num_ite = train_params['num_ite']
    save_path = net_params['save_path']
    tau = train_params['tau']

    learning_rate = train_params['learning_rate']
    alpha = train_params['alpha']
    eta = train_params['eta']
    kappa = train_params['kappa']
    r = train_params['r']

    # Create network groups for communication
    all_nodes = dist.new_group([0, 1, 2,3], timeout=datetime.timedelta(0, 360000))

    test_accuracies = []  # used to store test accuracies
    test_loss = [[] for _ in range(num_ite)]
    test_indices = np.hstack((np.arange(900, 1000)[:epochs_test], np.arange(1900, 2000)[:epochs_test]))

    print('training at node', rank)
    network = SNNetwork(**net_params)
    dist.barrier(all_nodes)

    
def init_process(rank, size, network_parameters,training_parameters,fn, backend='gloo'):
    os.environ['MASTER_ADDR']='127.0.0.1'
    os.environ['MASTER_PORT']='29501'
    dist.init_process_group(backend,rank=rank, world_size=size)
    fn(rank,size,network_parameters,training_parameters)
size=4
processes =[]
for rank in range(size):
    p= Process(target=init_process, args=(rank, size,network_parameters, training_parameters, run))
    p.start()
    processes.append(p)
for p in processes:
    p.join()

training at node 0
At network
tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        ...,
        [1., 1., 1.,  ..., 0., 1., 1.],
        [1., 1., 1.,  ..., 1., 0., 1.],
        [1., 1., 1.,  ..., 1., 1., 0.]])
self.n_neurons are 690
training at node 1
Hi [[[1. 1. 1. ... 1. 1. 1.]
  [1. 1. 1. ... 1. 1. 1.]
  [1. 1. 1. ... 1. 1. 1.]
  ...
  [1. 1. 1. ... 1. 1. 1.]
  [1. 1. 1. ... 1. 1. 1.]
  [1. 1. 1. ... 1. 1. 1.]]

 [[1. 1. 1. ... 1. 1. 1.]
  [1. 1. 1. ... 1. 1. 1.]
  [1. 1. 1. ... 1. 1. 1.]
  ...
  [1. 1. 1. ... 1. 1. 1.]
  [1. 1. 1. ... 1. 1. 1.]
  [1. 1. 1. ... 1. 1. 1.]]

 [[1. 1. 1. ... 1. 1. 1.]
  [1. 1. 1. ... 1. 1. 1.]
  [1. 1. 1. ... 1. 1. 1.]
  ...
  [1. 1. 1. ... 1. 1. 1.]
  [1. 1. 1. ... 1. 1. 1.]
  [1. 1. 1. ... 1. 1. 1.]]

 ...

 [[1. 1. 1. ... 1. 1. 1.]
  [1. 1. 1. ... 1. 1. 1.]
  [1. 1. 1. ... 1. 1. 1.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [1. 1. 1. ... 1. 1. 1.]
  [1. 1. 1. ... 1. 1. 1.]]

 [[1. 1. 