In [1]:
import io
import time
import numpy as np
import torch
import random
import torch.nn.functional as F
import torch.optim.lr_scheduler as lr_scheduler
from torch_geometric.datasets import GNNBenchmarkDataset, TUDataset
from torch_geometric.transforms.add_positional_encoding import AddLaplacianEigenvectorPE
from torch_geometric.loader import DataLoader
from IST_models import RMGN
from utils import merge_model_parameters, hierarchical_graphs, build_network_params, count_parameters

data_name = "CIFAR10"
model_root = 'data/MODELS/'

dataset_train = GNNBenchmarkDataset(root = "data/" + data_name, name = data_name, split="train")
dataset_val = GNNBenchmarkDataset(root = "data/" + data_name, name = data_name, split="val")
dataset_test = GNNBenchmarkDataset(root = "data/" + data_name, name = data_name, split="test")

train_loader = DataLoader(dataset_train, batch_size=64, shuffle=True)
valid_loader = DataLoader(dataset_val, batch_size=64)
test_loader = DataLoader(dataset_test, batch_size=64)

# from torch_loader import GraphClassificationBench

# # Load "hard"
# dataset_train = GraphClassificationBench("data/", split='train', easy=False, small=False)
# dataset_val = GraphClassificationBench("data/", split='val', easy=False, small=False)
# dataset_test = GraphClassificationBench("data/", split='test', easy=False, small=False)

# train_loader = DataLoader(dataset_train, batch_size=32, shuffle=True)
# valid_loader = DataLoader(dataset_val, batch_size=32)
# test_loader = DataLoader(dataset_test, batch_size=32)

emb_bias = True
if data_name == "CIFAR10":
    return_level = "graph_level"
    from utils import accuracy_MNIST_CIFAR as accuracy
    from utils import loss_MNIST_CIFAR as loss_funtion
    num_features = 5
    emb_bias = True
elif data_name == "PATTERN":
    return_level = "node_level"
    from utils import accuracy_SBM as accuracy
    from utils import loss_SBM as loss_funtion
    emb_bias = False
    num_features = 3
elif data_name == "CLUSTER":
    return_level = "node_level"
    from utils import accuracy_SBM as accuracy
    from utils import loss_SBM as loss_funtion
    emb_bias = False
    num_features = 7
elif data_name == "MNIST":
    return_level = "node_level"
    from utils import accuracy_MNIST_CIFAR as accuracy
    from utils import loss_MNIST_CIFAR as loss_funtion
    emb_bias = True
    num_features = 3
next(iter(train_loader))

  return torch._C._cuda_getDeviceCount() > 0


DataBatch(x=[7489, 3], edge_index=[2, 59912], edge_attr=[59912], y=[64], pos=[7489, 2], batch=[7489], ptr=[65])

In [2]:
model_dict = {
    "Graph_Sage_pool": {
        "edge_params": {
            "project": True,
            "activation": "relu",
            "initial_loops": False,
            "with_nodes_own": True,
        },
        "node_params": {
            "edge_to_node_aggr": "max",
            "project": True,
            "activation": "relu",
            "with_nodes_own": True,
        },
        "selection_params": {
            "remove_self_loops": True
        }
    },
    "Graph_Sage_mean": {
        "edge_params": {
            "edge_att_norm": "mean",
            "project": False,
            "initial_loops": False,
            "with_nodes_own": True
        },
        "node_params": {
            "edge_to_node_aggr": "sum",
            "project": True,
            "activation": "relu",
            "with_nodes_own": True,
        },
        "selection_params": {
            "remove_self_loops": True
        }
    },
    "Graph_Sage_vanilla": {
        "edge_params": {
            "edge_att_norm": "mean",
            "project": False,
            "initial_loops": True,
            "with_nodes_own": True,
        },
        "node_params": {
            "edge_to_node_aggr": "sum",
            "project": True,
            "activation": "relu",
        },
        "selection_params": {
        }
    },
    "Graph_Sage_GCN": {
        "edge_params": {
            "edge_att_norm": "symmetric",
            "project": False,
            "initial_loops": True,
            "with_nodes_own": True,
        },
        "node_params": {
            "edge_to_node_aggr": "sum",
            "project": True,
            "activation": "relu",
        },
        "selection_params": {
        }
    },
    "Graph_Attention": {
        "edge_params": {
            "edge_att_norm": "attention",
            "project": True,
            "bias": False,
            "activation": None,
            "initial_loops": True,
            "with_nodes_own": True
        },
        "node_params": {
            "edge_to_node_aggr": "sum",
            "project": False,
            "activation": "relu",
        },
    }
}

