In [63]:
# from gmn_lim.graph_construct.model_arch_graph import seq_to_feats, sequential_to_arch, arch_to_graph, graph_to_arch, arch_to_sequential
from gmn_lim.graph_models import EdgeMPNNDiT
from gmn_lim.feature_extractor_gmn import GraphPredGen
from torchviz import make_dot
import torch
import torch.nn as nn
from torch.func import functional_call
import json
import math
import time
import os
import sys
sys.path.insert(0, "./gmn_lim/graph_construct")
from gmn_lim.graph_construct.constants import NODE_TYPES, EDGE_TYPES, CONV_LAYERS, NORM_LAYERS, RESIDUAL_LAYERS, NODE_TYPE_TO_LAYER
from gmn_lim.graph_construct.utils import (
    make_node_feat,
    make_edge_attr,
    conv_to_graph,
    linear_to_graph,
    norm_to_graph,
    ffn_to_graph,
    basic_block_to_graph,
    self_attention_to_graph,
    equiv_set_linear_to_graph,
    triplanar_to_graph,
)
from gmn_lim.graph_construct.layers import (
    Flatten,
    PositionwiseFeedForward,
    BasicBlock,
    SelfAttention,
    EquivSetLinear,
    TriplanarGrid,
)

In [64]:
cifar_zoo_dir = '/media/siddhartha/games/gmn_data/fixed_hp_data/'
# -> folderid
#       -> epoch_0_feats.pt
#       -> ...
#       -> epoch_0_feats.pt
#       -> results.json
#       -> torch_model.pt


example_folder = os.path.join(cifar_zoo_dir,'0b5474db-1755-487b-b8f1-42e9d5950f85')

example_torch_model = torch.load(os.path.join(example_folder,'torch_model.pt'),map_location='cpu') # sequential

example_results = json.load(open(os.path.join(example_folder,'results.json'),'r'))
example_feats = torch.load(os.path.join(example_folder,'epoch_0_feats.pt'),map_location='cpu')

  example_torch_model = torch.load(os.path.join(example_folder,'torch_model.pt'),map_location='cpu') # sequential
  example_feats = torch.load(os.path.join(example_folder,'epoch_0_feats.pt'),map_location='cpu')


In [65]:
for layer in example_torch_model:
    print(layer)
print()
print(example_results.keys())
print(example_feats[0].shape, example_feats[1].shape, example_feats[2].shape)

Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1))
BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
ReLU()
Conv2d(128, 16, kernel_size=(3, 3), stride=(1, 1))
BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
ReLU()
AdaptiveAvgPool2d(output_size=(1, 1))
Flatten()
LayerNorm((16,), eps=1e-05, elementwise_affine=True)
Linear(in_features=16, out_features=128, bias=True)
ReLU()
Linear(in_features=128, out_features=32, bias=True)
ReLU()
Linear(in_features=32, out_features=10, bias=True)
ReLU()

dict_keys(['hyperparameters', 'train_losses', 'val_losses', 'accuracy'])
torch.Size([328, 3]) torch.Size([2, 28986]) torch.Size([28986, 6])


In [66]:


def seq_to_feats(seq: nn.Sequential):
    """
    Convert a sequential model to node features and edge attributes.

    Args:
        seq (torch.nn.Sequential): The sequential model to convert.

    Returns:
        torch.Tensor: The node feature matrix - num_nodes x 3 node features
        torch.Tensor: The edge attribute matrix - num_edges x 6 edge features

    """
    return arch_to_graph(sequential_to_arch(seq))

