In [3]:
# Bilevel optimization
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchopt
from torchopt.nn import ImplicitMetaGradientModule
from torch.optim.lr_scheduler import LinearLR
from torch_geometric.loader import DataLoader
from optimization_final_loss import Mtopo, Mfeat, GraphSAGE, SupervisedGraphSage
import pickle

# define the implicit gradient module
class GraphOpt(ImplicitMetaGradientModule):
    def __init__(self, mtopo, mfeat, graphsage, device, inner_iter_num=10, T=0.1, batch_size=32):
        super().__init__()
        self.device = device
        # Meta-models
        self.mtopo = mtopo.to(device)
        self.mfeat = mfeat.to(device)
        # Inner model
        self.net = torchopt.module_clone(graphsage, by='deepcopy', detach_buffers=True)
        # other hyperparameters
        self.batch_size = batch_size
        self.T = T
        self.inner_iter_num = inner_iter_num

    def forward(self, features, edge_index, batch):
        ''' get the scores as node embeddings and graph embeddings used for
    loss function calculation in objective() '''
        scores, graph_embeds = self.net(features, edge_index, batch)

        return scores, graph_embeds
    
    def objective(self, features, edge_index, batch, labels, group_labels):
        """
        Define the loss function for the inner-loop optimization using final_loss().
        """
        scores, graph_embeds = self(features, edge_index, batch)
        # step1: get the cross-Entropy Loss
        labels = labels.long() # converting the tensor labels to the torch.int64 data type, also known as long in PyTorch.
        loss_fn = nn.CrossEntropyLoss()
        classification_loss = loss_fn(scores, labels)
        # step2: Compute SNNL Loss
        snnl_loss = self.net.snnl_loss(graph_embeds, group_labels, self.T)
        # step3: Combine Cross-Entropy Loss and SNNL Loss
        loss = classification_loss - snnl_loss
        
        return loss
    
    def solve(self, S, D, s5_nodes_per_graph):
        '''
        use meta model mtopo and mfeat to modify S, which incorporates meta model into the computational graph.
        start inner model training using the self.objective as loss.
        return the optimized inner model and modified dataset Sopt.
        '''
        Sopt = []
        # Step 1: before inner training loop, modify S to get Sopt using Mtopo and Mfeat
        for idx, data in enumerate(S):
            data = data.to(self.device)
            s5_nodes = s5_nodes_per_graph[idx]  # Indices of the 5 nodes to modify
            num_nodes = data.num_nodes
            edge_index = data.edge_index
            A = torch.zeros((num_nodes, num_nodes), device=self.device)
            A[edge_index[0], edge_index[1]] = 1.0

            # Extract subgraph adjacency and features
            A_s5 = A[s5_nodes][:, s5_nodes].unsqueeze(0)  # Shape: [1, 5, 5]
            X_s5 = data.x[s5_nodes].unsqueeze(0)          # Shape: [1, 5, K]

            # Modify topology and features using Mtopo and Mfeat
            A_s5_tilde = self.mtopo(A_s5)  # Modified adjacency
            X_s5_tilde = self.mfeat(X_s5)  # Modified features

            # Update adjacency matrix and features
            A[s5_nodes][:, s5_nodes] = A_s5_tilde.squeeze(0)
            data.x[s5_nodes] = X_s5_tilde.squeeze(0)

            # Reconstruct edge_index from updated adjacency matrix
            A_binary = (A > 0.5).nonzero(as_tuple=False).t()
            data.edge_index = A_binary

            Sopt.append(data)

        # combine the normal dataset D and modified watermark dataset Sopt
        combined_graph_list = D + Sopt
        data_loader = DataLoader(combined_graph_list, batch_size=self.batch_size, shuffle=True)
        # prepare for inner model training loop
        params = tuple(self.parameters())
        inner_optimizer = torchopt.Adam(params,lr=0.0005, weight_decay=5e-4 )
        # set up LinearLR
        total_steps = self.inner_iter_num * len(data_loader)
        scheduler = LinearLR(inner_optimizer, start_factor=1.0, end_factor=0.1, total_iters=total_steps)
        # Step2: start inner model training loop
        with torch.enable_grad():
            for i in range(self.inner_iter_num):
                self.net.train()
                total_loss = 0
                for batch in data_loader:
                    batch = batch.to(self.device)
                    features = batch.x
                    edge_index = batch.edge_index
                    labels = batch.y
                    batch_graph_indices = batch.batch
                    group_labels = batch.group_label
                    inner_optimizer.zero_grad()
                    loss = self.objective(features, edge_index, batch_graph_indices, labels, group_labels)
                    loss.backward(inputs=params)
                    inner_optimizer.step()
                    scheduler.step()
                    total_loss += loss.item() 
                avg_loss = total_loss / len(data_loader)
                print(f"Epoch {i + 1}/{self.inner_iter_num}, Loss: {avg_loss:.4f}")
        
        return self, Sopt

# Set the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# load the necessary datasets
with open('new_keyinut_dataset_enzymes.pkl', 'rb') as f:
    S = pickle.load(f)
with open('train_dataset_enzymes.pkl', 'rb') as f:
    D = pickle.load(f)
# add different group labels for different dataset
for graph in D:
    graph.group_label = torch.tensor(0, dtype=torch.long)
for graph in S:
    graph.group_label = torch.tensor(1, dtype=torch.long)
# load 5 nodes storage information per each graph in dataset S
with open('s5_nodes_per_graph.pkl', 'rb') as f:
    s5_nodes_per_graph = pickle.load(f)  # List of lists, one per graph
# initialize the inner model graphsage
input_dim = 18
hidden_dims = [128, 128]
num_classes = 6
graphsage_model = GraphSAGE(input_dim, hidden_dims, num_sample=10, gcn=False)
graphsage = SupervisedGraphSage(num_classes, graphsage_model, readout="sum")
# initialize the meta model Mtopo and Mfeat
mtopo = Mtopo()
mfeat = Mfeat(feature_dim=18)

# initialize the implicit module
inner_net = GraphOpt( mtopo, mfeat, graphsage, device, inner_iter_num=10, T=0.1, batch_size=32)
# call solve for inner-loop process to use objective to optimize inner model
optimal_inner_net,Sopt = inner_net.solve(S, D, s5_nodes_per_graph)
print("inner model attained the stationary condition'\n'")

# define the outer loss: cross entropy from optimized inner model prediction on dataset D
loss_fn = nn.CrossEntropyLoss()
D_loader = DataLoader(D, batch_size=32, shuffle=False)
outer_loss = 0
for batch in D_loader:
    batch = batch.to(device)
    features = batch.x
    edge_index = batch.edge_index
    D_labels = batch.y
    batch_graph_indices = batch.batch
    # get the each graph prediction scores
    D_scores,_ = optimal_inner_net(features, edge_index, batch)
    # compute the outer loss
    loss = loss_fn(D_scores, D_labels)
    outer_loss += loss.item()
outer_loss = outer_loss/len(D_loader)

# Derive the meta-gradient
torch.autograd.grad(outer_loss, mtopo.parameters())
torch.autograd.grad(outer_loss, mfeat.parameters())



TypeError: Adam is not an Optimizer