pooling_dict = {
    "edge_params": {
        "broadcast_method": "reverse_cluster",
        "with_edge_weights": True
    },
    "node_params": {
        "broadcast_method": "reverse_cluster",
        "normalize": True,
        "dropout_prob" : 0.00
    },
    "selection_params": {
        "selection_meth": "mlp",
        "edge_to_edge_aggr": "sum",
        "node_to_node_aggr": "sum",
        "project": True,
        "softmax_direction": -1,
        "normalize_adj": True
    }
}

layer_dict = {
    "emb_params": {
        "with_nodes_below": True,
        "with_nodes_above": True,
        "with_globals": True,
        "with_nodes_depth": True
    },
    "pro_params": {
        "with_nodes_below": True,
        "with_nodes_above": True,
    },
    "dec_params": {
        "with_nodes_below": True,
        "with_nodes_above": False,
        "with_globals": False,
        "with_nodes_depth": True,
        "with_selection_previous": True
    }
}

setup_dict = {"HMGN_emb": True, "HMGN_dec": False, "pos_encoding": False, "pos_enc_dim": 10,
              "n_hid_layer": 4, "emb_bias": emb_bias, "return_type": return_level, 
              "patience": 10, "reduce_factor": 0.5, "initial_lr": 5e-3, "stop_learning_rate": 1e-5, 
              "seq_len": 12, "n_heads": 8}

if setup_dict["pos_encoding"]:
    num_features += setup_dict["pos_enc_dim"]

data_dict = {'n_node_feat_S': num_features, 'n_node_feat_ST': 0, 'n_feat_glob_ST': 0, 
             'n_edge_feat_S': 0, 'n_edge_feat_ST': 0, 'n_feat_glob_S': 0, "n_out_final" : dataset_train.num_classes}

n_cluster_with_hid_list = [
    ["Graph_Sage_mean", [], 100], 
    # ["Graph_Sage_GCN", [6], 116],
    # ["Graph_Sage_vanilla", [], 146],
    # ["Graph_Sage_mean", [6], 49],
    # ["Graph_Attention", [], 41],
    # ["Graph_Sage_mean", [64, 32, 6], 60],
    # ["Graph_Sage_mean", [32, 6], 131],
    # ["Graph_Sage_pool", [10], 59], 
    # ["Graph_Sage_GCN", [10], 100],
    # ["Graph_Sage_vanilla", [10], 100],
    # ["Graph_Attention", [8], 20],
    ]

for n_cluster_with_hid in n_cluster_with_hid_list:
    model_name = n_cluster_with_hid[0]
    n_cluster_nodes = n_cluster_with_hid[1]
    setup_dict["HMGN_emb"] = len(n_cluster_nodes) > 0
    setup_dict["model_name"] = model_name
    setup_dict["n_hid"] = n_cluster_with_hid[2]
    hierarchical_network_params = build_network_params(model_dict[model_name], pooling_dict, data_dict,
                                                    layer_dict, setup_dict, n_cluster_nodes)
    model = RMGN(model_name, setup_dict, data_dict, n_cluster_nodes, hierarchical_network_params)
    count_parameters(model, model_name, True, True)

