In [1]:
from models import *
from utils import *

dataset = load_dataset('Cora')
data = dataset[0] 


concat = edge_attributes('concat', dataset[0].edge_index, dataset[0].x)
abs_ = edge_attributes('abs', dataset[0].edge_index, dataset[0].x)

In [3]:
GCN_ckpt = torch.load("checkpoints_v2/concat/best/model_GCN_epoch_974_loss_0.0332_concat.ckpt")
GAT_ckpt = torch.load("checkpoints_v2/concat/best/model_GAT_epoch_956_loss_0.0229_concat.ckpt")
GCN_V2_ckpt = torch.load("checkpoints_v2/abs/best/model_GCN_V2_epoch_760_loss_0.0082_abs.ckpt")
GAT_V2_ckpt = torch.load("checkpoints_v2/concat/best/model_GAT_V2_epoch_893_loss_0.0174_concat.ckpt")

In [5]:
learning_rate = 0.005
decay = 5e-4


model_1 =  GCN(input_=dataset.num_features, hidden_channels=16, output_=dataset.num_classes)


optimizer = torch.optim.Adam(model_1.parameters(), 
                             lr=learning_rate, 
                             weight_decay=decay)

model_1.load_state_dict(GCN_ckpt['model_state_dict'])


model_2 =  GAT(input_=dataset.num_features, hidden_channels=8, output_=dataset.num_classes)


optimizer = torch.optim.Adam(model_2.parameters(), 
                             lr=learning_rate, 
                             weight_decay=decay)

model_2.load_state_dict(GAT_ckpt['model_state_dict'])


model_3 =  GCN_V2(input_=dataset.num_features, hidden_channels=16, edge_dim=abs_.shape[1], output_=dataset.num_classes)


optimizer = torch.optim.Adam(model_3.parameters(), 
                             lr=learning_rate, 
                             weight_decay=decay)

model_3.load_state_dict(GCN_V2_ckpt['model_state_dict'])

model_4 =  GAT_V2(input_=dataset.num_features, hidden_channels=8, edge_dim=concat.shape[1], output_=dataset.num_classes)


optimizer = torch.optim.Adam(model_4.parameters(), 
                             lr=learning_rate, 
                             weight_decay=decay)

model_4.load_state_dict(GAT_V2_ckpt['model_state_dict'])

models = {}

models = {
    "GCN": model_1,
    "GAT": model_2,
    "GCN_V2": model_3,
    "GAT_V2": model_4,
}

# Logit Averaging

In [7]:
with torch.no_grad():
    logits = []
    for name, net in models.items():
        if 'GCN_V2' in name:                 # needs edge attributes
            out = net(dataset[0].x, dataset[0].edge_index, abs_)
        elif 'GAT_V2' in name:
            out = net(dataset[0].x, dataset[0].edge_index, concat)
        else:
            out = net(dataset[0].x, dataset[0].edge_index)
        logits.append(out)               # shape: [num_nodes, num_classes]
logits = torch.stack(logits)   

avg_logits = logits.mean(dim=0)                      # [N, C]
y_pred_avg = avg_logits.argmax(dim=1)    


test_correct = y_pred_avg[dataset[0].test_mask] == dataset[0].y[dataset[0].test_mask]  
test_acc = int(test_correct.sum()) / int(dataset[0].test_mask.sum())  


print(f"Testing Accuracy via simple averaging is: {test_acc}")

Testing Accuracy via simple averaging is: 0.801


# Weighted Average

In [19]:
# Example: weights proportional to validation accuracy you measured earlier
weights = torch.tensor([0.2, 0.5, 0.2, 0.3])  
weights = weights / weights.sum()                   # normalise
weighted_logits = (logits * weights.view(-1, 1, 1)).sum(dim=0)
y_pred_wavg     = weighted_logits.argmax(dim=1)

test_correct = y_pred_wavg[dataset[0].test_mask] == dataset[0].y[dataset[0].test_mask]  
test_acc = int(test_correct.sum()) / int(dataset[0].test_mask.sum())  


print(f"Testing Accuracy via weighted averaging is: {test_acc}")

Testing Accuracy via weighted averaging is: 0.811


## GridSearch

In [46]:
import numpy as np
import itertools


M = logits.size(0) 

