In [1]:
import os
from typing import Callable, List, Optional

import numpy as np
import pytorch_lightning as pl
import torch
import torch_geometric as tg
from geomloss import SamplesLoss
from sklearn.cluster import KMeans
from torch_geometric.data import Data, InMemoryDataset, download_url, extract_zip
from torch_geometric.transforms import BaseTransform
from torch_geometric.transforms import RadiusGraph, Compose
from tqdm import tqdm
import matplotlib.pyplot as plt
import networkx as nx
import wandb
from utils.transforms import Graph_to_Subgraph
import wandb
from PIL import Image
import io
import torch_geometric.data
import torch
from torch_geometric.utils import to_networkx

def sinkhorn_loss(x, y):
    # "sinkhorn" loss ('blur':σ=0.01, 'scaling':0.9)
    loss = SamplesLoss(loss="sinkhorn", p=2, blur=0.01, scaling=0.9)
    return loss(x, y)

def get_datasets(data_dir, batch_size, radius, subgraph_dict=None):
    cluster_k = 3
    transforms = []
    transforms.append(RadiusGraph(radius))
    if subgraph_dict is not None:
        subgraph_mode = subgraph_dict.get("mode", None)
        transforms.append(Graph_to_Subgraph(mode=subgraph_mode))
    transforms = Compose(transforms)
    # TODO: RESCALE THE DATASET BACK TO THE ORIGINAL SIZE
    train_val_set = MNISTSuperpixels(root=data_dir, transform=transforms, train=True, cluster_k=cluster_k)
    # split train into train and val sets by taking the last 10% of the training set
    train_set = train_val_set[:int(len(train_val_set) * 0.9)]
    train_set = train_set[:1]
    val_set = train_val_set[int(len(train_val_set) * 0.9):]
    val_set = val_set[:1]
    test_set = MNISTSuperpixels(root=data_dir, transform=transforms, train=False, cluster_k=cluster_k)
    # print which transforms are we using
    print("Transforms: ", transforms)
    #assert len(train_set) + len(val_set) == len(train_val_set)

    train_loader = tg.loader.DataLoader(train_set, batch_size=batch_size, shuffle=True)
    val_loader = tg.loader.DataLoader(val_set, batch_size=batch_size, shuffle=False)
    test_loader = tg.loader.DataLoader(test_set, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader

class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, warmup, max_iters):
        self.warmup = warmup
        self.max_num_iters = max_iters
        super().__init__(optimizer)

    def get_lr(self):
        lr_factor = self.get_lr_factor(epoch=self.last_epoch)
        return [base_lr * lr_factor for base_lr in self.base_lrs]

    def get_lr_factor(self, epoch):
        lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_num_iters))
        if epoch <= self.warmup:
            lr_factor *= epoch * 1.0 / self.warmup
        return lr_factor

class KMeansClustering(BaseTransform):
    def __init__(self, num_clusters):
        self.num_clusters = num_clusters

    def fit(self, data):
        pos = data.pos
        x = data.x

        N = data.num_nodes
        k = N // self.num_clusters

        pos_flattened = pos.view(-1, pos.size(-1)).numpy()

        kmeans = KMeans(n_clusters=k, n_init=3)
        self.labels = kmeans.fit_predict(pos_flattened)
        self.labels = torch.from_numpy(self.labels)  # Convert labels to torch.Tensor
        self.centroids_pos = torch.zeros(k, pos.size(-1))
        self.centroids_x = torch.zeros(k, x.size(-1))

        for node_idx, cluster_idx in enumerate(self.labels):
            self.centroids_pos[cluster_idx] += pos[node_idx]
            self.centroids_x[cluster_idx] += x[node_idx]

        for cluster_idx in range(k):
            indices = torch.nonzero(self.labels == cluster_idx).view(-1)
            count = indices.size(0)

            self.centroids_pos[cluster_idx] /= count
            self.centroids_x[cluster_idx] /= count

    def __call__(self, data):
        pos = data.pos
        x = data.x

        # Assign the precomputed centroids and labels
        data.x = self.centroids_x
        data.x_full = x
        data.pos = self.centroids_pos
        data.pos_full = pos
        data.cluster_labels = self.labels

        return data