def sequential_to_arch(model):
    """
    Convert a sequential model to an architecture, which is a list of lists where each list contains the
        layer type and the weights and biases of the layer.
    Args:
        model (torch.nn.Sequential): The sequential model to convert.
    Returns:
        List[List[torch.nn.Module, torch.Tensor, torch.Tensor]]: The architecture of the model.
            - The first element of each list is the layer type.
            - The second element of each list is the weight tensor.
            - The third element of each list is the bias tensor.
    """
    # input can be a nn.Sequential
    # or ordered list of modules
    arch = []
    weight_bias_modules = CONV_LAYERS + [nn.Linear] + NORM_LAYERS
    for i, module in enumerate(model):
        layer = [type(module)]
        if type(module) in weight_bias_modules:
            # Preserve gradients for weight and bias
            weight = module.weight.clone()
            bias = module.bias.clone()
            
            # Copy gradients if they exist
            if module.weight.grad is not None:
                weight.grad = module.weight.grad.clone()
            if module.bias.grad is not None:
                bias.grad = module.bias.grad.clone()
            
            layer.append(weight)
            layer.append(bias)
        
        elif type(module) == BasicBlock:
            # Preserve gradients for BasicBlock components
            layer_components = []
            
            # Conv layers and batch norm weights/biases
            conv1_weight = module.conv1.weight.clone()
            bn1_weight = module.bn1.weight.clone()
            bn1_bias = module.bn1.bias.clone()
            
            conv2_weight = module.conv2.weight.clone()
            bn2_weight = module.bn2.weight.clone()
            bn2_bias = module.bn2.bias.clone()
            
            # Copy gradients if they exist
            if module.conv1.weight.grad is not None:
                conv1_weight.grad = module.conv1.weight.grad.clone()
            if module.bn1.weight.grad is not None:
                bn1_weight.grad = module.bn1.weight.grad.clone()
            if module.bn1.bias.grad is not None:
                bn1_bias.grad = module.bn1.bias.grad.clone()
            
            if module.conv2.weight.grad is not None:
                conv2_weight.grad = module.conv2.weight.grad.clone()
            if module.bn2.weight.grad is not None:
                bn2_weight.grad = module.bn2.weight.grad.clone()
            if module.bn2.bias.grad is not None:
                bn2_bias.grad = module.bn2.bias.grad.clone()
            
            layer.extend([
                conv1_weight,
                bn1_weight,
                bn1_bias,
                conv2_weight,
                bn2_weight,
                bn2_bias,
            ])
            
            # Handle shortcut if exists
            if len(module.shortcut) > 0:
                shortcut_weight = module.shortcut[0].weight.clone()
                shortcut_bn_weight = module.shortcut[1].weight.clone()
                shortcut_bn_bias = module.shortcut[1].bias.clone()
                # Copy gradients for shortcut if they exist
                if module.shortcut[0].weight.grad is not None:
                    shortcut_weight.grad = module.shortcut[0].weight.grad.clone()
                if module.shortcut[1].weight.grad is not None:
                    shortcut_bn_weight.grad = module.shortcut[1].weight.grad.clone()
                if module.shortcut[1].bias.grad is not None:
                    shortcut_bn_bias.grad = module.shortcut[1].bias.grad.clone()
                
                layer.extend([
                    shortcut_weight,
                    shortcut_bn_weight,
                    shortcut_bn_bias,
                ])
        
        elif type(module) == PositionwiseFeedForward:
            # Preserve gradients for linear layers
            lin1_weight = module.lin1.weight.clone()
            lin1_bias = module.lin1.bias.clone()
            lin2_weight = module.lin2.weight.clone()
            lin2_bias = module.lin2.bias.clone()
            
            # Copy gradients if they exist
            if module.lin1.weight.grad is not None:
                lin1_weight.grad = module.lin1.weight.grad.clone()
            if module.lin1.bias.grad is not None:
                lin1_bias.grad = module.lin1.bias.grad.clone()
            if module.lin2.weight.grad is not None:
                lin2_weight.grad = module.lin2.weight.grad.clone()
            if module.lin2.bias.grad is not None:
                lin2_bias.grad = module.lin2.bias.grad.clone()
            
            layer.extend([lin1_weight, lin1_bias, lin2_weight, lin2_bias])
        
        elif type(module) == SelfAttention:
            # Preserve gradients for self-attention components
            in_proj_weight = module.attn.in_proj_weight.clone()
            in_proj_bias = module.attn.in_proj_bias.clone()
            out_proj_weight = module.attn.out_proj.weight.clone()
            out_proj_bias = module.attn.out_proj.bias.clone()
            
            # Copy gradients if they exist
            if module.attn.in_proj_weight.grad is not None:
                in_proj_weight.grad = module.attn.in_proj_weight.grad.clone()
            if module.attn.in_proj_bias.grad is not None:
                in_proj_bias.grad = module.attn.in_proj_bias.grad.clone()
            if module.attn.out_proj.weight.grad is not None:
                out_proj_weight.grad = module.attn.out_proj.weight.grad.clone()
            if module.attn.out_proj.bias.grad is not None:
                out_proj_bias.grad = module.attn.out_proj.bias.grad.clone()
            
            layer.extend([in_proj_weight, in_proj_bias, out_proj_weight, out_proj_bias])
        
        elif type(module) == EquivSetLinear:
            # Preserve gradients for EquivSetLinear
            lin1_weight = module.lin1.weight.clone()
            lin1_bias = module.lin1.bias.clone()
            lin2_weight = module.lin2.weight.clone()
            
            # Copy gradients if they exist
            if module.lin1.weight.grad is not None:
                lin1_weight.grad = module.lin1.weight.grad.clone()
            if module.lin1.bias.grad is not None:
                lin1_bias.grad = module.lin1.bias.grad.clone()
            if module.lin2.weight.grad is not None:
                lin2_weight.grad = module.lin2.weight.grad.clone()
            
            layer.extend([lin1_weight, lin1_bias, lin2_weight])
        
        elif type(module) == TriplanarGrid:
            # For TriplanarGrid, clone the tensor
            tgrid = module.tgrid.clone()
            
            # Copy gradient if it exists
            if module.tgrid.grad is not None:
                tgrid.grad = module.tgrid.grad.clone()
            
            layer.append(tgrid)
        
        else:
            if len(list(module.parameters())) != 0:
                raise ValueError(
                    f"{type(module)} has parameters but is not yet supported"
                )
            continue
        layer.append(i)
        arch.append(layer)
    
    return arch


