In [1]:
import torch.nn.functional as F
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_sparse import matmul

from label_prop_heterophily import get_myo_h, LabelPropagationHeterophilyTog

from load_data import *
from basic_model import Basic_MLP, train_MLP, test_MLP
from utils import seed, masked_test, makeDoubleStochastic

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

# SEED = random.randint(1, 10001)
SEED = 8997 # 7969 # 8997 # 5206 42
seed(value=SEED)

import warnings
warnings.filterwarnings('ignore')



In [2]:
data_name = 'Texas' # syn-10-1000-0.5-0.01 syn-10-1000-0.5-0.005 'MixHopSyn-1'
num_train = 10 # 0: keep fixed split; others are train ratio
split_id = 8
add_feat = False
#
n_hops = 1
n_layers = 20
dropout = 0.5
n_epoch = 501
patience = 100

In [3]:
data = load_data(
    data_name, num_train=num_train,
    fixed_split=False, split_id=split_id,
    add_feat=add_feat, to_sparse=True,
).to(device)
in_dim = data.x.shape[1]
out_dim = data.y.unique().size(0)

The obtained data Texas has 183 nodes, 325 edges, 1703 features, 5 labels, 18 training, 18 validation and 147 testing nodes


In [4]:
model_mlp = Basic_MLP(in_dim, out_dim, hid_dim=64, dropout=dropout).to(device)
print('Information about the model:')
print(model_mlp)
optimizer = torch.optim.Adam(params=model_mlp.parameters(), lr=0.01, weight_decay=5e-5)

best_train_acc = best_val_acc = best_test_acc = 0
best_epoch = 0
for epoch in range(1, n_epoch):
    loss = train_MLP(data, model_mlp, optimizer)
    (train_acc, val_acc, test_acc), pred = test_MLP(data, model_mlp)
    if val_acc > best_val_acc:
        best_epoch = epoch
        best_train_acc = train_acc
        best_val_acc = val_acc
        best_test_acc = test_acc
        best_pred = pred
    if epoch - best_epoch == patience:
        break
print('Best Train: {:.4f}, Best Val: {:.4f}, Best Test: {:.4f}'.
      format(best_train_acc, best_val_acc, best_test_acc))
prior_estim = best_pred.clone()

Information about the model:
Basic_MLP(
  (conv_layers): ModuleList(
    (0): Linear(in_features=1703, out_features=64, bias=True)
    (1): Linear(in_features=64, out_features=64, bias=True)
    (2): Linear(in_features=64, out_features=5, bias=True)
  )
)
Best Train: 0.9444, Best Val: 0.6111, Best Test: 0.6667


In [5]:
adj_t = data.adj_t
deg = adj_t.sum(dim=1).to(torch.float)
deg_inv_sqrt = deg.pow_(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
DA = deg_inv_sqrt.view(-1, 1) * deg_inv_sqrt.view(-1, 1) * adj_t
DA = DA.storage.value()
# 
gcn_weight = gcn_norm(data.adj_t, add_self_loops=False).storage.value()
# 
A = data.adj_t
Y_true = F.one_hot(data.y.view(-1)).float()
H_true = torch.matmul(Y_true.transpose(0, 1), matmul(A, Y_true)) \
         / torch.matmul(Y_true.transpose(0, 1), matmul(A, Y_true)).sum(-1, keepdim=True)
H_true = makeDoubleStochastic(H_true)
H, B = get_myo_h(
    A, data.y, prior_estim,
    mask=data.train_mask
)
A_row, A_col, _ = A.coo()
cm_weight = torch.matmul(B[A_col], H) * B[A_row]
echo_H = torch.matmul(H, H)
echo_value = torch.matmul(B[A_col], echo_H) * B[A_row]
echo_value = echo_value * DA.view(-1, 1)

In [6]:
y_soft = best_pred.clone()
spread_mask, eval_mask, test_mask = data.train_mask, data.val_mask, data.test_mask
(train_acc, val_acc, test_acc) = masked_test(data, y_soft)
print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')
model = LabelPropagationHeterophilyTog(max_layers=n_layers, num_hops=n_hops)
print(model)
best_val_res = best_test_res = 0
##########################################################################
for alpha in range(1, 11):
    alpha /= 10.
    for diff_post_step in [True, False]:
        for echo_set in [True, False]:
            for select_eval in [True, False]:
                for weight_type in [0, 1]:
########################################################################## 
                    edge_weight = cm_weight if weight_type==0 else cm_weight * gcn_weight.view(-1, 1)
                    echo_weight = echo_value if echo_set is True else None
                    if diff_post_step:
                        post_step=lambda y: F.softmax(y, dim=-1)
                    else:
                        post_step=lambda y: y.clamp_(0., 1.)

                    idx, best_epoch, val_acc, test_acc, output = model(
                        y_true=data.y, y_soft=y_soft, alpha=alpha, data_name=data.name,
                        spread_mask=spread_mask, eval_mask=eval_mask, test_mask=test_mask,
                        adj=data.adj_t, edge_weight=edge_weight, echo_weight=echo_weight,
                        verbose=False, post_step=post_step, select_eval=select_eval,
                    )
                    print('Method: {}, Best epoch: {} -- Best Val: {:.4f}, Best Test: {:.4f}'.
                          format(idx+1, best_epoch, val_acc, test_acc))
                    if test_acc > best_test_res:
                        best_val_res, best_test_res = val_acc, test_acc
                        best_idx, best_alpha = idx, alpha
                        best_diff_post_step = diff_post_step
                        best_echo_set = echo_set
                        best_select_eval = select_eval
                        best_weight_type = weight_type
print('Method: {}, alpha: {}, diff: {}, echo_set: {}, eval: {}, weight: {}'.\
      format(best_idx+1, best_alpha, best_diff_post_step, best_echo_set, best_select_eval, best_weight_type))
print('Best Val: {:.4f}, Best Test: {:.4f}'.format(best_val_res, best_test_res))

Train: 0.9444, Val: 0.6111, Test: 0.6667
LabelPropagationHeterophilyTog(max_layers=20)
Method: 1, Best epoch: 1 -- Best Val: 0.6111, Best Test: 0.6667
Method: 1, Best epoch: 1 -- Best Val: 0.6111, Best Test: 0.6599
Method: 1, Best epoch: -1 -- Best Val: 0.6111, Best Test: 0.6667
Method: 1, Best epoch: -1 -- Best Val: 0.6111, Best Test: 0.6667
Method: 3, Best epoch: 2 -- Best Val: 0.6667, Best Test: 0.6599
Method: 1, Best epoch: 1 -- Best Val: 0.6111, Best Test: 0.6599
Method: 3, Best epoch: -1 -- Best Val: 0.6667, Best Test: 0.6599
Method: 1, Best epoch: -1 -- Best Val: 0.6111, Best Test: 0.6667
Method: 1, Best epoch: 1 -- Best Val: 0.6111, Best Test: 0.6667
Method: 1, Best epoch: 1 -- Best Val: 0.6111, Best Test: 0.6599
Method: 1, Best epoch: -1 -- Best Val: 0.6111, Best Test: 0.6667
Method: 1, Best epoch: -1 -- Best Val: 0.6111, Best Test: 0.6667
Method: 1, Best epoch: 1 -- Best Val: 0.6111, Best Test: 0.6667
Method: 1, Best epoch: 1 -- Best Val: 0.6111, Best Test: 0.6599
Method: 1, 