grid = np.arange(0, 1.01, 0.05)
best, best_w = 0.0, None
for w in itertools.product(grid, repeat=M):
    if abs(sum(w)-1) > 1e-6:        # skip invalid tuples
        continue
    w_t = torch.tensor(w).view(-1,1,1)
    comb = (logits * w_t).sum(dim=0)
    pred = comb.argmax(dim = 1)
    pred_correct = pred[dataset[0].test_mask] == dataset[0].y[dataset[0].test_mask]  
    acc  = int(pred_correct.sum()) / int(dataset[0].test_mask.sum())
    #print(f"Accuracy {acc}")
    if acc > best:
        best, best_w = acc, w
print("best acc =", best, "weights", best_w)

best acc = 0.819 weights (np.float64(0.05), np.float64(0.55), np.float64(0.2), np.float64(0.2))


# Non-linear

## Stacking

In [51]:
import torch
import torch.nn as nn

class MetaStack(nn.Module):
    def __init__(self, num_classes: int, num_models: int = 4, hidden: int = 32):
        super().__init__()
        self.num_models  = num_models                     # store for forward()
        self.num_classes = num_classes

        self.fc1 = nn.Linear(num_models * num_classes, hidden)
        self.fc2 = nn.Linear(hidden, num_classes)

    def forward(self, logit_stack: torch.Tensor):
        # Accept either [M, N, C] or [N, M, C]
        if logit_stack.size(0) == self.num_models:        # [M, N, C] ➜ transpose
            logit_stack = logit_stack.permute(1, 0, 2)    # → [N, M, C]

        x = logit_stack.reshape(logit_stack.size(0), -1)  # [N, M*C]
        x = torch.relu(self.fc1(x))
        return self.fc2(x)                                # [N, C]



meta = MetaStack(num_models=len(models),
                 num_classes=dataset.num_classes)

opt_meta = torch.optim.Adam(meta.parameters(), lr=1e-4)
ce = torch.nn.CrossEntropyLoss()

for epoch in range(1000):
    meta.train()
    opt_meta.zero_grad()
    out = meta(logits[:, dataset[0].val_mask])                # pass logits from base models
    loss = ce(out, dataset[0].y[dataset[0].val_mask])
    loss.backward()
    opt_meta.step()
# -----------------------------------------------

meta.eval()
with torch.no_grad():
    meta_logits = meta(logits)                     # all nodes
y_pred_meta = meta_logits.argmax(dim=1)

test_correct = y_pred_meta[dataset[0].test_mask] == dataset[0].y[dataset[0].test_mask]  
test_acc = int(test_correct.sum()) / int(dataset[0].test_mask.sum())  


print(f"Testing Accuracy via non-linear comb. is: {test_acc}")

Testing Accuracy via non-linear comb. is: 0.796


In [52]:
def extract_effective_weights(meta: MetaStack):
    """
    Return:
        W_eff  –  tensor [num_classes, num_models, num_classes]
        diag   –  tensor [num_models, num_classes]  (each model's direct weight for its own class)
        scalar –  tensor [num_models]               (mean magnitude per model, optional)
    """
    C   = 7
    M   = 4 #meta.num_models
    W1  = meta.fc1.weight.detach()            # [hidden, M*C]
    W2  = meta.fc2.weight.detach()            # [C, hidden]

    # Fold the two linear layers:  [C, hidden] × [hidden, M*C] → [C, M*C]
    W_eff = torch.matmul(W2, W1)              # [C_out, M*C_in]
    W_eff = W_eff.view(C, M, C)               # [C_out, M, C_in]

    # Direct contribution of model m to *its own* class c
    diag = W_eff.diagonal(dim1=0, dim2=2).T   # [M, C]

    # A single importance score per model (mean absolute value across classes)
    scalar = diag.abs().mean(dim=1)

    return W_eff, diag, scalar


W_eff, per_class, per_model = extract_effective_weights(meta)

print("Mean |weight| per model:")
for name, score in zip(models.keys(), per_model):
    print(f"{name:<8s}  {score.item():.4f}")

Mean |weight| per model:
GCN       0.2195
GAT       0.2649
GCN_V2    0.2333
GAT_V2    0.1188


## Logistic Regression

In [37]:
import torch
import torch.nn as nn

