Import dependencies

In [12]:
from model.layers import GraphSpectralFilterLayer, AnalysisFilter
from model.spectral_filter import Graph
import torch
import torch.nn.functional as F
from torch import nn
from random import seed as rseed
from numpy.random import seed as nseed
from webkb import get_dataset, run
from webkb.train_eval import evaluate
import numpy as np


Define hyperparameters

In [13]:
dataset_name = 'Wisconsin'
random_splits = False
runs = 1
epochs =2000
alpha = 0.2
seed =729
lr =0.01
weight_decay = 0.0008
patience=100
hidden=256
heads =4
dropout=0.3
normalize_features =True
pre_training = False
cuda = False
order =16
edge_dropout =0
node_feature_dropout =0
filter_name ='analysis'

rseed(seed)
nseed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7f9075ef1770>

Define model

In [14]:
class Net(torch.nn.Module):
    def __init__(self, dataset):
        super(Net, self).__init__()
        data = dataset[0]
        adj = torch.sparse_coo_tensor(data.edge_index, torch.ones(data.num_edges))
        self.G = Graph(adj)
        self.G.estimate_lmax()

        self.analysis = GraphSpectralFilterLayer(self.G, dataset.num_node_features, hidden,
                                                 dropout=dropout, out_channels=heads, filter=filter_name,
                                                 pre_training=pre_training, device='cuda' if cuda else 'cpu',
                                                 alpha=alpha, order=order)
        # self.mlp = nn.Sequential(nn.Linear(hidden * heads, 128),
        #                             nn.ReLU(inplace=True),
        #                             nn.Linear(128, 64),
        #                             nn.ReLU(inplace=True),
        #                             nn.Linear(64, 32),
        #                             nn.ReLU(inplace=True),
        #                             nn.Linear(32, dataset.num_classes),
        #                             nn.ReLU(inplace=True))

        # self.W = torch.zeros(hidden * heads, dataset.num_classes)

        self.synthesis = GraphSpectralFilterLayer(self.G, hidden * heads, dataset.num_classes, filter=filter_name,
                                                  device='cuda' if cuda else 'cpu', dropout=dropout,
                                                  out_channels=1, alpha=alpha, pre_training=False,
                                                  order=order)

    def reset_parameters(self):
        self.analysis.reset_parameters()
        # torch.nn.init.xavier_uniform_(self.W.data, gain=1.414)
        # for layer in self.mlp:
        #     if hasattr(layer, 'reset_parameters'):
        #         layer.reset_parameters()
        self.synthesis.reset_parameters()

    def forward(self, data):
        x = data.x
        x = F.dropout(x, p=dropout, training=self.training)
        x, attentions_1 = self.analysis(x)
        x = F.dropout(x, p=dropout, training=self.training)
        x, attentions_2 = self.synthesis(x)
        x = F.elu(x)
        # x = F.elu(x.mm(self.W))
        # x = self.mlp(x)
        return F.log_softmax(x, dim=1), attentions_1, attentions_2


dataset = get_dataset(dataset_name, normalize_features, edge_dropout=edge_dropout,
                                node_feature_dropout=node_feature_dropout)

if cuda:
    dataset[0].to('cuda')

# permute_masks = random_planetoid_splits if random_splits else None
# run(dataset, Net(dataset), runs, epochs, lr, weight_decay,
#     early_stopping, permute_masks)

In [15]:
class SingleNet(torch.nn.Module):
    def __init__(self, dataset):
        super(SingleNet, self).__init__()
        data = dataset[0]
        adj = torch.sparse_coo_tensor(data.edge_index, torch.ones(data.num_edges))
        self.G = Graph(adj)
        self.G.estimate_lmax()

        self.analysis = GraphSpectralFilterLayer(self.G, dataset.num_node_features, hidden,
                                                 dropout=dropout, out_channels=heads, filter=filter_name,
                                                 pre_training=pre_training, device='cuda' if cuda else 'cpu',
                                                 alpha=alpha, order=order)

        self.linear = torch.nn.Linear(hidden * heads, dataset.num_classes)

    def reset_parameters(self):
        self.analysis.reset_parameters()
        self.linear.reset_parameters()

    def forward(self, data):
        x = data.x
        x = F.dropout(x, p=dropout, training=self.training)
        x, attentions_1 = self.analysis(x)
        x = F.dropout(x, p=dropout, training=self.training)
        x = F.elu(self.linear(x))
        return F.log_softmax(x, dim=1), attentions_1, None