class MNISTSuperpixels(InMemoryDataset):
    url = 'https://data.pyg.org/datasets/MNISTSuperpixels.zip'

    def __init__(
        self,
        root: str,
        train: bool = True,
        transform: Optional[Callable] = None,
        pre_transform: Optional[Callable] = None,
        pre_filter: Optional[Callable] = None,
        cluster_k: int = None,
        **kwargs,
    ):
        self.cluster_k = cluster_k  # Store cluster_k for later use
        super().__init__(root, transform, pre_transform, pre_filter)
        path = self.processed_paths[0] if train else self.processed_paths[1]
        self.data, self.slices = torch.load(path)

    @property
    def raw_file_names(self) -> str:
        return 'MNISTSuperpixels.pt'

    @property
    def processed_file_names(self) -> List[str]:
        if self.cluster_k is None:
            return ['train_data.pt', 'test_data.pt']
        else:
            return [f'train_data_k{self.cluster_k}.pt', f'test_data_k{self.cluster_k}.pt']

    def download(self):
        path = download_url(self.url, self.raw_dir)
        extract_zip(path, self.raw_dir)
        os.unlink(path)

    def process(self):
        inputs = torch.load(self.raw_paths[0])
        for i in range(len(inputs)):
            data_list = [Data(**data_dict) for data_dict in inputs[i]]

            if self.pre_filter is not None:
                data_list = [d for d in data_list if self.pre_filter(d)]

            if self.pre_transform is not None:
                data_list = [self.pre_transform(d) for d in data_list]

            if self.cluster_k is not None:
                with tqdm(total=len(data_list), desc=f'Cluster K={self.cluster_k}') as pbar:
                    for j in range(len(data_list)):
                        cluster_transform = KMeansClustering(num_clusters=self.cluster_k)
                        cluster_transform.fit(data_list[j])
                        data_list[j] = cluster_transform(data_list[j])
                        pbar.update(1)

            torch.save(self.collate(data_list), self.processed_paths[i])

In [19]:
class RFF(nn.Module):
    def __init__(self, in_features, out_features, sigma=1.0):
        super().__init__()
        self.sigma = sigma
        self.in_features = in_features
        self.out_features = out_features

        if out_features % 2 != 0:
            self.compensation = 1
        else:
            self.compensation = 0

        B = torch.randn(int(out_features / 2) + self.compensation, in_features) * sigma
        B /= math.sqrt(2)
        self.register_buffer("B", B)

    def forward(self, x):
        x = F.linear(x, self.B)
        x = torch.cat((x.sin(), x.cos()), dim=-1)
        if self.compensation:
            x = x[..., :-1]
        return x

    def extra_repr(self) -> str:
        return "in_features={}, out_features={}, sigma={}".format(
            self.in_features, self.out_features, self.sigma
        )