def arch_to_graph(arch, self_loops=False):
    """
    Convert an architecture to a graph, which is represented by node features, edge indices, and edge attributes.
    This version preserves gradients present in weights and biases.

    Args:
        arch (List[List[torch.nn.Module, torch.Tensor, torch.Tensor]]): The architecture of the model.
            - The first element of each list is the layer type.
            - The second element of each list is the weight tensor.
            - The third element of each list is the bias tensor.
        self_loops (bool, optional): Whether to include self loops. Defaults to False.

    Returns:
        torch.Tensor: The node feature matrix - num_nodes x 3 node features
        torch.Tensor: The edge indices - 2 x num_edges (source, target)
        torch.Tensor: The edge attribute matrix - num_edges x 6 edge features
    """
    curr_idx = 0  # used to keep track of current node index relative to the entire graph
    node_features = []  # stores a list of tensors, each representing the features of a node
    edge_index = []  # stores a list of tensors, each stores 2xnum_edges (source, target)
    edge_attr = []  # stores a list of tensors, each stores num_edges x 6 edge features
    layer_num = 0  # keep track of current layer number

    # initialize input nodes
    layer = arch[0]
    layer_type = layer[0]
    if layer_type in CONV_LAYERS:
        in_neuron_idx = torch.arange(layer[1].shape[1])
    elif layer_type in (nn.Linear, PositionwiseFeedForward):
        in_neuron_idx = torch.arange(layer[1].shape[1])
    elif layer_type == BasicBlock:
        in_neuron_idx = torch.arange(layer[1].shape[1])
    elif layer_type == EquivSetLinear:
        in_neuron_idx = torch.arange(layer[1].shape[1])
    elif layer_type == TriplanarGrid:
        triplanar_resolution = layer[1].shape[2]
        in_neuron_idx = torch.arange(3 * triplanar_resolution**2)
    else:
        raise ValueError("Invalid first layer")

    for i, layer in enumerate(arch):
        is_output = i == len(arch) - 1
        layer_type = layer[0]
        
        
        if layer_type in CONV_LAYERS:
            weight_mat, bias = layer[1], layer[2]
            ret = conv_to_graph(
                weight_mat,
                bias,
                layer_num,
                in_neuron_idx,
                is_output,
                curr_idx,
                self_loops,
            )
            layer_num += 1
        elif layer_type == nn.Linear:
            weight_mat, bias = layer[1], layer[2]
            ret = linear_to_graph(
                weight_mat,
                bias,
                layer_num,
                in_neuron_idx,
                is_output,
                curr_idx,
                self_loops,
            )
            layer_num += 1
        elif layer_type in NORM_LAYERS:
            if layer_type in (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d):
                norm_type = "bn"
            elif layer_type == nn.LayerNorm:
                norm_type = "ln"
            elif layer_type == nn.GroupNorm:
                norm_type = "gn"
            elif layer_type in (
                nn.InstanceNorm1d,
                nn.InstanceNorm2d,
                nn.InstanceNorm3d,
            ):
                norm_type = "in"
            else:
                raise ValueError("Invalid norm type")
            gamma = layer[1]
            beta = layer[2]
            ret = norm_to_graph(
                gamma,
                beta,
                layer_num,
                in_neuron_idx,
                is_output,
                curr_idx,
                self_loops,
                norm_type=norm_type,
            )
        elif layer_type == BasicBlock:
            ret = basic_block_to_graph(
                layer[1:], layer_num, in_neuron_idx, is_output, curr_idx, self_loops
            )
            layer_num += 2
        elif layer_type == PositionwiseFeedForward:
            ret = ffn_to_graph(
                layer[1],
                layer[2],
                layer[3],
                layer[4],
                layer_num,
                in_neuron_idx,
                is_output,
                curr_idx,
                self_loops,
            )
            layer_num += 2
        elif layer_type == SelfAttention:
            ret = self_attention_to_graph(
                layer[1],
                layer[2],
                layer[3],
                layer[4],
                layer_num,
                in_neuron_idx,
                is_output=is_output,
                curr_idx=curr_idx,
                self_loops=self_loops,
            )
            layer_num += 2
        elif layer_type == EquivSetLinear:
            ret = equiv_set_linear_to_graph(
                layer[1],
                layer[2],
                layer[3],
                layer_num,
                in_neuron_idx,
                is_output=is_output,
                curr_idx=curr_idx,
                self_loops=self_loops,
            )
            layer_num += 1
        elif layer_type == TriplanarGrid:
            ret = triplanar_to_graph(
                layer[1], layer_num, is_output=is_output, curr_idx=curr_idx
            )
            layer_num += 1
        else:
            raise ValueError("Invalid layer type")
        in_neuron_idx = ret["out_neuron_idx"]

        edge_index.append(ret["edge_index"])  # 2 x num_edges
        edge_attr.append(ret["edge_attr"])  # num_edges x 6
        if ret["node_feats"] is not None:
            feat = ret["node_feats"]
            node_features.append(feat)
            curr_idx += feat.shape[0]

    # Concatenate while preserving gradients
    # Using torch.cat preserves gradient information when tensors have requires_grad=True
    node_features = torch.cat(node_features, dim=0)
    edge_index = torch.cat(edge_index, dim=1)
    edge_attr = torch.cat(edge_attr, dim=0)
    
    return node_features, edge_index, edge_attr

