<a href="https://colab.research.google.com/github/raghav96/FederatedSemiSupervisedGraphLearning/blob/main/Capstone.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Getting all the required packages for the Python environment

In [None]:
!pip install  dgl -f https://data.dgl.ai/wheels/cu117/repo.html
!pip install  dglgo -f https://data.dgl.ai/wheels-test/repo.html
!pip install torch_geometric

Looking in links: https://data.dgl.ai/wheels/cu117/repo.html
Looking in links: https://data.dgl.ai/wheels-test/repo.html


# Model

In [None]:
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import SAGEConv

class GraphSAGE(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(in_feats, hid_feats, 'mean')
        self.conv2 = SAGEConv(hid_feats, out_feats, 'mean')

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h

# Client

In [None]:
import torch
from torch import nn, optim
from tqdm import tqdm
from dgl import DGLGraph
import copy
import numpy as np
import torch as th
import scipy.sparse as sp
from scipy.linalg import fractional_matrix_power, inv

import dgl
import networkx as nx

from sklearn.preprocessing import MinMaxScaler

from dgl.nn import APPNPConv

class Client:

    def loss_fn(self, batch_graph, embeddings, negative_samples = 5):
        # Get the edges in the batch graph
        edges = batch_graph.edges()

        # Get the positive pairs (node and its neighbors)
        positive_pairs = torch.stack(edges)

        # Get the embeddings for the nodes in the positive pairs
        pos_u_embeddings = embeddings[positive_pairs[0]]
        pos_v_embeddings = embeddings[positive_pairs[1]]

        # Compute the similarity for the positive pairs (dot product is used as the similarity function)
        pos_similarity = torch.sum(pos_u_embeddings * pos_v_embeddings, dim=1)

        # Generate negative samples
        num_nodes = embeddings.shape[0]
        neg_v = torch.randint(0, num_nodes, (positive_pairs.shape[1], negative_samples))

        # Get the embeddings for the nodes in the negative samples
        neg_u_embeddings = pos_u_embeddings.unsqueeze(1).repeat(1, negative_samples, 1)
        neg_v_embeddings = embeddings[neg_v]

        # Compute the similarity for the negative samples
        neg_similarity = torch.sum(neg_u_embeddings * neg_v_embeddings, dim=2)

        # Compute the loss
        pos_loss = -F.logsigmoid(pos_similarity)
        neg_loss = -F.logsigmoid(-neg_similarity)

        # Combine the positive and negative loss
        loss = torch.mean(pos_loss) + torch.mean(neg_loss)

        return loss
    #   pos_score = torch.sum(embeddings[g.edges()[0]] * embeddings[g.edges()[1]], dim=1)
    #   neg_score = torch.sum(embeddings[g.edges()[0]] * embeddings[g.edges()[1]].flip(0), dim=1)
    #   return -torch.mean(F.logsigmoid(pos_score) + F.logsigmoid(-neg_score))

    def __init__(self, client_id, data, model, device, epoch, lr, l2norm, model_path, logging, writer) -> None:
        # log setting
        self.model_path = model_path
        self.logging = logging
        self.writer = writer

        # client setting
        self.client_id = client_id
        self.device = device
        self.data = data
        self.model = model.to(self.device)
        self.epoch = epoch
        self.hid_dim = 64

        self.Epoch = -1                    # record the FL round by self
        self.round = None                  # record the mask round of FL this round from the server
        self.val_acc = None
        self.model_param = None
        self.patience = 10
        self.epochs = args.local_epoch
        #self.loss_fn = nn.BCEWithLogitsLoss()
        self.graph = data.to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=l2norm)


    def train(self):
        print('Client: {0}'.format(self.client_id))
        best = float('inf')
        cnt_wait = 0
        for epoch in range(self.epochs):
            self.model.train()
            out = self.model(self.graph, self.graph.ndata['feat'].float())
            loss = self.loss_fn(self.graph, out)
            self.optimizer.zero_grad()

            # shuf_idx = np.random.permutation(self.n_node)
            # shuf_feat = self.feat[shuf_idx, :]
            # shuf_feat = shuf_feat.to(args.device)

            #out = self.model(self.graph, self.diff_graph, self.feat, shuf_feat, self.edge_weight)

            loss.backward()
            self.optimizer.step()

            print('Epoch: {0}, Loss: {1:0.4f}'.format(epoch, loss.item()))

            if loss < best:
                best = loss
                cnt_wait = 0
                th.save(self.model.state_dict(), 'model' + '_' + str(self.client_id) + '.pkl')
                self.model_param = self.model.state_dict()
            else:
                cnt_wait += 1

            if cnt_wait == self.patience:
                print('Early stopping')
                break

    # def test(self):
    #     self.model.eval()
    #     with torch.no_grad():
    #         logits = self.model(self.graph)
    #         test_loss = self.loss_fn(logits[self.test_idx], self.labels[self.test_idx])
    #         test_acc =  (logits[self.test_idx].argmax(1) == self.labels[self.test_idx]).float().sum()
    #         test_acc = test_acc / len(self.test_idx)

    #     self.logging.info("Client {:>2} Test: Test Loss: {:.4f} | Test Acc: {:.4f}".format(
    #         self.client_id, test_loss.item(), test_acc
    #     ))
    #     # save to disk
    #     torch.save(self.model_param, self.model_path + "client" + str(self.client_id) + '_model.ckpt')

    #     return test_acc

    # get the global model's parameters from parameter server
    def getParame(self, round, param):
        self.round = round
        # Right now we are copying all the layers of the global model into the local model's state dict
        # Loading the local model with the global model to finetune with the local data
        if self.model_param is not None:
            for layer_name in self.model_param:
                    param[layer_name] = copy.deepcopy(self.model_param[layer_name])
        self.model.load_state_dict(param)


    # upload the local model's parameters to parameter server
    def uploadParame(self):
        param = {}
        for layer_name in self.model_param:
            param[layer_name] = self.model_param[layer_name]
        return self.round, param