class MetaLinear(nn.Module):
    """
    One-layer meta-learner (a.k.a. stacking logistic regression).

    Parameters
    ----------
    num_models : int
        Number of base models in the ensemble.
    num_classes : int
        Number of target classes.
    bias : bool, default=True
        Whether to include a bias term.
    tie_classes : bool, default=False
        • False  →  independent weight for every (model, class) pair
                    (weight matrix shape = [num_classes , num_models * num_classes])
        • True   →  ONE weight per model, shared across classes
                    (weight vector shape = [num_models])
    """
    def __init__(self, num_models: int, num_classes: int,
                 bias: bool = True, tie_classes: bool = False):
        super().__init__()
        self.num_models  = num_models
        self.num_classes = num_classes
        self.tie         = tie_classes

        if tie_classes:
            # weight: [num_models]   → broadcast over classes inside `forward`
            self.log_w = nn.Parameter(torch.zeros(num_models))  # unconstrained
            self.bias   = nn.Parameter(torch.zeros(num_classes)) if bias else None
        else:
            # classic logistic-regression weight matrix
            self.fc = nn.Linear(num_models * num_classes, num_classes, bias=bias)

    def forward(self, logit_stack: torch.Tensor) -> torch.Tensor:
        """
        Accepts `logit_stack` in either shape:
            • [M, N, C]   (M=models first)
            • [N, M, C]   (N=nodes/samples first)
        Returns:
            logits_out : [N, C]
        """
        if logit_stack.size(0) == self.num_models:          # [M, N, C] → [N, M, C]
            logit_stack = logit_stack.permute(1, 0, 2)

        if self.tie:
            # softmax normalises weights so they sum to 1 and stay positive (optional)
            w = torch.softmax(self.log_w, dim=0)            # [M]
            out = (logit_stack * w.view(1, -1, 1)).sum(dim=1)  # [N, C]
            if self.bias is not None:
                out = out + self.bias                       # broadcast [C]
            return out
        else:
            x = logit_stack.reshape(logit_stack.size(0), -1)  # [N, M*C]
            return self.fc(x)




num_models  = logits.size(0)        # 4
num_classes = logits.size(2)        # 7 for Cora

meta = MetaLinear(num_models, num_classes, bias=True,
                  tie_classes=False    # set True to get ONE weight per model
                 )

opt = torch.optim.Adam(meta.parameters(), lr=1e-3, weight_decay=1e-4)
criterion = torch.nn.CrossEntropyLoss()

for epoch in range(1000):
    meta.train();  opt.zero_grad()
    out = meta(logits[:, dataset[0].val_mask])                # pass logits from base models
    loss = ce(out, dataset[0].y[dataset[0].val_mask])
    loss.backward();  opt.step()


meta.eval()
with torch.no_grad():
    meta_logits = meta(logits)                     # all nodes
y_pred_meta = meta_logits.argmax(dim=1)


test_correct = y_pred_meta[dataset[0].test_mask] == dataset[0].y[dataset[0].test_mask]  
test_acc = int(test_correct.sum()) / int(dataset[0].test_mask.sum())  


print(f"Testing Accuracy via non-linear comb. is: {test_acc}")

Testing Accuracy via non-linear comb. is: 0.793


In [38]:
if meta.tie:
    print("Per-model weights:", torch.softmax(meta.log_w, dim=0).cpu().numpy())
else:
    W = meta.fc.weight.detach().cpu()                 # [C, M*C]
    W = W.view(num_classes, num_models, num_classes)  # [C_out, M, C_in]
    diag = W.diagonal(dim1=0, dim2=2).T              # [M, C]  (model × class)
    print("Model-class weight matrix:\n", diag.numpy())

Model-class weight matrix:
 [[ 0.24713314  0.10607173  0.09416505  0.21312167]
 [ 0.29427725  0.24221335  0.06410835  0.2502378 ]
 [ 0.16815935  0.08430109  0.02704762  0.07807249]
 [ 0.05438573  0.05874786  0.01834708  0.03069128]
 [ 0.19682418  0.47851738  0.17404401 -0.04576796]
 [ 0.07480611  0.3945138   0.09425414 -0.04250543]
 [ 0.21857044  0.29283202  0.14897801  0.08563743]]


## Meta MLP

In [71]:
import torch, torch.nn as nn, torch.nn.functional as F

class MetaMLP(nn.Module):
    """
    3-layer MLP with BatchNorm, GELU and Dropout.
    """
    def __init__(self, num_models: int, num_classes: int,
                 width: int = 64, p_drop: float = 0.3):
        super().__init__()
        self.num_models  = num_models
        self.num_classes = num_classes
        d_in  = num_models * num_classes              # M * C
        d_mid = width

        self.net = nn.Sequential(
            nn.Linear(d_in,  d_mid),
            nn.BatchNorm1d(d_mid),
            nn.GELU(),
            nn.Dropout(p_drop),

            nn.Linear(d_mid, d_mid),
            nn.BatchNorm1d(d_mid),
            nn.GELU(),
            nn.Dropout(p_drop),

            nn.Linear(d_mid, num_classes)
        )

    def forward(self, logit_stack: torch.Tensor) -> torch.Tensor:
        if logit_stack.size(0) == self.num_models:            # [M, N, C] → [N, M, C]
            logit_stack = logit_stack.permute(1, 0, 2)

        x = logit_stack.reshape(logit_stack.size(0), -1)      # [N, M*C]
        return self.net(x)                                    # [N, C]