In [28]:
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch_geometric as tg
#from torch_scatter import scatter_add, scatter
import torch.nn.functional as F
import math
from torch_geometric.nn.inits import glorot, zeros
from torch_geometric.utils import softmax
import math
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import Tensor
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import Adj, OptTensor, PairTensor, SparseTensor
from torch_geometric.utils import softmax
from torch_scatter import scatter
import torch_geometric as tg
import torch_geometric.nn as geom_nn
from torch_geometric.nn.conv import TransformerConv
from utils.tools import catch_lone_sender, fully_connected_edge_index
class EGNN_FullLayer(tg.nn.MessagePassing):
    def __init__(self, emb_dim, activation="relu", norm="layer", aggr="add", RFF_dim=64, RFF_sigma=5, **kwargs):
        """E(n) Equivariant GNN Layer

        Paper: E(n) Equivariant Graph Neural Networks, Satorras et al.

        Args:
            emb_dim: (int) - hidden dimension `d`
            activation: (str) - non-linearity within MLPs (swish/relu)
            norm: (str) - normalisation layer (layer/batch)
            aggr: (str) - aggregation function `\oplus` (sum/mean/max)
        """
        # Set the aggregation function
        super().__init__(aggr=aggr)
        self.update_pos = True
        self.emb_dim = emb_dim
        self.RFF_dim = RFF_dim
        self.RFF_sigma = RFF_sigma
        self.activation = {"swish": nn.SiLU(), "relu": nn.ReLU()}[activation]
        self.norm = {"layer": torch.nn.LayerNorm, "batch": torch.nn.BatchNorm1d}[norm]

        # MLP `\psi_h` for computing messages `m_ij`
        self.mlp_msg = nn.Sequential(
             nn.Linear(2 * emb_dim + 1 if self.RFF_dim is None else 2 * emb_dim + RFF_dim, emb_dim),
            self.norm(emb_dim),
            self.activation,
            nn.Linear(emb_dim, emb_dim),
            self.norm(emb_dim),
            self.activation,
        )
        # MLP `\psi_x` for computing messages `\overrightarrow{m}_ij`
        self.mlp_pos = nn.Sequential(
            nn.Linear(emb_dim, emb_dim), self.norm(emb_dim), self.activation, nn.Linear(emb_dim, 1)
        )
        # MLP `\phi` for computing updated node features `h_i^{l+1}`
        self.mlp_upd = nn.Sequential(
            nn.Linear(2 * emb_dim, emb_dim),
            self.norm(emb_dim),
            self.activation,
            nn.Linear(emb_dim, emb_dim),
            self.norm(emb_dim),
            self.activation,
        )
        if self.RFF_dim is not None:
            self.RFF = RFF(1, RFF_dim, RFF_sigma)
    def forward(self, h, pos, edge_index):
        """
        Args:
            h: (n, d) - initial node features
            pos: (n, 3) - initial node coordinates
            edge_index: (e, 2) - pairs of edges (i, j)
        Returns:
            out: [(n, d),(n,3)] - updated node features
        """
        out = self.propagate(edge_index, h=h, pos=pos)
        return out

    def message(self, h_i, h_j, pos_i, pos_j):
        # Compute messages
        pos_diff = pos_i - pos_j
        dists = torch.norm(pos_diff, dim=-1).unsqueeze(1)
        if self.RFF_dim is not None:
            #print("USING RFF")
            dists = self.RFF(dists)
        msg = torch.cat([h_i, h_j, dists], dim=-1)
        msg = self.mlp_msg(msg)
        # Scale magnitude of displacement vector
        pos_diff = pos_diff * self.mlp_pos(msg)  # torch.clamp(updates, min=-100, max=100)
        return msg, pos_diff

    def aggregate(self, inputs, index):
        msgs, pos_diffs = inputs

        # Aggregate messages
        msg_aggr = scatter(msgs, index, dim=self.node_dim, reduce=self.aggr)
        # Aggregate displacement vectors
        if self.update_pos:
            pos_aggr = scatter(pos_diffs, index, dim=self.node_dim, reduce="mean")

        nodes_to_upd = torch.unique(index)
        msg_aggr = msg_aggr[nodes_to_upd]

        if self.update_pos:
            pos_aggr = pos_aggr[nodes_to_upd]
        else:
            pos_aggr = None

        return msg_aggr, pos_aggr, nodes_to_upd

    def update(self, aggr_out, h, pos):
        msg_aggr, pos_aggr, nodes_to_upd = aggr_out

        upd_out = h
        upd_out[nodes_to_upd] = self.mlp_upd(torch.cat([h[nodes_to_upd], msg_aggr], dim=-1))
        if self.update_pos:
            upd_pos = pos
            #print('pos before is ', pos)
            #print('pos aggr is ', pos_aggr)
            upd_pos[nodes_to_upd] = pos[nodes_to_upd] + pos_aggr
            # print the difference
            #print('pos after is ', upd_pos)

        else:
            upd_pos = pos

        return upd_out, upd_pos