# Server

In [None]:
import torch
from torch import nn
import random
import copy

class Server:
    def __init__(self, data, globalModel, device, logging, writer) -> None:
        # log setting
        self.logging = logging
        self.writer = writer

        # server setting
        self.device = device
        self.Epoch = -1                        # record the FL round by self
        self.round = random.randint(0, 1e8)    # record the mask round of FL this round
        self.local_state_dict = []
        #self.val_acc = 0
        self.model = globalModel.to(self.device)
        self.global_state_dict = copy.deepcopy(globalModel.state_dict())
        self.graph = data.to(self.device)

        # loss function
        self.loss_fn = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.005)
        self.supervised_loss = 0


    def train(self):
        # Extract features, labels, and masks for training, validation, and testing
        #features = self.graph.ndata['feat']
        labels = self.graph.ndata['label']
        train_mask = self.graph.ndata['train_mask']
        val_mask = self.graph.ndata['val_mask']
        test_mask = self.graph.ndata['test_mask']
        print(self.model)
        print(self.graph)
        print(self.loss_fn)

        # Define the optimizer and loss function
        optimizer = torch.optim.Adam(self.model.parameters(), lr=0.005)

        print('Server training')
        # Training loop
        for epoch in range(100):
            self.model.train()
            logits = self.model(self.graph, self.graph.ndata['feat'])
            loss = self.loss_fn(logits[train_mask], labels[train_mask])

            #print(loss)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Validation
            self.model.eval()
            with torch.no_grad():
                val_pred = torch.argmax(self.model(self.graph, self.graph.ndata['feat'])[val_mask], dim=1)
                val_labels = labels[val_mask]
                val_acc = (val_pred == val_labels).float().mean()
                print('Epoch: {0}, Loss: {1:0.4f} , Val Acc: {2:0.4f}'.format(epoch, loss.item(), 100*val_acc.item()))
                #print(f'Epoch {epoch+1}, Loss: {100*loss.item()}, Val Acc: {100*val_acc.item()}')

        # Testing
        self.model.eval()
        with torch.no_grad():
            test_pred = torch.argmax(self.model(self.graph, self.graph.ndata['feat'])[test_mask], dim=1)
            test_labels = labels[test_mask]
            test_acc = (test_pred == test_labels).float().mean()
            print(f'Test Acc: {100*test_acc.item()}')

    # the global model is not complete
    def test(self):
        size = self.graph.num_nodes
        self.model.eval()  # Set the model to evaluation mode

        # Extract the labels from the graph
        labels = self.graph.ndata['label']

        with torch.no_grad():
            # Get the predictions for all nodes
            logits = self.model(self.graph, self.graph.ndata['feat'].float())

            # Calculate the predictions
            predictions = torch.argmax(logits, dim=1)

            # Calculate the accuracy
            correct_predictions = torch.sum(predictions == labels)
            total_predictions = len(labels)
            accuracy = correct_predictions.float() / total_predictions

        print(f"Server test error: \n Accuracy: {(100*accuracy.item()):>0.1f}\n")

    def aggregate(self):
        clientNum = len(self.local_state_dict)
        if clientNum == 0:
            return
        #self.val_acc /= clientNum
        self.Epoch += 1

        labels = self.graph.ndata['label']
        train_mask = self.graph.ndata['train_mask']
        val_mask = self.graph.ndata['val_mask']
        test_mask = self.graph.ndata['test_mask']

        #self.logging.info("Clients Val Avg Acc: {:>8f}".format(self.val_acc))
        #self.writer.add_scalar(f"val/acc/avg", self.val_acc, self.Epoch)

        # Training the model
        self.model.train()
        self.optimizer.zero_grad()

        # aggregate all parameter
        for layer_name in self.local_state_dict[0].keys():
            self.global_state_dict[layer_name] = torch.zeros_like(self.global_state_dict[layer_name])

            for localParame in self.local_state_dict:
                self.global_state_dict[layer_name].add_(localParame[layer_name])

            self.global_state_dict[layer_name].div_(clientNum)

        self.local_state_dict.clear()
        self.round = random.randint(0, 1e8)

        # Get the value of the model and calculating the loss
        logits = self.model(self.graph, self.graph.ndata['feat'])
        self.supervised_loss = self.loss_fn(logits[train_mask], labels[train_mask])

        self.supervised_loss.backward()
        self.optimizer.step()

    def sendParame(self):
        return self.round, self.global_state_dict

    def getParame(self, round, localParame):
        if round == self.round:
            self.local_state_dict.append(localParame)
            #self.val_acc += val_acc