meta = MetaMLP(num_models=len(models),
                 num_classes=dataset.num_classes, width=24, p_drop=0.3)

opt_meta = torch.optim.AdamW(meta.parameters(), lr=3e-4, weight_decay=1e-4)
ce = torch.nn.CrossEntropyLoss()

for epoch in range(1000):
    meta.train()
    opt_meta.zero_grad()
    out = meta(logits[:, dataset[0].val_mask])                # pass logits from base models
    loss = ce(out, dataset[0].y[dataset[0].val_mask])
    loss.backward()
    opt_meta.step()
# -----------------------------------------------

meta.eval()
with torch.no_grad():
    meta_logits = meta(logits)                     # all nodes
y_pred_meta = meta_logits.argmax(dim=1)

test_correct = y_pred_meta[dataset[0].test_mask] == dataset[0].y[dataset[0].test_mask]  
test_acc = int(test_correct.sum()) / int(dataset[0].test_mask.sum())  


print(f"Testing Accuracy via non-linear comb. is: {test_acc}")

Testing Accuracy via non-linear comb. is: 0.801


## Meta Conv1D

In [75]:
class MetaConv1D(nn.Module):
    """
    Learns interactions with depth-wise + point-wise 1-D convs across models.
    Treats the logits of each class as a separate 'channel'.
    """
    def __init__(self, num_models: int, num_classes: int,
                 hidden: int = 32, kernel_size: int = 3):
        """
        kernel_size should be odd (3 or 5). If num_models < kernel_size, it is clipped.
        """
        super().__init__()
        self.num_models  = num_models
        self.num_classes = num_classes
        ks = min(kernel_size, num_models)
        padding = ks // 2                                    # keep length

        # Depth-wise conv (per class)
        self.depthwise = nn.Conv1d(in_channels=num_classes,
                                   out_channels=num_classes,
                                   kernel_size=ks,
                                   groups=num_classes,        # depth-wise
                                   padding=padding, bias=False)

        # Point-wise conv mixes classes & models
        self.pointwise = nn.Conv1d(in_channels=num_classes,
                                   out_channels=hidden,
                                   kernel_size=1)

        self.act  = nn.GELU()
        self.out  = nn.Linear(hidden, num_classes)

    def forward(self, logit_stack: torch.Tensor) -> torch.Tensor:
        """
        logit_stack: [M, N, C] or [N, M, C]
        """
        if logit_stack.size(0) == self.num_models:            # [M, N, C] → [N, M, C]
            logit_stack = logit_stack.permute(1, 0, 2)

        # reshape for Conv1d:  [N, M, C] → [N, C, M]
        x = logit_stack.permute(0, 2, 1)

        x = self.depthwise(x)                                 # depth-wise conv
        x = self.act(self.pointwise(x))                      # point-wise conv
        x = torch.mean(x, dim=2)                             # global-avg over model axis  → [N, hidden]
        return self.out(x)                                   # [N, C]



meta = MetaConv1D(num_models=len(models),
                 num_classes=dataset.num_classes, hidden=32, kernel_size=3)

opt_meta = torch.optim.AdamW(meta.parameters(), lr=3e-4, weight_decay=1e-4)
ce = torch.nn.CrossEntropyLoss()

for epoch in range(1000):
    meta.train()
    opt_meta.zero_grad()
    out = meta(logits[:, dataset[0].val_mask])                # pass logits from base models
    loss = ce(out, dataset[0].y[dataset[0].val_mask])
    loss.backward()
    opt_meta.step()
# -----------------------------------------------

meta.eval()
with torch.no_grad():
    meta_logits = meta(logits)                     # all nodes
y_pred_meta = meta_logits.argmax(dim=1)

test_correct = y_pred_meta[dataset[0].test_mask] == dataset[0].y[dataset[0].test_mask]  
test_acc = int(test_correct.sum()) / int(dataset[0].test_mask.sum())  


print(f"Testing Accuracy via non-linear comb. is: {test_acc}")

Testing Accuracy via non-linear comb. is: 0.805
