In [1]:
from config import *
from utilities import get_model_name
from dataset import ClassificationDataset

import functools
import numpy as np
import pandas as pd

results_dict = {}
internal, external = 11, 12
encode_method = "dummy"
impute_method = "itbr"
fs_method, fs_ratio = "chi2", 0.5
norm_method = "maxmin"
classification_dataset = ClassificationDataset(
    internal, external, encode_method, impute_method, fs_method, fs_ratio, norm_method, random_state=SEED)

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as gnn

class GraphConvNet(nn.Module):
    def __init__(self, input_size, output_size, hidden_size, alpha):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)
        self.input = nn.Linear(
            self.input_size, self.hidden_size[0], bias=False)
        self.alpha = alpha
        self.hiddens = nn.ModuleList([gnn.SAGEConv(
            self.hidden_size[h], self.hidden_size[h + 1]) for h in range(len(self.hidden_size) - 1)])
        self.output = nn.Linear(hidden_size[-1], output_size)

    def forward(self, x):
        edge_index = self.create_edge_index(x)
        x = self.input(x)
        x = self.relu(x)
        for hidden in self.hiddens:
            x = hidden(x, edge_index)
            x = self.relu(x)
        x = self.output(x)
        x = self.softmax(x)
        return x

    def create_edge_index(self, x):
        # similarity_matrix = torch.abs(F.cosine_similarity(
        #     x[None, :, :], x[:, None, :], dim=-1))
        similarity_matrix = torch.abs(F.cosine_similarity(
            x[..., None, :, :], x[..., :, None, :], dim=-1))
        similarity = torch.sort(similarity_matrix.view(-1))[0]
        eps = torch.quantile(similarity, self.alpha, interpolation='nearest')
        adj_matrix = similarity_matrix >= eps
        row, col = torch.where(adj_matrix)
        edge_index = torch.cat((row.reshape(1, -1), col.reshape(1, -1)), dim=0)
        return edge_index

In [12]:
bsz = 4
input_size, output_size = 10, 2
x = torch.randn(bsz, input_size)
y = torch.randint(output_size, (bsz,))

model = GraphConvNet(input_size, output_size, hidden_size=[64, 32], alpha=0.95)
# model(x)

In [15]:
model.hiddens[0]

IndexError: index 1 is out of range