In [19]:
import torch
from torch.functional import Tensor
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import tqdm
from tests.utils.data import PairData
import random
import pickle
from typing import List, Tuple

from sgmatch.models.GraphSim import GraphSim, GraphSim_v2
from tests.utils.dataset import load_dataset
from tests.utils.parser import parser
from tests.utils.data import PairData

## Dataset retreival or generation 

In [20]:
def bfs(x,edge_index):
    adjmat={}
    reached={}

    for _ in range(x.shape[0]):
        adjmat[_]=[]
        reached[_]=False

    for each in torch.transpose(edge_index,0,1):
        adjmat[int(each[0])].append(int(each[1]))
    
    visited=[]
    visited.append(0)
    reached[0]=True
    output=[]
    while len(visited)!=0:
        top=visited[0]
        for each in adjmat[top]:
            if not reached[each] :
                visited.append(each)
                reached[each]=True
        output.append(top)
        visited.remove(top)
    
    mapping=dict(zip(output,[x for x in range(x.shape[0])]))
    
    return x[output], torch.Tensor([[mapping[int(e)] for e in list(edge_index[0])],[mapping[int(e)] for e in list(edge_index[1])]])
            

In [21]:
def create_graph_pairs(train_dataset, test_dataset) -> Tuple[List]:
    train_graph_pairs = []
    with tqdm.tqdm(total=len(train_dataset)**2, desc='Train graph pairs completed: ') as bar:
        for idx1, graph1 in enumerate(train_dataset):
            for idx2, graph2 in enumerate(train_dataset):
                if idx1 == idx2:
                    continue
                # Initializing Data
                x_s, edge_index_s=bfs(graph1.x,graph1.edge_index)
                x_t, edge_index_t=bfs(graph2.x,graph2.edge_index)

                #Max padding as stated in paper
                if graph1.num_nodes < graph2.num_nodes:
                    x_s=torch.cat((x_s,torch.zeros(graph2.num_nodes-graph1.num_nodes,graph1.x.shape[1])),dim=0)
                else:
                    x_t=torch.cat((x_t,torch.zeros(graph1.num_nodes-graph2.num_nodes,graph1.x.shape[1])),dim=0)

                norm_ged = train_dataset.norm_ged[graph1.i, graph2.i]
                graph_sim = torch.exp(-norm_ged).unsqueeze(-1)
                
                # Making Graph Pair
                if isinstance(x_s, Tensor) and isinstance(x_t, Tensor):
                    graph_pair = PairData(edge_index_s=edge_index_s, x_s=x_s,
                                        edge_index_t=edge_index_t, x_t=x_t,
                                        y=graph_sim)
                    
                    # Saving all the Graph Pairs to the List for Batching and Data Loading
                    train_graph_pairs.append(graph_pair)
            bar.update(len(train_dataset))
    
    test_graph_pairs = []
    with tqdm.tqdm(total=len(test_dataset)*len(train_dataset), desc='Test graph pairs completed: ') as bar:
        for graph1 in test_dataset:
            for graph2 in train_dataset:
                # Initializing Data
                x_s, edge_index_s=bfs(graph1.x,graph1.edge_index)
                x_t, edge_index_t=bfs(graph2.x,graph2.edge_index)

                #Max padding as stated in paper
                if graph1.num_nodes < graph2.num_nodes:
                    x_s=torch.cat((x_s,torch.zeros(graph2.num_nodes-graph1.num_nodes,graph1.x.shape[1])),dim=0)
                else:
                    x_t=torch.cat((x_t,torch.zeros(graph1.num_nodes-graph2.num_nodes,graph1.x.shape[1])),dim=0)
                    
                norm_ged = train_dataset.norm_ged[graph1.i, graph2.i]
                graph_sim = torch.exp(-norm_ged).unsqueeze(-1)
                
                # Making Graph Pair
                if isinstance(x_s, Tensor) and isinstance(x_t, Tensor):
                    graph_pair = PairData(edge_index_s=edge_index_s, x_s=x_s,
                                        edge_index_t=edge_index_t, x_t=x_t,
                                        y=graph_sim)
                
                    # Saving all the Graph Pairs to the List for Batching and Data Loading
                    test_graph_pairs.append(graph_pair)
            bar.update(len(train_dataset))
    
    return train_graph_pairs, test_graph_pairs


## Training