class Superpixel_EGNN(nn.Module):
    def __init__(
            self,
            depth=5,
            hidden_features=128,
            node_features=1,
            out_features=1,
            activation="relu",
            norm="layer",
            aggr="sum",
            pool="add",
            residual=True,
            mask=True,
            **kwargs
    ):
        super().__init__()
        # Name of the network
        self.name = "Superpixel_EGNN"
        self.depth = depth
        # Embedding lookup for initial node features
        self.emb_in = nn.Linear(node_features, hidden_features)

        # Stack of GNN layers
        self.ground_mps = torch.nn.ModuleList()
        self.ground_to_sub_mps = torch.nn.ModuleList()
        self.sub_mps = torch.nn.ModuleList()
        self.sub_to_ground_mps = torch.nn.ModuleList()
        for layer in range(depth):
            self.ground_mps.append(EGNN_FullLayer(hidden_features, activation, norm, aggr))
            self.ground_to_sub_mps.append(EGNN_FullLayer(hidden_features, activation, norm, aggr))
            self.sub_mps.append(EGNN_FullLayer(hidden_features, activation, norm, aggr))
            #self.sub_to_ground_mps.append(EGNN_FullLayer(hidden_features, activation, norm, aggr))
        self.residual = residual
        self.mask = mask

        self.pred = torch.nn.Sequential(
        torch.nn.Linear(hidden_features*1, hidden_features),
        torch.nn.ReLU(),
        torch.nn.Linear(hidden_features, out_features)
        )
    def forward(self, batch):

        h = self.emb_in(batch.x)  # (n,) -> (n, d)
        pos = batch.pos.clone()  # (n, 3)
        pos[~batch.ground_node] += torch.randn_like(pos[~batch.ground_node]) * 0.01
        h_ground = h[batch.ground_node]
        pos_ground = pos[batch.ground_node]

        h_sub = h[~batch.ground_node]
        pos_sub = pos[~batch.ground_node]

        for layer_idx in range(self.depth):
            h_old = h.clone()
            h_0 = h
            pos_old = pos.clone()
            h_update, pos_update = self.ground_mps[layer_idx](h, pos, batch.edge_index)
            h = h + h_update if self.residual else h_update
            pos = pos_update
            if self.mask:
                pos[batch.ground_node] = pos_old[batch.ground_node]

            pos_old = pos.clone()
            h_update, pos_update = self.ground_to_sub_mps[layer_idx](h, pos, batch.node_subnode_index)
            h = h + h_update if self.residual else h_update
            pos = pos_update
            if self.mask:
                pos[batch.ground_node] = pos_old[batch.ground_node]
            pos_old = pos.clone()

            pos_before = pos.clone()
            h_update, pos_update = self.sub_mps[layer_idx](h, pos, batch.subgraph_edge_index)
            #print('pos update is', pos_update)
            #print('difference is', pos_update-pos)
            h = h + h_update if self.residual else h_update
            pos = pos_update
            #print('difference is', pos-pos_before)


            if self.mask:
                pass;
                #pos[batch.ground_node] = pos_old[batch.ground_node]


        h = self.pred(h)
        superpixel_pos = pos[~batch.ground_node]
        superpixel_h = h[~batch.ground_node]
        return superpixel_pos, superpixel_h

In [34]:
# get the dataset
cluster_k = 3
batch_size = 1
data_dir = "~/Documents/Github/FractalMessagePassing/data/mnist"
subgraph_dict = {"mode": "transformer_3"}
radius = 16
transforms = []
if radius is not None:
    transforms.append(RadiusGraph(radius))
    if subgraph_dict is not None:
        subgraph_mode = subgraph_dict.get("mode", None)
        print("Subgraph mode: ", subgraph_mode)
        transforms.append(Graph_to_Subgraph(mode=subgraph_mode))
    transforms = Compose(transforms)