S_encoder.0.weight :  torch.Size([100, 5])
S_encoder.0.bias :  torch.Size([100])
S_processor.0.0.model.1.node_project_fn.weight :  torch.Size([100, 200])
S_processor.0.0.model.1.node_project_fn.bias :  torch.Size([100])
S_processor.0.0.model.1.normalize_fn.weight :  torch.Size([100])
S_processor.0.0.model.1.normalize_fn.bias :  torch.Size([100])
S_processor.1.0.model.1.node_project_fn.weight :  torch.Size([100, 200])
S_processor.1.0.model.1.node_project_fn.bias :  torch.Size([100])
S_processor.1.0.model.1.normalize_fn.weight :  torch.Size([100])
S_processor.1.0.model.1.normalize_fn.bias :  torch.Size([100])
S_processor.2.0.model.1.node_project_fn.weight :  torch.Size([100, 200])
S_processor.2.0.model.1.node_project_fn.bias :  torch.Size([100])
S_processor.2.0.model.1.normalize_fn.weight :  torch.Size([100])
S_processor.2.0.model.1.normalize_fn.bias :  torch.Size([100])
S_processor.3.0.model.1.node_project_fn.weight :  torch.Size([100, 200])
S_processor.3.0.model.1.node_project_fn.bias 

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Running training on: ", device)

