In [1]:
from gmn_lim.graph_construct.model_arch_graph import seq_to_feats, sequential_to_arch, arch_to_graph, graph_to_arch
from gmn_lim.graph_models import EdgeMPNNDiT
from gmn_lim.feature_extractor_gmn import GraphPredGen
from torchviz import make_dot
import torch
import torch.nn as nn
from torch.func import functional_call
import json
import math
import time
import os
import sys
sys.path.insert(0, "./gmn_lim/graph_construct")
from gmn_lim.graph_construct.constants import NODE_TYPES, EDGE_TYPES, CONV_LAYERS, NORM_LAYERS, RESIDUAL_LAYERS, NODE_TYPE_TO_LAYER
from gmn_lim.graph_construct.utils import (
    make_node_feat,
    make_edge_attr,
    conv_to_graph,
    linear_to_graph,
    norm_to_graph,
    ffn_to_graph,
    basic_block_to_graph,
    self_attention_to_graph,
    equiv_set_linear_to_graph,
    triplanar_to_graph,
)
from gmn_lim.graph_construct.model_arch_graph import (
    seq_to_feats,
    sequential_to_arch,
    arch_to_graph,
    graph_to_arch,
    arch_to_named_params

)
from gmn_lim.graph_construct.layers import (
    Flatten,
    PositionwiseFeedForward,
    BasicBlock,
    SelfAttention,
    EquivSetLinear,
    TriplanarGrid,
)

In [None]:
model_dir = '/media/siddhartha/games/gmn_data/fixed_hp_data/'
# -> folderid
#       -> epoch_0_feats.pt
#       -> ...
#       -> epoch_0_feats.pt
#       -> results.json
#       -> torch_model.pt


example_folder = os.path.join(model_dir,'0b5474db-1755-487b-b8f1-42e9d5950f85')

example_torch_model = torch.load(os.path.join(example_folder,'torch_model.pt'),map_location='cpu') # sequential

example_results = json.load(open(os.path.join(example_folder,'results.json'),'r'))
example_feats = torch.load(os.path.join(example_folder,'epoch_0_feats.pt'),map_location='cpu')

  example_torch_model = torch.load(os.path.join(example_folder,'torch_model.pt'),map_location='cpu') # sequential
  example_feats = torch.load(os.path.join(example_folder,'epoch_0_feats.pt'),map_location='cpu')


In [None]:
for layer in example_torch_model:
    print(layer)
print()
print(example_results.keys())
print(example_feats[0].shape, example_feats[1].shape, example_feats[2].shape)

Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1))
BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
ReLU()
Conv2d(128, 16, kernel_size=(3, 3), stride=(1, 1))
BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
ReLU()
AdaptiveAvgPool2d(output_size=(1, 1))
Flatten()
LayerNorm((16,), eps=1e-05, elementwise_affine=True)
Linear(in_features=16, out_features=128, bias=True)
ReLU()
Linear(in_features=128, out_features=32, bias=True)
ReLU()
Linear(in_features=32, out_features=10, bias=True)
ReLU()

dict_keys(['hyperparameters', 'train_losses', 'val_losses', 'accuracy'])
torch.Size([328, 3]) torch.Size([2, 28986]) torch.Size([28986, 6])


In [None]:
example_torch_model.requires_grad_(True)
arch = sequential_to_arch(example_torch_model)
node_feats, edge_index, edge_feats = arch_to_graph(arch)
reconstructed_arch = graph_to_arch(arch, edge_feats[:,0])


for i in range(len(arch)):
    weight_tensor = arch[i][1]
    weight_tensor_reconstructed = reconstructed_arch[i][1]
    assert weight_tensor.shape == weight_tensor_reconstructed.shape
    assert torch.all(torch.eq(weight_tensor, weight_tensor_reconstructed))

In [None]:
params = {name: param
                for name, param in example_torch_model.named_parameters()}
params_from_arch = {name:param
                for name, param in arch_to_named_params(arch)}
test_input = torch.randn(1, 3, 32, 32)

out1 = functional_call(example_torch_model, params, (test_input,))
out2 = functional_call(example_torch_model, params_from_arch, (test_input,))

print(out1.shape, out2.shape)

print(torch.all(torch.eq(out1, out2)))

torch.Size([1, 10]) torch.Size([1, 10])
tensor(True)


In [None]:
from torch.func import functional_call