# Federated Learning initialization and running

In [None]:
# init FL setting
def init_fed(args):

    ## Split data to the server and the client accordingly
    server_data, datasets = labels_in_server_split_subgraph(args)

    # init the roles of FL
    # 1.init server
    #ser_model = GCN(num_nodes, args.n_hidden, num_classes, num_rels, args.n_bases, args.n_hidden_layers, args.gpu)
    ser_model = GraphSAGE(args.in_channels, args.n_hidden, args.out_channels)
    server = Server(server_data, ser_model, args.gpu, logging, args.writer)
    # 2.init clients
    clients = []
    for i in range(args.num_client):
        #model = GCN(in_channels=args.in_channels, hidden_channels = args.n_hidden, num_layers=args.n_hidden_layers, out_channels= args.out_channels)
        model = GraphSAGE(args.in_channels, args.n_hidden, args.out_channels)
        client = Client(i, datasets[i], model, args.gpu, args.local_epoch, args.lr, args.l2norm, args.state_dir + args.run_mode + "/", logging, args.writer)
        clients.append(client)

    return server, clients


# FL process
def FedRunning(args, server, clients):
    server.train()

    for t in range(args.round):

        logging.info(f"---------------------Round {t}---------------------")

        # The 0 step
        perm = list(range(args.num_client))
        random.shuffle(perm)
        perm = np.array(perm[:int(args.num_client * args.fraction)])

        # The 1 step
        for client in np.array(clients)[perm]:
            client.getParame(*server.sendParame())
            client.train()
            server.getParame(*client.uploadParame())

        # The 2 step
        server.aggregate()
        server.test()




    logging.info(f"--------------------Finally Test--------------------")

    # test_acc = 0
    # for client in np.array(clients):
    #     client.getParame(*server.sendParame())
    #     #test_acc += client.test()
    # test_acc /= len(clients)
    # logging.info("Clients Test Avg Acc: {:>8f}".format(test_acc))

# Main code to execute

In [None]:
import sys
sys.path.append("../")
import copy
import torch
from torch.utils.data import random_split
from torch.utils.tensorboard import SummaryWriter
from dgl.data import load_data
from dgl.data import AsNodePredDataset
from dgl import edge_subgraph
import numpy as np
from torch_geometric.nn.models import GCN
import argparse
import json
import logging
import os
import random


# init directories
def init_dir(args):
    # logging
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)

    # state
    if not os.path.exists(args.state_dir):
        os.makedirs(args.state_dir)
    if not os.path.exists(args.state_dir + args.run_mode + '/'):
        os.makedirs(args.state_dir + args.run_mode + '/')

    # tensorboard log
    if not os.path.exists(args.tb_log_dir):
        os.makedirs(args.tb_log_dir)


# init logger
def init_logger(args):
    log_file = os.path.join(args.log_dir, args.run_mode + '.log')

    logging.basicConfig(
        format='%(asctime)s [%(levelname)s] | %(message)s',
        datefmt="%Y-%m-%d %H:%M:%S",
        level=logging.INFO,
        filename=log_file,
        filemode='a+'
    )

    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s [%(levelname)s] | %(message)s')
    console.setFormatter(formatter)
    logging.getLogger('').addHandler(console)