for n_cluster_with_hid in n_cluster_with_hid_list:
    model_name = n_cluster_with_hid[0]
    n_cluster_nodes = n_cluster_with_hid[1]
    setup_dict["HMGN_emb"] = len(n_cluster_nodes) > 0
    setup_dict["model_name"] = model_name
    setup_dict["n_hid"] = n_cluster_with_hid[2]
    for node_dropout_prob in [0.05]:
        start_time = time.time()
        all_train_accuracy, all_val_accuracy, all_test_accuracy, all_epoch = [], [], [], []
        for seed in [0, 1, 2, 3]:
            torch.cuda.empty_cache()
            torch.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
            random.seed(seed)
            np.random.seed(seed)
            
            pooling_dict["node_params"]["dropout_prob"] = node_dropout_prob
            hierarchical_network_params = build_network_params(model_dict[model_name], pooling_dict, data_dict,
                                                            layer_dict, setup_dict, n_cluster_nodes)
            print("\n", model_name)
            for key, val in hierarchical_network_params["params_S"].items():
                print("\n", key)
                for val2 in val:
                    for key2, val3 in val2.items():
                        print(key2)
                        my_list = ["broadcast_method", "dropout_prob", "edge_to_node_aggr", "with_edges_own", "n_feat_edge",
                                "with_edges_below", "with_edges_above", "with_edges_depth", "edge_att_norm",
                                "softmax_direction", "node_to_node_aggr", "edge_to_edge_aggr", "selection_meth"]
                        print({k.replace("with_", ""): v for k, v in val3.items() if k not in my_list})

            model_type = "RMGN"
            model = RMGN(model_name, setup_dict, data_dict, n_cluster_nodes, hierarchical_network_params).to(device)
            position_endoder = AddLaplacianEigenvectorPE(setup_dict["pos_enc_dim"], attr_name=None)

            if seed == 0:
                print("\n", model)
                count_parameters(model, model_name, False, False)
            optimizer = torch.optim.Adam(model.parameters(), setup_dict["initial_lr"])

            scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=setup_dict["reduce_factor"],
                                                    patience= setup_dict["patience"], 
                                                    verbose=True, min_lr=setup_dict["stop_learning_rate"])

            current_best = 0
            best_training = 0
            best_epoch = 0
            for epoch in range(1000):
                # Training loop
                model.train()
                correct = 0
                total = 0
                for snapshot in train_loader:
                    if "pos" in snapshot: snapshot.x = torch.cat((snapshot.x, snapshot.pos), dim=-1)
                    snapshot = position_endoder(snapshot)
                    snapshot = snapshot.to(device)

                    optimizer.zero_grad()
                    if model_type == "RMGN":
                        snapshot.node_attr = snapshot.x
                        del snapshot.x
                        S_graphs, ST_graphs = hierarchical_graphs(n_cluster_nodes, snapshot, 
                                                    model_dict[model_name]["edge_params"]["initial_loops"])
                        S_graphs = [graph.clone().to(device) for graph in S_graphs]
                        ST_graphs = [graph.clone().to(device) for graph in ST_graphs]
                        y_hat, mincut_loss, ortho_loss = model(S_graphs, ST_graphs)
                        
                    loss = loss_funtion(snapshot.y, y_hat, data_dict["n_out_final"], device) + mincut_loss + ortho_loss

                    total += snapshot.y.size(0)
                    correct += (accuracy(snapshot.y, y_hat) * snapshot.y.size(0))
                    
                    loss.backward()
                    optimizer.step()

                train_accuracy = 100.0 * correct / total

                # if epoch == 0 or (epoch + 1) % 10 == 0:
                print(f"Epoch: {epoch+1}, Training Accuracy: {train_accuracy:.3f}%")

                # Validation loop
                model.eval()
                correct = 0
                total = 0
                with torch.no_grad():
                    for snapshot in valid_loader:
                        if "pos" in snapshot: snapshot.x = torch.cat((snapshot.x, snapshot.pos), dim=-1)
                        snapshot = position_endoder(snapshot)
                        snapshot = snapshot.to(device)

                        if model_type == "RMGN":
                            snapshot.node_attr = snapshot.x
                            del snapshot.x
                            S_graphs, ST_graphs = hierarchical_graphs(n_cluster_nodes, snapshot, 
                                                        model_dict[model_name]["edge_params"]["initial_loops"])
                            S_graphs = [graph.clone().to(device) for graph in S_graphs]
                            ST_graphs = [graph.clone().to(device) for graph in ST_graphs]
                            y_hat, mincut_loss, ortho_loss = model(S_graphs, ST_graphs)

                        # y_hat = global_mean_pool(y_hat, snapshot.batch)
                        # y_hat = after_model(y_hat)

                        val_loss = loss_funtion(snapshot.y, y_hat, data_dict["n_out_final"], device) + mincut_loss + ortho_loss

                        total += snapshot.y.size(0)
                        correct += (accuracy(snapshot.y, y_hat) * snapshot.y.size(0))

                val_accuracy = 100.0 * correct / total
                
                scheduler.step(val_accuracy)

                # Check if learning rate is 1e-5 or lower
                if scheduler._last_lr[0] <= setup_dict["stop_learning_rate"]:
                    break

                # if epoch == 0 or (epoch + 1) % 10 == 0:
                print(f"Epoch: {epoch+1}, Validation Accuracy: {val_accuracy:.3f}%")

                if val_accuracy > current_best:
                    current_best = val_accuracy
                    best_training = train_accuracy
                    best_epoch = epoch
                    model_dir = "".join([model_root, data_name])
                    f_name = "_".join([model_type, model_name, str(n_cluster_nodes), 
                                        str(setup_dict["n_hid_layer"]), str(setup_dict["initial_lr"]), str(node_dropout_prob), 
                                        str(setup_dict["n_hid"])])
                    # print("saving model!")
                    torch.save(model.state_dict(), model_dir + "/" + f_name)

            model = RMGN(model_name, setup_dict, data_dict, n_cluster_nodes, hierarchical_network_params).to(device)

            saved_state_dict_path = model_dir + "/" + f_name
            with open(saved_state_dict_path, 'rb') as f:
                buffer = io.BytesIO(f.read())

            model.load_state_dict(torch.load(buffer))

            model.eval()
            correct = 0
            total = 0
            with torch.no_grad():
                for snapshot in test_loader:
                    if "pos" in snapshot: snapshot.x = torch.cat((snapshot.x, snapshot.pos), dim=-1)
                    snapshot = position_endoder(snapshot)
                    snapshot = snapshot.to(device)

                    if model_type == "RMGN":
                        snapshot.node_attr = snapshot.x
                        del snapshot.x
                        S_graphs, ST_graphs = hierarchical_graphs(n_cluster_nodes, snapshot, 
                                                    model_dict[model_name]["edge_params"]["initial_loops"])
                        S_graphs = [graph.clone().to(device) for graph in S_graphs]
                        ST_graphs = [graph.clone().to(device) for graph in ST_graphs]
                        y_hat, mincut_loss, ortho_loss = model(S_graphs, ST_graphs)

                    # y_hat = global_mean_pool(y_hat, snapshot.batch)
                    # y_hat = after_model(y_hat)

                    loss = loss_funtion(snapshot.y, y_hat, data_dict["n_out_final"], device) + mincut_loss + ortho_loss

                    total += snapshot.y.size(0)
                    correct += (accuracy(snapshot.y, y_hat) * snapshot.y.size(0))

            test_accuracy = 100.0 * correct / total
            print("------------------------------------------------------")
            print(f"Epoch: {epoch+1}, Test Accuracy: {test_accuracy:.3f}%") 
            print("------------------------------------------------------")
            print("saving final model!")
            torch.save(model.state_dict(), model_dir + "/" + f_name + "_" + str(best_epoch) + "_" + str(best_training)[:4]
                       + "_" + str(current_best) + "_" + str(test_accuracy))
            print("------------------------------------------------------")
            
            all_train_accuracy.append(best_training)
            all_val_accuracy.append(current_best)
            all_test_accuracy.append(test_accuracy)
            all_epoch.append(best_epoch)
        
        print("Results for Model: ", model_name)

        print("Dropout Probability: ", node_dropout_prob)
        print("Number of Cluster Nodes: ", n_cluster_nodes)

        print("Training Accuracy Mean: {:.3f}".format(np.mean(all_train_accuracy)))
        print("Training Accuracy Std Deviation: {:.3f}".format(np.std(all_train_accuracy)))

        print("Validation Accuracy Mean: {:.3f}".format(np.mean(all_val_accuracy)))
        print("Validation Accuracy Std Deviation: {:.3f}".format(np.std(all_val_accuracy)))

        print("Test Accuracy Mean: {:.3f}".format(np.mean(all_test_accuracy)))
        print("Test Accuracy Std Deviation: {:.3f}".format(np.std(all_test_accuracy)))

        print("Best Average Epoch: {:.3f}".format(np.mean(all_epoch)))

        print(f"Average Runtime per Epoch: {(time.time() - start_time)/sum(all_epoch):.3f} seconds.")