class NeuralNet(nn.Module):
    def __init__(self, sequential: nn.Sequential):
        super(NeuralNet, self).__init__()
        self.sequential = sequential
        self.arch = sequential_to_arch(sequential)
        self.params = {
            name: param
            for name, param in sequential.named_parameters()
        }
        self.node_feats, self.edge_index, self.edge_feats = arch_to_graph(self.arch)
    def update(self, weights):
        self.arch = graph_to_arch(self.arch, weights)
        new_params = {
            name: param
            for name, param in arch_to_named_params(self.arch)
        }
        # assert all shapes same as before
        for key in self.params.keys():
            assert self.params[key].shape == new_params[key].shape
        # update params
        self.params = {name: new_params[name] for name in self.params.keys()}
    def forward(self, x):
        return functional_call(self.sequential, self.params, (x,)) 

In [None]:
net = NeuralNet(example_torch_model)
net.train()

node_in_dim = node_feats.shape[1]
edge_in_dim = edge_feats.shape[1]
hidden_dim = 16
edge_out_dim = 1
num_layers = 3
gmn = EdgeMPNNDiT(3, edge_in_dim, hidden_dim, edge_out_dim, 
                  num_layers,  dropout=0.0, reduce='mean', activation='silu', use_global=False,)
# gmn.init_weights_()
for block in gmn.convs:
    if block.update_node:
        nn.init.constant_(block.node_mlp[-1].weight, 1)
        nn.init.constant_(block.node_mlp[-1].bias, 1)

    if block.update_edge:
        nn.init.constant_(block.edge_mlp[-1].weight, 1)
        nn.init.constant_(block.edge_mlp[-1].bias, 1)
# init
node_feats, next_edge_attr = gmn.forward(net.node_feats, net.edge_index, net.edge_feats, None, None)

node_feats.shape, next_edge_attr.shape

(torch.Size([328, 16]), torch.Size([28986, 1]))

In [None]:
torch.autograd.set_detect_anomaly(True)
gmn.train()
gmn = EdgeMPNNDiT(3, edge_in_dim, hidden_dim, edge_out_dim, 
                  num_layers,  dropout=0.0, reduce='mean', activation='silu', use_global=False,)
# gmn.init_weights_()
for block in gmn.convs:
    if block.update_node:
        nn.init.constant_(block.node_mlp[-1].weight, 1)
        nn.init.constant_(block.node_mlp[-1].bias, 1)

    if block.update_edge:
        nn.init.constant_(block.edge_mlp[-1].weight, 1)
        nn.init.constant_(block.edge_mlp[-1].bias, 1)
gmn.requires_grad_(True)

meta_optimizer = torch.optim.Adam(gmn.parameters(), lr=0.01)
meta_optimizer.zero_grad()



num_samples = 1000
test_input = torch.randn(num_samples,3,32,32)
target = torch.randn(num_samples,10)
dataset = torch.utils.data.TensorDataset(test_input, target)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True)
it = iter(dataloader)

def get_batch():
    global it
    # Get next batch, create new iterator if we've exhausted the current one
    try:
        batch = next(it)
    except StopIteration:
        it = iter(dataloader)
        batch = next(it)
    return batch
        
criterion = nn.MSELoss()
outer_iterations = 10
inner_iterations = 20  # or however many iterations you want
for epoch in range(outer_iterations):
    meta_optimizer.zero_grad()
    net = NeuralNet(example_torch_model)
    net.train()
    losses = []
    for i in range(inner_iterations):
        test_input, target = get_batch()
        out = net(test_input)
        loss = criterion(out, target)
        # print(f"Epoch {epoch}, Iteration {i}, Loss: {loss.item()}")
        retain_graph =  i< inner_iterations 
        loss.backward(retain_graph=retain_graph)
        losses.append(loss)
        # Ensure all parameters of net have gradients
        for name, param in net.named_parameters():
            assert param.grad is not None, f"{name} has no grad"
        
        _, next_edge_attr = gmn(net.node_feats, net.edge_index, net.edge_feats, None, None)
        
        # Update net's weights using the new edge attributes
        net.update(next_edge_attr[:,0])
    for name, param in gmn.named_parameters():
        assert param.grad is not None, f"{name} has no grad"
    
    # weight later losses higher
    weights = torch.tensor([1.0/(i+1) for i in range(inner_iterations)])
    weights = weights / torch.sum(weights)
    meta_loss = torch.sum(torch.stack(losses) * weights)
    # make_dot(meta_loss, params=net.params).render('graph20', format='png')
    # break
    gmn_params = {name: param
                for name, param in gmn.named_parameters()}
    
    print(f"Meta Loss: {meta_loss.item()}")
    meta_loss.backward()
    meta_optimizer.step()


Meta Loss: 3.887474536895752


KeyboardInterrupt: 