Load trained model and evaluate

In [16]:
model = Net(dataset)

accs = []

for i in range(10):
    model = Net(dataset)
    model.load_state_dict(torch.load('./model/best_wisconsin_dec_split_{}.pkl'.format(i)))
    # model = SingleNet(dataset)
    # model.load_state_dict(torch.load('./model/best_cornell_single_dec.pkl'.format(dataset_name)))

    eval_info = evaluate(model, dataset[0], split=i)
    accs.append(eval_info['test_acc'])

accs = torch.tensor(accs)
print('acc:', accs.mean().item(), 'std:', accs.std().item())

acc: 0.8588235294117647 std: 0.03788595227761946


Obtain attention weights in layer 1 and 2

Build NetworkX Graph

In [6]:
from matplotlib import pyplot as plt
def plot_filter_banks(idx=list(range(heads)), kernel=model.analysis.filter._kernel, ax=None, no_ticks=False, legend=True):
    plt.figure(figsize=(5,5))
    x = torch.linspace(0, 2, 100)
    if not ax:
        plt.figure(figsize=(10,10))
        plt.plot(x, kernel(x).detach()[:, idx])
#         if len(idx) > 1 and legend:
#             plt.legend(['filter {}'.format(i) for i in range(1, len(idx) + 1)])
#         plt.show()
    else:
        ax.plot(x, kernel(x).detach()[:, idx])
#         if len(idx) > 1 and legend:
#             ax.legend(['filter {}'.format(i) for i in range(1, len(idx) + 1)])
    if no_ticks:
        plt.xticks([])
        plt.yticks([])
    else:
        plt.xticks(fontsize=18)
        plt.yticks(fontsize=18)
    plt.show()


### Frequency cutoff analysis

In [7]:
# Cut frequency bands abruptly
class CutOff(nn.Module):
    def __init__(self, kernel, min_val = 0, max_val = 2):
        super(CutOff, self).__init__()
        self.min = min_val
        self.max = max_val
        self.kernel = kernel

    def reset_parameters(self):
        pass

    def forward(self, x):
        h = torch.where(
            (x.view(-1,1).repeat(1, heads) >= self.min).logical_and(x.view(-1,1).repeat(1, heads) < self.max), 
            torch.zeros(x.shape[0], heads), self.kernel(x))
        return h

Evaluate low pass cutoff

In [22]:
step = 0.25
acc_mean_by_step = []
acc_std_by_step = []

for threshold in torch.arange(0, 2, step):
    accs = []
    for i in range(10):
        model = Net(dataset)
        model.load_state_dict(torch.load('./model/best_wisconsin_dec_split_{}.pkl'.format(i)))
        filter_kernel = model.analysis.filter_kernel
    #     syn_filter_kernel = model.synthesis.filter_kernel
        model.analysis.filter_kernel = CutOff(min_val=threshold, max_val=threshold+step, kernel=filter_kernel)
        model.analysis.filter._kernel = CutOff(min_val=threshold, max_val=threshold+step, kernel=filter_kernel)
#         plot_filter_banks(kernel=model.analysis.filter_kernel)
        eval_info = evaluate(model, dataset[0], split=i)
        accs.append(eval_info['test_acc'])

    accs = torch.tensor(accs)
    acc_mean_by_step.append(accs.mean().item())
    acc_std_by_step.append(accs.std().item())
    print('step:', threshold.item(), 'acc:', accs.mean().item(), 'std:', accs.std().item())
