In [None]:
from __future__ import division
from __future__ import print_function
import time
import argparse
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
import random
import matplotlib.pyplot as plt
from sklearn import metrics

from utils import get_plot
from models import GCN
from data_process import load_data
from train_func import test, Block_matrix_train, Block_matrix_train_batch

In [None]:
def get_K_hop_neighbors(adj_matrix, index, K):
    adj_matrix = adj_matrix + torch.eye(adj_matrix.shape[0],adj_matrix.shape[1])  #make sure the diagonal part >= 1
    hop_neightbor_index=index
    for i in range(K):
        hop_neightbor_index=torch.unique(torch.nonzero(adj[hop_neightbor_index])[:,1])
    return hop_neightbor_index

In [None]:
def get_K_hop_neighbors_BDS(adj_matrix, index, K):
    adj_matrix = adj_matrix + torch.eye(adj_matrix.shape[0],adj_matrix.shape[1])  #make sure the diagonal part >= 1
    
    onehop_neightbor_index=torch.unique(torch.nonzero(adj[index])[:,1])
    np.setdiff1d(index, onehop_neightbor_index)
    
    return onehop_neightbor_index

In [None]:
import scipy.sparse as sp
def normalize(mx):
    """Row-normalize sparse matrix"""
    
    mx = mx + torch.eye(mx.shape[0],mx.shape[1])
    
    rowsum = np.array(mx.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx)
    return torch.tensor(mx)

# Model

In [None]:

def Collaborative_Reasoning(K, features, adj, labels, idx_train, idx_val, idx_test, iid_percent):
        # K: number of models
        #choose adj matrix
        #GCN:n*n
        #no connection between agents

        #define model

        global_model = GCN(nfeat=features.shape[1],
                    nhid=args_hidden,
                    nclass=labels.max().item() + 1,
                    dropout=args_dropout)
        
        
        
        models=[]
        for i in range(K):
            models.append(GCN(nfeat=features.shape[1],
                    nhid=args_hidden,
                    nclass=labels.max().item() + 1,
                    dropout=args_dropout))
        if args_cuda:
                for i in range(K):
                    models[i]=models[i].to(torch.device('cuda:0'))#.cuda()
                global_model=global_model.to(torch.device('cuda:0'))
                features = features.cuda()
                adj = adj.to(torch.device('cuda:0'))
                labels = labels.cuda()
                idx_train = idx_train.cuda()
                idx_val = idx_val.cuda()
                idx_test = idx_test.cuda()
        #optimizer and train
        optimizers=[]
        for i in range(K):
            optimizers.append(optim.SGD(models[i].parameters(),
                              lr=args_lr, weight_decay=args_weight_decay))
        # split data into K devices
        
        n=len(adj)
        
        split_data_indexes=[]
        
        nclass=labels.max().item() + 1
        split_data_indexes = []
        non_iid_percent = 1 - float(iid_percent)
        iid_indexes = [] #random assign
        shuffle_labels = [] #make train data points split into different devices
        for i in range(K):
            current = torch.nonzero(labels == i).reshape(-1)
            current = current[np.random.permutation(len(current))] #shuffle
            shuffle_labels.append(current)
                
        average_device_of_class = K // nclass
        if K % nclass != 0: #for non-iid
            average_device_of_class += 1
        for i in range(K):  
            label_i= i // average_device_of_class    
            labels_class = shuffle_labels[label_i]

            average_num= int(len(labels_class)//average_device_of_class * non_iid_percent)
            split_data_indexes.append(np.array(labels_class[average_num * (i % average_device_of_class):average_num * (i % average_device_of_class + 1)]))
        
        L = []
        for i in split_data_indexes:
            L += list(i)
        L.sort()
        iid_indexes = np.setdiff1d(range(len(labels)), L)
        
        for i in range(K):  #for iid
            label_i= i // average_device_of_class
            labels_class = shuffle_labels[label_i]

            average_num= int(len(labels_class)//average_device_of_class * (1 - non_iid_percent))
            split_data_indexes[i] = list(split_data_indexes[i]) + list(iid_indexes[:average_num])
                    
            iid_indexes = iid_indexes[average_num:]
        
        
        #get train indexes in each device, only part of nodes in each device have labels in the train process
        split_train_ids = []
        for i in range(K):
            split_data_indexes[i].sort()
            inter = np.intersect1d(split_data_indexes[i], idx_train)
            
            split_train_ids.append(np.searchsorted(split_data_indexes[i], inter))   #local id in block matrix
            
        
        
        #assign global model weights to local models at initial step
        for i in range(K):
            models[i].load_state_dict(global_model.state_dict())
        
        
        #start training
        for t in range(iterations):
            acc_trains=[]
            for i in range(K):
                for epoch in range(args_epochs):
                    if len(split_train_ids[i]) == 0:
                        continue
                    acc_train=Block_matrix_train(epoch, models[i], optimizers[i], features, adj, labels,
                                    split_data_indexes[i], split_train_ids[i])
                    
                acc_trains.append(acc_train)
                    #print(model.Lambda)
            states=[]
            gloabl_state=dict()
            for i in range(K):
                states.append(models[i].state_dict())
            # Average all parameters
            
            
            for key in global_model.state_dict():
                gloabl_state[key] = split_train_ids[0].shape[0] * states[0][key]
                count_D=split_train_ids[0].shape[0]
                for i in range(1,K):
                    gloabl_state[key] += split_train_ids[i].shape[0] * states[i][key]
                    count_D += split_train_ids[i].shape[0]
                gloabl_state[key] /= count_D
            

            global_model.load_state_dict(gloabl_state)
            
            
            loss_train, acc_train = test(global_model, features, adj, labels, idx_train)
            #print(t,'\t',"train",'\t',loss_train,'\t',acc_train)
            
            loss_val, acc_val = test(global_model, features, adj, labels, idx_val) #validation
            #print(t,'\t',"val",'\t',loss_val,'\t',acc_val)
            

            a = open(mode+'_'+dataset_name+'_IID_'+str(iid_percent)+'_Collaborative_Reasoning_iter_'+str(iterations)+'_epoch_'+str(args_epochs)+'_device_num_'+str(K),'a+')
            a.write(str(t)+'\t'+"train"+'\t'+str(loss_train)+'\t'+str(acc_train)+'\n')
            a.write(str(t)+'\t'+"val"+'\t'+str(loss_val)+'\t'+str(acc_val)+'\n')
            a.close()
            for i in range(K):
                models[i].load_state_dict(gloabl_state)
        #test  
        loss_test, acc_test= test(global_model, features, adj, labels, idx_test)
        #print(t,'\t',"test",'\t',loss_test,'\t',acc_test)
        a = open(mode+'_'+dataset_name+'_IID_'+str(iid_percent)+'_Collaborative_Reasoning_iter_'+str(iterations)+'_epoch_'+str(args_epochs)+'_device_num_'+str(K),'a+')
        a.write(str(t)+'\t'+"test"+'\t'+str(loss_test)+'\t'+str(acc_test)+'\n')
        a.close()
        #print("save file as",mode+'_'+dataset_name+'_IID_'+str(iid_percent)+'_Collaborative_Reasoning_iter_'+str(iterations)+'_epoch_'+str(args_epochs)+'_device_num_'+str(K))


        return loss_test, acc_test



In [None]:
np.random.seed(42)
torch.manual_seed(42)
mode="real"
dataset_name='cora'
features, adj, labels, idx_train, idx_val, idx_test = load_data(dataset_name)
class_num = labels.max().item() + 1




In [None]:
#for fix seed, need to rerun both data and model codes

args_normalize = True

model_type = 'GCN'    #GCN
args_hidden = 16
args_dropout = 0.5
args_lr = 0.5
args_weight_decay = 5e-4     #L2 penalty
args_epochs = 3
args_no_cuda = False
args_cuda = not args_no_cuda and torch.cuda.is_available()

args_device_num = class_num #split data into args_device_num parts
#iterations = 100



if args_normalize==True:  
    adj = normalize(adj)
    '''
    adj = adj + torch.eye(adj.shape[0],adj.shape[1])
    d=torch.sum(adj,axis=1)
    D_minus_one_over_2=torch.zeros(adj.shape[0],adj.shape[0])
    D_minus_one_over_2[range(len(D_minus_one_over_2)), range(len(D_minus_one_over_2))] = d**(-0.5)
    adj = torch.mm(torch.mm(D_minus_one_over_2,adj),D_minus_one_over_2)
    '''
    




In [None]:
for args_epochs in [3]:
    for args_random_assign in [0.0, 0.5, 1]:
        for args_device_num in [2,3,4,5,6]:
            for iterations in [10,20,40,80,100,200,300,400,500,600]:
       # for i in range(3):
                Collaborative_Reasoning(args_device_num, features, adj, labels, idx_train, idx_val, idx_test, args_random_assign)