def feats_to_arch(node_features):
    arch = {}
    for i in range(node_features.shape[0]):
        node_feats = node_features[i]
        layer_num, _, node_type =  node_feats
        layer_num = layer_num.item()
        node_type = node_type.item()
        if layer_num in arch:
            continue

        arch[layer_num] = NODE_TYPE_TO_LAYER[node_type]
    arch = [arch[i] for i in range(len(arch))]

    return arch

def graph_to_arch(arch, weights):
    arch_new = []
    curr_idx = 0
    for l, layer in enumerate(arch):
        lst = [layer[0]]
        if layer[0] != SelfAttention:
            for tensor in layer[1:-1]: # ignore first elem (layer type) and last elem (layer number)
                if tensor is not None:
                    weight_size = math.prod(tensor.shape)
                    reshaped = weights[curr_idx : curr_idx + weight_size].reshape(tensor.shape) 
                    lst.append(reshaped)
                    curr_idx += weight_size
                else:
                    lst.append(None)
        else:
            # handle in_proj stuff differently, because pytorch stores it all as a big matrix
            in_proj_weight_shape = layer[1].shape
            dim = in_proj_weight_shape[1]
            in_proj_weight = []
            in_proj_bias = []
            for _ in range(3):
                # get q, k, and v
                weight_size = dim * dim
                reshaped = weights[curr_idx : curr_idx + weight_size].reshape(dim, dim)
                in_proj_weight.append(reshaped)
                curr_idx += weight_size

                bias_size = dim
                reshaped = weights[curr_idx : curr_idx + bias_size].reshape(dim)
                in_proj_bias.append(reshaped)
                curr_idx += bias_size

            # concatenate q, k, v weights and biases
            lst.append(torch.cat(in_proj_weight, 0))
            lst.append(torch.cat(in_proj_bias, 0))

            # out_proj handled normally
            for tensor in layer[3:]:
                if tensor is not None:
                    weight_size = math.prod(tensor.shape)
                    reshaped = weights[curr_idx : curr_idx + weight_size].reshape(
                        tensor.shape
                    )
                    lst.append(reshaped)
                    curr_idx += weight_size
                else:
                    lst.append(None)

        # handle residual connections, and other edges that don't correspond to weights
        if layer[0] == PositionwiseFeedForward:
            residual_size = layer[1].shape[1]
            curr_idx += residual_size
        elif layer[0] == BasicBlock:
            residual_size = layer[1].shape[0]
            curr_idx += residual_size
        elif layer[0] == SelfAttention:
            residual_size = layer[1].shape[1]
            curr_idx += residual_size

        arch_new.append(lst)
    return arch_new