Running training on:  cpu

 Graph_Sage_mean

 pro_params
edge_params
{'bias': True, 'globals': False, 'nodes_below': False, 'nodes_above': False, 'nodes_depth': False, 'nodes_own': True, 'edge_weights': True, 'project': False, 'initial_loops': False, 'n_feat_node': 100, 'n_out': 100}
node_params
{'bias': True, 'receivers': True, 'nodes_below': False, 'nodes_own': True, 'nodes_above': False, 'nodes_depth': False, 'globals': False, 'normalize': True, 'project': True, 'activation': 'relu', 'n_feat': 200, 'n_out': 100}

 RMGN(
  (S_encoder): Sequential(
    (0): Linear(in_features=5, out_features=100, bias=True)
    (1): ReLU()
  )
  (S_processor): ModuleList(
    (0-3): 4 x ModuleList(
      (0): GraphNetwork(
        (model): ModuleList(
          (0): EdgeModel(
            (norm_att_fn): Normalize()
          )
          (1): NodeModel(
            (node_project_fn): Linear(in_features=200, out_features=100, bias=True)
            (activation_fn): ReLU()
            (dropout): Dropout(

In [None]:
model_dir = "".join([model_root, data_name])
saved_state_dict_path = model_dir + "/" + "RMGN_[6]_tensor(70.3675)"
model = torch.load(saved_state_dict_path)
counter = 0
for i in model:
    if "running_mean" not in i and "running_var" not in i and "num_batches_tracked" not in i:
        print(i, model[i].shape)
    # counter += np.prod(model[i].shape)
counter

In [None]:
model = RMGN(model_name, setup_dict, data_dict, n_cluster_nodes, hierarchical_network_params).to(device)
model_dir = "".join([model_root, data_name])
saved_state_dict_path = model_dir + "/" + "RMGN_[6]_tensor(70.3675)"
with open(saved_state_dict_path, 'rb') as f:
    buffer = io.BytesIO(f.read())

model.load_state_dict(torch.load(buffer))

model.eval()
correct = 0
total = 0
with torch.no_grad():
    for snapshot in train_loader:
        if "pos" in snapshot: snapshot.x = torch.cat((snapshot.x, snapshot.pos), dim=-1)
        snapshot = snapshot.to(device)

        if model_type == "RMGN":
            snapshot.node_attr = snapshot.x
            del snapshot.x
            S_graphs, ST_graphs = hierarchical_graphs(n_cluster_nodes, snapshot, 
                                        model_dict[model_name]["edge_params"]["initial_loops"])
            S_graphs = [graph.to(device) for graph in S_graphs]
            ST_graphs = [graph.to(device) for graph in ST_graphs]
            y_hat, mincut_loss, ortho_loss = model(S_graphs, ST_graphs)

        # y_hat = global_mean_pool(y_hat, snapshot.batch)
        # y_hat = after_model(y_hat)

        loss = loss_funtion(snapshot.y, y_hat, data_dict["n_out_final"], device) + mincut_loss + ortho_loss

        total += snapshot.y.size(0)
        correct += (accuracy(snapshot.y, y_hat) * snapshot.y.size(0))

    test_accuracy = 100.0 * correct / total

In [None]:
from torch_geometric.transforms.add_positional_encoding import AddLaplacianEigenvectorPE
for snapshot in train_loader:
    break
# test = AddLaplacianEigenvectorPE(10, attr_name=None)
# test(snapshot)
snapshot

In [None]:
from GAT_imlementation import GATConv
gat_model = GATConv(5, 19, 8).cuda()
for snapshot in test_loader:
    if "pos" in snapshot: snapshot.x = torch.cat((snapshot.x, snapshot.pos), dim=-1)
    snapshot = snapshot.to(device)

    snapshot.node_attr = snapshot.x
    del snapshot.x
    S_graphs, ST_graphs = hierarchical_graphs(n_cluster_nodes, snapshot, 
                                model_dict[model_name]["edge_params"]["initial_loops"])
    S_graphs = [graph.to(device) for graph in S_graphs]

gat_model(S_graphs[0].node_attr, S_graphs[0].edge_index)

In [None]:

elapsed_time = time.time() - start_time
print(f"The code executed in {elapsed_time / 1000} seconds.")

In [None]:
from torch_geometric.utils import to_dense_adj
adj_mat = to_dense_adj(snapshot.edge_index)

In [None]:
import numpy as np

# Assuming adj_mat is adjacency matrix of graph
print((adj_mat.sum(axis=0) > 0).all() and (adj_mat.sum(axis=1) > 0).all())  # This will print True if graph is fully connected


In [None]:
from torch_geometric.utils import degree
row = torch.tensor([0, 1, 0, 2, 0])
degree(row, 180)
        

In [None]:
import networkx as nx

G = nx.Graph()
# add edges to the graph - example ((1, 2), (1, 3), (2, 3))
G.add_edges_from([(1, 2)])

print(nx.is_connected(G))  # This will print True if graph is fully connected


In [None]:
test = torch.rand(100, 4)
test

In [None]:
test / test[:,1].sum()

In [None]:
import numpy as np

training_accuracies = [50.489, 50.496, 50.490]
test_accuracies = [50.37, 50.14, 50.21]

print("Training Accuracy Mean: {:.3f}".format(np.mean(training_accuracies)))
print("Training Accuracy Std Deviation: {:.3f}".format(np.std(training_accuracies)))

print("Test Accuracy Mean: {:.3f}".format(np.mean(test_accuracies)))
print("Test Accuracy Std Deviation: {:.3f}".format(np.std(test_accuracies)))


In [None]:
torch.set_printoptions(profile="full")
for snapshot in test_loader:
    snapshot.node_attr = snapshot.x
    del snapshot.x
    S_graphs, ST_graphs = hierarchical_graphs(model_config["n_cluster_nodes"], snapshot, 
                                    merged_dict["edge_params"]["with_initial_loops"])
    S_graphs = [graph.to(device) for graph in S_graphs]
    ST_graphs = [graph.to(device) for graph in ST_graphs]
    model_dir = "".join([model_root, dataset_train.name])
    f_name = "_".join([model_type, str(n_cluster_nodes)])
    model = RMGN(data_config, model_config, after_config).to(device)
    saved_state_dict_path = model_dir + "/" + f_name
    with open(saved_state_dict_path, 'rb') as f:
        buffer = io.BytesIO(f.read())

    model.load_state_dict(torch.load(buffer))
    model.eval()
    model(S_graphs, ST_graphs)
    print(model)
    break
    print("")

"".join([model_root, dataset_train.name, "/"])

In [None]:
model.S_processor[0][0].model[0].norm_att_fn.att_dst

In [None]:
model

In [None]:
import torch
import numpy as np
from sklearn.metrics import accuracy_score

# Assuming y_true and y_pred are 1D tensors or numpy arrays containing 
# the true and predicted class labels, respectively

def calculate_weighted_accuracy(y_true, y_pred):
    unique_classes = torch.unique(y_true)
    class_weights = []
    class_accuracies = []

    for class_label in unique_classes:
        class_indices = (y_true == class_label)
        # class_accuracy = accuracy_score(y_true[class_indices], y_pred[class_indices])
        class_accuracy += y_true[class_indices].eq(y_pred[class_indices]).sum().item() / class_indices.sum()
        class_accuracies.append(class_accuracy)
        class_weights.append(class_indices.sum())

    # Normalize weights so they sum up to 1
    class_weights = class_weights / class_weights.sum()

    # Compute the weighted average accuracy
    weighted_accuracy = torch.Tensor(class_accuracies) * torch.Tensor(class_weights)

    return weighted_accuracy

# Let's say y_true and y_pred are your ground truth and predicted labels
y_true = torch.Tensor([0, 0, 0, 1])
y_pred = torch.Tensor([0, 0, 0, 0])  # just for example

weighted_accuracy = calculate_weighted_accuracy(y_true, y_pred)
print(f'Weighted Accuracy: {weighted_accuracy:.2f}')



In [None]:
from sklearn.utils.class_weight import compute_class_weight

y_true = torch.Tensor([0, 0, 0, 1])
y_pred = torch.Tensor([0, 0, 0, 0])  # just for example

unique_classes = torch.unique(y_true)
class_accuracies = []

for class_label in unique_classes:
    class_indices = (y_true == class_label)
    # class_accuracy = accuracy_score(y_true[class_indices], y_pred[class_indices])
    class_accuracy = y_true[class_indices].eq(y_pred[class_indices]).sum().item() / class_indices.sum()
    class_accuracies.append(class_accuracy)

weighted_accuracy = torch.Tensor(class_accuracies).mean()

In [None]:
weighted_accuracy

In [None]:
import torch
test = torch.nn.Embedding(7, 108).to("cuda:0")

In [None]:
test(snapshot.node_attr).shape

In [None]:
model.S_processor[0][0].model[0].norm_att_fn.att_src

In [None]:
# def node_to_node_broadcast(broadcast_method, graphs, idx):
#     batch_size = len(graphs[idx]["ptr"]) -1
#     s_dim = graphs[idx].selection.shape
#     v_dim = graphs[idx+1].node_attr.shape
#     s_dim = [batch_size] + [int(s_dim[0] / batch_size)] + list(s_dim[1:])
#     v_dim = [batch_size] + [int(v_dim[0] / batch_size)] + list(v_dim[1:])
#     if len(v_dim) == 3:
#         return torch.einsum("abc, acd -> abd", graphs[idx].selection.reshape(s_dim), graphs[idx+1].node_attr.reshape(v_dim)).flatten(end_dim=1)
#     else:
#         return torch.einsum("abc, acde -> abde", graphs[idx].selection.reshape(s_dim), graphs[idx+1].node_attr.reshape(v_dim)).flatten(end_dim=1)
        
def node_to_node_broadcast(broadcast_method, graphs, idx):
    nodes_per_graph = graphs[idx].ptr[1:] - graphs[idx].ptr[:-1]
    broadcasted_attrs = torch.repeat_interleave(graphs[idx+1].node_attr, nodes_per_graph, dim=0)
    weighted_broadcasted_attr = broadcasted_attrs * graphs[idx].selection
    return weighted_broadcasted_attr


In [None]:
from torch_scatter import scatter
scatter(snapshot.edge_attr, snapshot.edge_index[1], dim=0, reduce="sum").shape

In [None]:
data = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
segment_ids = torch.tensor([0, 0, 2])
scatter(data, segment_ids, dim=0, dim_size=4)

In [None]:
snapshot.batch[snapshot.edge_index[1]]

In [None]:
snapshot.edge_index[1]

In [None]:
snapshot.batch