In [22]:
def train(train_loader, val_loader, model, loss_criterion, optimizer, device, num_epochs=10):
    batch_train_loss_sum = 0
    batch_val_loss_sum = 0

    for epoch in range(num_epochs):
        with tqdm.tqdm(total=len(train_loader), desc='Train batches completed: ') as bar:
            for batch_idx, train_batch in enumerate(train_loader):
                model.train()
                train_batch = train_batch.to(device)
                optimizer.zero_grad()

                pred_sim = model(train_batch.x_s, train_batch.edge_index_s, train_batch.x_t, 
                                train_batch.edge_index_t, train_batch.x_s_batch, train_batch.x_t_batch)
                mean_batch_loss = loss_criterion(pred_sim, train_batch.y)
                # Compute Gradients via Backpropagation
                mean_batch_loss.backward()
                # Update Parameters
                optimizer.step()
                batch_train_loss_sum += mean_batch_loss.item()*len(train_batch)
                
                bar.update(1)

        with tqdm.tqdm(total=len(val_loader), desc='Validation batches completed: ') as bar:
            for batch_idx, val_batch in enumerate(val_loader):
                model.eval()
                with torch.no_grad():
                    val_batch = val_batch.to(device)
                    pred_sim = model(val_batch.x_s, val_batch.edge_index_s, 
                            val_batch.x_t, val_batch.edge_index_t, val_batch.x_s_batch, val_batch.x_t_batch)
                    mean_val_loss = loss_criterion(pred_sim, val_batch.y)
                    batch_val_loss_sum += mean_val_loss.item()*len(val_batch)

                bar.update(1)
        
        if torch.cuda.is_available():
            torch.cuda.empty_cache() 
    
        # Printing Epoch Summary
        print(f"Epoch: {epoch+1}/{num_epochs} | Per Graph Train MSE: {batch_train_loss_sum / len(train_loader.dataset)} | Mean batch loss :{mean_batch_loss} \n   |Per Graph Validation MSE: {batch_val_loss_sum / len(val_loader.dataset)}| Mean_val_loss: {mean_val_loss}")


## Model

In [5]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

data_path="./data"
train_batch_size=128
val_batch_size=64
test_batch_size=256
learning_rate=0.01

train_dataset = load_dataset(dpath=data_path+"/aids/", name="GED", category="AIDS700nef", train=True)
test_dataset = load_dataset(dpath=data_path+"/aids/", name="GED", category="AIDS700nef", train=False)

train_graph_pairs, test_graph_pairs = create_graph_pairs(train_dataset, test_dataset)
torch.save(train_graph_pairs, data_path+"/aids/graph_pairs/train_graph_pairs_Graphsim.pt")
torch.save(test_graph_pairs, data_path+"/aids/graph_pairs/test_graph_pairs_Graphsim.pt")

Train graph pairs completed:   0%|          | 0/313600 [00:00<?, ?it/s]

Train graph pairs completed:   2%|▏         | 6720/313600 [00:06<04:58, 1027.38it/s]


Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "c:\Users\saiki\IITB\env_general\Lib\site-packages\IPython\core\interactiveshell.py", line 3526, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "C:\Users\saiki\AppData\Local\Temp\ipykernel_18352\334977845.py", line 12, in <module>
    train_graph_pairs, test_graph_pairs = create_graph_pairs(train_dataset, test_dataset)
                                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\saiki\AppData\Local\Temp\ipykernel_18352\2907749329.py", line 10, in create_graph_pairs
    x_t, edge_index_t=bfs(graph2.x,graph2.edge_index)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\saiki\AppData\Local\Temp\ipykernel_18352\4085205553.py", line -1, in bfs
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "c:\Users\saiki\IITB\env_general\Lib\site-packages\IPython\core\interactiveshell.py", li

In [7]:
train_graph_pairs, test_graph_pairs = torch.load(data_path+"/aids/graph_pairs/train_graph_pairs_Graphsim.pt"),\
                                             torch.load(data_path+"/aids/graph_pairs/test_graph_pairs_Graphsim.pt")


In [23]:
import numpy as np
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

data_path="./data"
train_batch_size=128
val_batch_size=64
test_batch_size=256
learning_rate=0.01


    
val_idxs = np.random.randint(len(train_graph_pairs), size=len(test_graph_pairs))
val_graph_pairs = [train_graph_pairs[idx] for idx in val_idxs]
train_idxs = set(range(len(train_graph_pairs))) - set(val_idxs)
train_graph_pairs = [train_graph_pairs[idx] for idx in train_idxs]
del val_idxs, train_idxs



train_loader = DataLoader(train_graph_pairs, batch_size = train_batch_size, follow_batch = ["x_s", "x_t"], shuffle = True)
val_loader = DataLoader(val_graph_pairs, batch_size = val_batch_size, follow_batch = ["x_s", "x_t"], shuffle = True)
test_loader = DataLoader(test_graph_pairs, batch_size = test_batch_size, follow_batch = ["x_s", "x_t"], shuffle = True)

class CustomModuleList(torch.nn.Module):
    def __init__(self, module_list):
        super(CustomModuleList, self).__init__()
        self.module_list = torch.nn.ModuleList(module_list)

    def forward(self, x):
        for module in self.module_list:
            x = module(x)
        return x

convo_filters=torch.nn.ModuleList([
    torch.nn.ReLU(),
    torch.nn.Conv2d(in_channels=1, out_channels=2, kernel_size=3, stride=1, padding=1),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),
    ])

model = GraphSim(input_dim=train_loader.dataset[0].x_s.shape[-1],conv_filters=CustomModuleList(convo_filters)).to(device)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), learning_rate)

In [24]:
train(train_loader, val_loader, model, criterion, optimizer, device)

Train batches completed:   0%|          | 0/113 [00:00<?, ?it/s]


TypeError: GCNConv.forward() missing 1 required positional argument: 'edge_index'

In [128]:
train_dataset[0].x.shape,train_dataset[0].edge_index


(torch.Size([10, 29]),
 tensor([[0, 1, 1, 1, 1, 2, 3, 4, 5, 5, 6, 7, 7, 8, 9, 9, 9, 9],
         [1, 0, 3, 5, 6, 9, 1, 9, 1, 7, 1, 5, 9, 9, 2, 4, 7, 8]]))