train_val_set = MNISTSuperpixels(root=data_dir, transform=transforms, train=True, cluster_k=cluster_k)
train_set = train_val_set[:int(len(train_val_set) * 0.9)]
train_set = train_set[:1]
val_set = train_val_set[int(len(train_val_set) * 0.9):]
val_set = val_set[:1]
train_loader = tg.loader.DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = tg.loader.DataLoader(val_set, batch_size=batch_size, shuffle=False)

Subgraph mode:  transformer_3


In [60]:
# get an optimizer, write a training loop
model = Superpixel_EGNN(depth=5,
            hidden_features=32,
            node_features=4,
            out_features=1,
            activation="relu",
            norm="layer",
            aggr="sum",
            pool="add",
            residual=True,
            mask=True,)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = torch.nn.MSELoss()
# make a model that has a few hidden layers, and a few output layers

In [63]:

for epoch in range(100):
    for graph in train_loader:
        optimizer.zero_grad()
        batch = graph.clone()
        #print('batch.x is ', batch.x)
        #batch.pos = batch.pos + 3*torch.randn_like(batch.pos)
        superpixel_pos, superpixel_h = model(batch)
        #print("first 10 superpixel pos is", superpixel_pos[:10])
        total_h = superpixel_h.sum()
        total_h_true = batch.x_full.sum()
        total_pos = superpixel_pos.sum()
        total_pos_true = batch.pos_full.sum()
        diff = superpixel_pos - batch.pos_full
        diff = diff**2
        loss = sinkhorn_loss(superpixel_pos, batch.pos_full)
        #print('superpixel pos is', superpixel_pos)
        #print('batch pos is', batch.pos_full)
        #loss = criterion(superpixel_pos, batch.pos_full)
        #loss = criterion(diff, torch.zeros_like(diff))
        #loss = criterion(total_h, total_h_true)
        #loss = criterion(total_pos, total_pos_true)
        loss.backward()
        for name, param in model.named_parameters():
            # check if they are not None
            if param.grad is not None:
                    #print(name, param.grad.data.sum())
                    pass
        optimizer.step()
        #print loss every 10 epochs
        if epoch % 10 == 0:
            print('loss is ', loss)
print('Difference between h and true h is', superpixel_h - batch.x_full)
print('Superpixel h is', superpixel_h)

loss is  tensor(3.3775, grad_fn=<SelectBackward0>)
loss is  tensor(3.3771, grad_fn=<SelectBackward0>)
loss is  tensor(3.3777, grad_fn=<SelectBackward0>)
loss is  tensor(3.3773, grad_fn=<SelectBackward0>)
loss is  tensor(3.3773, grad_fn=<SelectBackward0>)
loss is  tensor(3.3772, grad_fn=<SelectBackward0>)
loss is  tensor(3.3772, grad_fn=<SelectBackward0>)
loss is  tensor(3.3773, grad_fn=<SelectBackward0>)
loss is  tensor(3.3770, grad_fn=<SelectBackward0>)
Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/tin/anaconda3/envs/AudioSeparation/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3505, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_11744/1233482350.py", line -1, in <module>
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/tin/anaconda3/envs/AudioSeparation/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 2102, in showtraceback
    stb = self.InteractiveTB.structured_traceback(
  File "/home/tin/anaconda3/envs/AudioSeparation/lib/python3.10/site-packages/IPython/core/ultratb.py", line 1310, in structured_traceback
    return FormattedTB.structured_traceback(
  File "/home/tin/anaconda3/envs/AudioSeparation/lib/python3.10/site-packages/IPython/core/ultratb.py", line 1199, in structured_traceback
    return VerboseTB.structured_traceback(
  File "/home/tin/anaconda3/