import random

def numbers_with_sum(n, k):
    """n numbers with sum k"""
    if n == 1:
        return [k]
    print(k)
    num = random.randint(1, k)
    return [num] + numbers_with_sum(n - 1, k - num)

# Splits the dataset into a labels-in-server scenario
def labels_in_server_split_subgraph(args):
  # load data
  dataset = load_data(args)
  graph = dataset[0]
  # keep the same number of node and relation to each client
  num_nodes = graph.num_nodes()
  #print(num_nodes)
  num_edges = graph.num_edges()
  #print(num_edges)
  num_classes = dataset.num_classes
  #print(num_classes)
  num_clients = args.num_client
  #print(num_clients)

  # Split the graph into the number of nodes for each client
  #edges_split = sorted(numbers_with_sum(num_clients+1, num_edges), reverse= True);
  edges_split = sorted(numbers_with_sum(num_clients+1, num_edges));
  #print(edges_split)

  # Get edges for each client based on split and create subgraph for the client
  edge_idx = [i for i in range(num_edges)]
  np.random.shuffle(edge_idx)
  edge_idx_splits = np.split(edge_idx,np.cumsum(edges_split))[:-1]
  #print(edge_idx_splits)


  # Add labels for servers
  datasets = []
  server_data = edge_subgraph(graph, edge_idx_splits[0])

  # Add subgraph to the client's dataset and push to array
  for i in range(1, num_clients+1):
    client_subgraph = edge_subgraph(graph, edge_idx_splits[i])
    datasets.append(client_subgraph)

  return server_data, datasets

class Args:
  dataset = "cora"
  state_dir = "../log/state/"
  log_dir = "../log/"
  tb_log_dir = "../log/tb_log/"
  labeled_rate = 0.7
  lr = 0.01
  l2norm = 0
  num_cpu = 1
  run_mode = "Fed"
  num_client = 5
  fraction = 1
  round = 100
  local_epoch = 10
  n_hidden = 64
  input = 50
  n_hidden_layers = 2
  in_channels = 1433
  out_channels = 7
  seed = 12345
  gpu = '0'

args = Args()
args_str = str(args)

args.gpu = torch.device('cuda:' + args.gpu)
print(args.gpu)

# set random seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)

# create directories
init_dir(args)

# init writer
writer = SummaryWriter(args.tb_log_dir + args.run_mode + "/")
args.writer = writer

# init logger
init_logger(args)
logging.info(args_str)

if args.run_mode == 'Fed':
    # init FL setting
    server, clients = init_fed(args)
    # running
    FedRunning(args, server, clients)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Epoch: 1, Loss: 0.8519
Epoch: 2, Loss: 0.8656
Epoch: 3, Loss: 0.8558
Epoch: 4, Loss: 0.8529
Epoch: 5, Loss: 0.8455
Epoch: 6, Loss: 0.8659
Epoch: 7, Loss: 0.8504
Epoch: 8, Loss: 0.8541
Epoch: 9, Loss: 0.8528
Server test error: 
 Accuracy: 4.5

Client: 4
Epoch: 0, Loss: 0.8725
Epoch: 1, Loss: 0.8754
Epoch: 2, Loss: 0.8715
Epoch: 3, Loss: 0.8681
Epoch: 4, Loss: 0.8741
Epoch: 5, Loss: 0.8708
Epoch: 6, Loss: 0.8703
Epoch: 7, Loss: 0.8712
Epoch: 8, Loss: 0.8690
Epoch: 9, Loss: 0.8694
Client: 0
Epoch: 0, Loss: 0.7029
Epoch: 1, Loss: 0.5453
Epoch: 2, Loss: 0.6683
Epoch: 3, Loss: 0.6759
Epoch: 4, Loss: 0.6417
Epoch: 5, Loss: 0.6898
Epoch: 6, Loss: 0.6783
Epoch: 7, Loss: 0.5493
Epoch: 8, Loss: 0.5551
Epoch: 9, Loss: 0.5751
Client: 1
Epoch: 0, Loss: 0.8102
Epoch: 1, Loss: 0.8036
Epoch: 2, Loss: 0.7959
Epoch: 3, Loss: 0.8046
Epoch: 4, Loss: 0.7824
Epoch: 5, Loss: 0.7965
Epoch: 6, Loss: 0.7868
Epoch: 7, Loss: 0.7852
Epoch: 8, Loss: 0.