def arch_to_named_params(arch):
    '''
    arch: the architecture of the model, as a list of lists

    returns a generator of tuples of (name, param)
    '''

    
    for i, layer in enumerate(arch):
        layer_num = layer[-1]
        yield f'{layer_num}.weight', layer[1]
        yield f'{layer_num}.bias', layer[2]
def arch_to_sequential(arch, model, preserve_grad=True):
    '''
    arch: the architecture of the model, as a list of lists
    model: the target model to be reconstructed, as a nn.Module
    '''
    # model is a model of the correct architecture
    arch_idx = 0
    for child in model.children():
        if len(list(child.parameters())) > 0:
            layer = arch[arch_idx]
            sd = child.state_dict()
            layer_idx = 1
            for i, k in enumerate(sd):
                if (
                    "running_mean" in k
                    or "running_var" in k
                    or "num_batches_tracked" in k
                ):
                    continue
                if not preserve_grad:
                    param = nn.Parameter(layer[layer_idx])
                else:
                    child._parameters[k] = layer[layer_idx]
                    param = child._parameters[k]
                sd[k] = param
                layer_idx += 1
            child.load_state_dict(sd)
            arch_idx += 1

    return model



In [67]:
example_torch_model.requires_grad_(True)
arch = sequential_to_arch(example_torch_model)
node_feats, edge_index, edge_feats = arch_to_graph(arch)
reconstructed_arch = graph_to_arch(arch, edge_feats[:,0])


for i in range(len(arch)):
    weight_tensor = arch[i][1]
    weight_tensor_reconstructed = reconstructed_arch[i][1]
    assert weight_tensor.shape == weight_tensor_reconstructed.shape
    assert torch.all(torch.eq(weight_tensor, weight_tensor_reconstructed))

In [68]:
[d[0] for d in arch]

[torch.nn.modules.conv.Conv2d,
 torch.nn.modules.batchnorm.BatchNorm2d,
 torch.nn.modules.conv.Conv2d,
 torch.nn.modules.batchnorm.BatchNorm2d,
 torch.nn.modules.normalization.LayerNorm,
 torch.nn.modules.linear.Linear,
 torch.nn.modules.linear.Linear,
 torch.nn.modules.linear.Linear]

In [69]:
params = {name: param
                for name, param in example_torch_model.named_parameters()}
params_from_arch = {name:param
                for name, param in arch_to_named_params(arch)}
test_input = torch.randn(1, 3, 32, 32)

out1 = functional_call(example_torch_model, params, (test_input,))
out2 = functional_call(example_torch_model, params_from_arch, (test_input,))

print(out1.shape, out2.shape)

print(torch.all(torch.eq(out1, out2)))

torch.Size([1, 10]) torch.Size([1, 10])
tensor(True)


