Code to install the required packages to operate PyG



In [None]:
# Install required packages.
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

# Helper function for visualization.
%matplotlib inline
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

def visualize(h, color):
    z = TSNE(n_components=2).fit_transform(h.detach().cpu().numpy())

    plt.figure(figsize=(10,10))
    plt.xticks([])
    plt.yticks([])

    plt.scatter(z[:, 0], z[:, 1], s=70, c=color, cmap="Set2")
    plt.show()

2.0.1+cu118
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.2/10.2 MB[0m [31m75.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.8/4.8 MB[0m [31m51.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for torch_geometric (pyproject.toml) ... [?25l[?25hdone


Code to Split up an MNIST image into 4 patches

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import random_split

class Patch(object):
    """
    Creates patches from images
    """

    def __init__(self, patch_size=14):
      self.patch_size = patch_size

    def __call__(self, image):
      #print("hello")
      #print(image)
      #print(image.shape)
      patched_image = image.unfold(1, self.patch_size, self.patch_size).unfold(2, self.patch_size, self.patch_size)
      # patched_image = rearrange(image.unsqueeze(0), 'b c (h h1) (w w1) -> (b h1 w1) c h w', h1=2,w1=2)
      # patched_image = rearrange(F.unfold(image, patch_size, patch_size), '(h w) c -> c h w', h=patch_size)
      v = torch.flatten(patched_image)
      #print(v)
      patched_image = torch.reshape(v,(4,14,14))
      #print(patched_image.shape)
      #print(patched_image)

      return patched_image

transform = transforms.Compose([
     transforms.ToTensor(),
     Patch(),
 ])

mnist_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)


train_size = int(0.8 * len(mnist_data))
test_size = len(mnist_data) - train_size
train_dataset, test_dataset = random_split(mnist_data, [train_size, test_size])

# data_loader = torch.utils.data.DataLoader(dataset=mnist_data,
#                                           batch_size=64,
#                                           shuffle=True)

data_loader_train = torch.utils.data.DataLoader(dataset=train_dataset,
                                          batch_size=64,
                                          shuffle=True)

data_loader_test = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=64,
                                          shuffle=True)

# train_size = int(0.8 * len(data_loader.dataset))
# valid_size = len(data_loader) - train_size

# train_dataset, valid_dataset = random_split(data_loader.dataset, [train_size, valid_size])
# print(train_dataset)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 131763432.81it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 21631373.90it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 45035596.63it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 4768592.93it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



Code below is used to generate Dataloader (Train and Test) to be used with PyG.
For the label, the first device has the accurate label of the MNIST digit and the other devices are assigned zero (primarily because after the first message passing round, the other devices will have zero, in the simple VFL case). The edge matrix defined below is for a VFL case, where the first device (out of 4) acts as both the client and server.

In [None]:
edge_index = torch.tensor([[0, 0],
                           [1, 0],
                           [2, 0],
                           [3, 0]], dtype=torch.long)

In [None]:
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
dataSet_VFL_FC = []
for (img,label) in data_loader_train:
  for i,y in enumerate(img):
    sens1 = ((y[0].reshape(1,-1))[0,:]).tolist()
    sens2 = ((y[1].reshape(1,-1))[0,:]).tolist()
    sens3 = ((y[2].reshape(1,-1))[0,:]).tolist()
    sens4 = ((y[3].reshape(1,-1))[0,:]).tolist()
    x1 = torch.tensor([sens1,sens2,sens3,sens4],dtype = torch.float)
    #y = [label[i].tolist() for x in range(4)]
    y = [label[i].tolist(),0,0,0]
    y = torch.tensor(y, dtype=torch.long)
    data_obj = Data(x=x1, edge_index=edge_index.t(), y=y)
    dataSet_VFL_FC.append(data_obj)

train_loader_VFL_FC  = DataLoader(dataSet_VFL_FC, batch_size=64)


dataSet_VFL_FC_test = []
for (img,label) in data_loader_test:
  for i,y in enumerate(img):
    sens1 = ((y[0].reshape(1,-1))[0,:]).tolist()
    sens2 = ((y[1].reshape(1,-1))[0,:]).tolist()
    sens3 = ((y[2].reshape(1,-1))[0,:]).tolist()
    sens4 = ((y[3].reshape(1,-1))[0,:]).tolist()
    x1 = torch.tensor([sens1,sens2,sens3,sens4],dtype = torch.float)
    #y = [label[i].tolist() for x in range(4)]
    y = [label[i].tolist(),0,0,0]
    y = torch.tensor(y, dtype=torch.long)
    data_obj = Data(x=x1, edge_index=edge_index.t(), y=y)
    dataSet_VFL_FC_test.append(data_obj)
test_loader_VFL_FC = DataLoader(dataSet_VFL_FC_test, batch_size=64)

In [None]:
print(len(test_loader_VFL_FC.dataset))

12000


The Code below is a Custom aggregation function that is used for aggregating by concateniating features across multiple nodes

In [None]:
from typing import Optional

from torch import Tensor

from torch_geometric.nn.aggr import Aggregation

import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.nn import aggr
from torch_geometric.utils import add_self_loops, degree

import torch
from torch_geometric.data import Data

from torch_geometric.nn import GCNConv
from torch_geometric.nn import SimpleConv

import torch
import torch.nn as nn
import torch.nn.functional as F



class CustomMLP(Aggregation):
    r"""Performs MLP aggregation in which the elements to aggregate are
    flattened into a single vectorial representation, and are then processed by
    a Multi-Layer Perceptron (MLP), as described in the `"Graph Neural Networks
    with Adaptive Readouts" <https://arxiv.org/abs/2211.04952>`_ paper.

    .. note::

        :class:`GRUAggregation` requires sorted indices :obj:`index` as input.
        Specifically, if you use this aggregation as part of
        :class:`~torch_geometric.nn.conv.MessagePassing`, ensure that
        :obj:`edge_index` is sorted by destination nodes, either by manually
        sorting edge indices via :meth:`~torch_geometric.utils.sort_edge_index`
        or by calling :meth:`torch_geometric.data.Data.sort`.

    .. warning::

        :class:`MLPAggregation` is not a permutation-invariant operator.

    Args:
        in_channels (int): Size of each input sample.
        out_channels (int): Size of each output sample.
        max_num_elements (int): The maximum number of elements to aggregate per
            group.
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.models.MLP`.
    """
    # def __init__(self, in_channels: int, out_channels: int,
    #              max_num_elements: int, **kwargs):
    def __init__(self, max_num_elements: int, **kwargs):
        super().__init__()

        # self.in_channels = in_channels
        # self.out_channels = out_channels
        self.max_num_elements = max_num_elements

        # from torch_geometric.nn import MLP
        #self.mlp = MLP(in_channels=in_channels * max_num_elements,
                       #out_channels=out_channels, **kwargs)

        #self.reset_parameters()


    # def reset_parameters(self):
    #     self.mlp.reset_parameters()


    def forward(self, x: Tensor, index: Optional[Tensor] = None,
                ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
                dim: int = -2) -> Tensor:
        x, _ = self.to_dense_batch(x, index, ptr, dim_size, dim,
                                   max_num_elements=self.max_num_elements)
        return x.view(-1, x.size(1) * x.size(2))



The following code is used for defining the network

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import SimpleConv
from torch_geometric.nn import global_mean_pool
from torch_geometric.nn import aggr

class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        self.conv1 = SimpleConv(aggr = CustomMLP(max_num_elements=4))

        self.lin = Linear(784, 120)
        self.lin1 = Linear(120, 50)
        self.lin2 = Linear(50, 10)

    #def forward(self, x, edge_index): #Good when not using batching
    def forward(self, data):
      x, edge_index = data.x, data.edge_index
      x = self.conv1(x, edge_index)
      x = self.lin(x)
      x = x.relu()

      x = self.lin1(x)
      x = x.relu()
      x = self.lin2(x)


      return F.log_softmax(x, dim=1)

model = GCN(196,196)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
def test():
  model.eval()
  correct = 0
  for batch in test_loader_VFL_FC:
    output = model(batch)
    _, predicted = torch.max(output, 1)
    for batch_count in range(0,predicted.shape[0],4):
      correct += int((predicted[batch_count]==batch.y[batch_count]).item())
  return (correct) / (len(test_loader_VFL_FC.dataset))


for epoch in range(10):
  model.train()
  total_loss = 0
  for batch in train_loader_VFL_FC:
    optimizer.zero_grad()
    output = model(batch)
    loss = criterion(output,batch.y)
    loss.backward()
    optimizer.step()
    total_loss+= loss.item()
  train_acc = test()
  avg_loss = total_loss/len(train_loader_VFL_FC)
  print(f'Epoch = {epoch}, train_loss = {avg_loss}, train_acc = {train_acc}')

Epoch = 0, train_loss = 0.2697514545197288, train_acc = 0.929
Epoch = 1, train_loss = 0.04819771861284971, train_acc = 0.94875
Epoch = 2, train_loss = 0.03316916544611255, train_acc = 0.9568333333333333
Epoch = 3, train_loss = 0.024610665819762897, train_acc = 0.9629166666666666
Epoch = 4, train_loss = 0.018891991714326043, train_acc = 0.9655
Epoch = 5, train_loss = 0.014828129508843024, train_acc = 0.9675
Epoch = 6, train_loss = 0.011844913529775417, train_acc = 0.966
Epoch = 7, train_loss = 0.009372338285727892, train_acc = 0.9670833333333333
Epoch = 8, train_loss = 0.007575706105097197, train_acc = 0.9679166666666666
Epoch = 9, train_loss = 0.0062910464338977665, train_acc = 0.9663333333333334