print(acc_mean_by_step, acc_std_by_step)

step: 0.0 acc: 0.8588235294117647 std: 0.03788595227761946
step: 0.25 acc: 0.8588235294117647 std: 0.03788595227761946
step: 0.5 acc: 0.8588235294117647 std: 0.03788595227761946
step: 0.75 acc: 0.8588235294117647 std: 0.03788595227761946
step: 1.0 acc: 0.6176470588235294 std: 0.06552267207765823
step: 1.25 acc: 0.8588235294117647 std: 0.03788595227761946
step: 1.5 acc: 0.8607843137254901 std: 0.041799416876652264
step: 1.75 acc: 0.8607843137254901 std: 0.041799416876652264
[0.8588235294117647, 0.8588235294117647, 0.8588235294117647, 0.8588235294117647, 0.6176470588235294, 0.8588235294117647, 0.8607843137254901, 0.8607843137254901] [0.03788595227761946, 0.03788595227761946, 0.03788595227761946, 0.03788595227761946, 0.06552267207765823, 0.03788595227761946, 0.041799416876652264, 0.041799416876652264]


In [18]:
accs = []

for i in range(10):
    model = SingleNet(dataset)
    model.load_state_dict(torch.load('./model/best_wisconsin_single_dec_split_{}.pkl'.format(i)))
    # model = SingleNet(dataset)
    # model.load_state_dict(torch.load('./model/best_cornell_single_dec.pkl'.format(dataset_name)))

    eval_info = evaluate(model, dataset[0], split=i)
    accs.append(eval_info['test_acc'])

accs = torch.tensor(accs)
print('acc:', accs.mean().item(), 'std:', accs.std().item())

acc: 0.8431372549019607 std: 0.05228758169934639


In [23]:
step = 0.25
acc_mean_by_step = []
acc_std_by_step = []

for threshold in torch.arange(0, 2, step):
    accs = []
    for i in range(10):
        model = SingleNet(dataset)
        model.load_state_dict(torch.load('./model/best_wisconsin_single_dec_split_{}.pkl'.format(i)))
        filter_kernel = model.analysis.filter_kernel
    #     syn_filter_kernel = model.synthesis.filter_kernel
        model.analysis.filter_kernel = CutOff(min_val=threshold, max_val=threshold+step, kernel=filter_kernel)
        model.analysis.filter._kernel = CutOff(min_val=threshold, max_val=threshold+step, kernel=filter_kernel)
#         plot_filter_banks(kernel=model.analysis.filter_kernel)
        eval_info = evaluate(model, dataset[0], split=i)
        accs.append(eval_info['test_acc'])

    accs = torch.tensor(accs)
    acc_mean_by_step.append(accs.mean().item())
    acc_std_by_step.append(accs.std().item())
    print('step:', threshold.item(), 'acc:', accs.mean().item(), 'std:', accs.std().item())
print(acc_mean_by_step, acc_std_by_step)

step: 0.0 acc: 0.8431372549019607 std: 0.05228758169934639
step: 0.25 acc: 0.8431372549019607 std: 0.05228758169934639
step: 0.5 acc: 0.8431372549019607 std: 0.05228758169934639
step: 0.75 acc: 0.8431372549019607 std: 0.05228758169934639
step: 1.0 acc: 0.6058823529411764 std: 0.06757676876111285
step: 1.25 acc: 0.8431372549019607 std: 0.05228758169934639
step: 1.5 acc: 0.8470588235294118 std: 0.04785099403326307
step: 1.75 acc: 0.8470588235294118 std: 0.04785099403326307
[0.8431372549019607, 0.8431372549019607, 0.8431372549019607, 0.8431372549019607, 0.6058823529411764, 0.8431372549019607, 0.8470588235294118, 0.8470588235294118] [0.05228758169934639, 0.05228758169934639, 0.05228758169934639, 0.05228758169934639, 0.06757676876111285, 0.05228758169934639, 0.04785099403326307, 0.04785099403326307]