In [70]:
node_feats.shape

torch.Size([328, 3])

In [71]:
node_in_dim = node_feats.shape[1]
edge_in_dim = edge_feats.shape[1]
hidden_dim = 16
edge_out_dim = 6
num_layers = 3
gmn = EdgeMPNNDiT(3, edge_in_dim, hidden_dim, edge_out_dim, 
                  num_layers,  dropout=0.0, reduce='mean', activation='silu', use_global=False,)
# gmn.init_weights_()
for block in gmn.convs:
    if block.update_node:
        nn.init.constant_(block.node_mlp[-1].weight, 1)
        nn.init.constant_(block.node_mlp[-1].bias, 1)

    if block.update_edge:
        nn.init.constant_(block.edge_mlp[-1].weight, 1)
        nn.init.constant_(block.edge_mlp[-1].bias, 1)
# init


In [72]:
# node_feats = torch.tensor(node_feats, dtype=torch.float32)
# edge_feats = torch.tensor(edge_feats, dtype=torch.float32)
# edge_index = torch.tensor(edge_index, dtype=torch.long)
# gmn.forward(node_feats, edge_index, edge_feats,None,None)

In [73]:
from torch.func import functional_call

def arch_to_state_dict(arch, model):
    state_dict = {}
    for i, layer in enumerate(arch):
        layer_type = layer[0]
class NeuralNet(nn.Module):
    def __init__(self, sequential: nn.Sequential):
        super(NeuralNet, self).__init__()
        self.sequential = sequential
        self.arch = sequential_to_arch(sequential)
        self.params = {
            name: param
            for name, param in sequential.named_parameters()
        }
        self.node_feats, self.edge_index, self.edge_feats = arch_to_graph(self.arch)
    def update(self, weights):
        self.arch = graph_to_arch(self.arch, weights)
        self.params = {
            name: param
            for name, param in arch_to_named_params(self.arch)
        }
    def forward(self, x):
        return functional_call(self.sequential, self.params, (x,)) 

In [74]:
net = NeuralNet(example_torch_model)
net.train()


node_feats, next_edge_attr = gmn.forward(net.node_feats, net.edge_index, net.edge_feats, None, None)

node_feats.shape, next_edge_attr.shape

(torch.Size([328, 16]), torch.Size([28986, 6]))

In [75]:
gmn = EdgeMPNNDiT(node_in_dim, edge_in_dim, hidden_dim, edge_out_dim, num_layers)
gmn.train()
gmn.requires_grad_(True)

meta_optimizer = torch.optim.Adam(gmn.parameters(), lr=0.01)
meta_optimizer.zero_grad()

net = NeuralNet(example_torch_model)
net.train()

test_input = torch.randn(1,3,32,32)

target = torch.randn(1,10)
criterion = nn.MSELoss()
num_iterations = 10  # or however many iterations you want
for iter in range(num_iterations):
    meta_optimizer.zero_grad()
    
    # Forward pass through net
    out1 = net(test_input)
    loss = criterion(out1, target)
    loss.backward(retain_graph=True)
    
    # Ensure all parameters of net have gradients
    for name, param in net.named_parameters():
        assert param.grad is not None, f"{name} has no grad"
    
    # Update net's graph features from its sequential representation
    net.node_feats, net.edge_index, net.edge_feats = seq_to_feats(net.sequential)
    # Confirm the edge features are connected to the computation graph
    for i in range(len(net.edge_feats)):
        assert net.edge_feats[i].grad_fn is not None, f"edge_feats[{i}] is not connected"
    
    # Get updated edge attributes from the graph network
    _, next_edge_attr = gmn(net.node_feats, net.edge_index, net.edge_feats, None, None)
    
    # Update net's weights using the new edge attributes
    net.update(next_edge_attr)
    
    # Second forward/backward pass through net
    out2 = net(test_input)
    loss2 = criterion(out2, target)
    loss2.backward(retain_graph=True)
    
    # Update the parameters of gmn using the meta optimizer
    meta_optimizer.step()
    
    print(f"Iteration {iter}: loss {loss.item()}, loss2 {loss2.item()}")


RuntimeError: shape '[128, 3, 3, 3]' is invalid for input of size 20736