Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: pseudo.size(1) == kernel_size.numel() INTERNAL ASSERT FAILED. Input mismatch #6315

Open
Amirtmgr opened this issue Dec 30, 2022 · 1 comment
Labels

Comments

@Amirtmgr
Copy link

Amirtmgr commented Dec 30, 2022

馃悰 Describe the bug

I tried to train a SplineCNN as provided in the example named mnist_nn_conv.py. I got the following error:

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
File "...conda\envs\dl23\lib\site-packages\torch_spline_conv\basis.py", line 10, in spline_basis
is_open_spline: torch.Tensor,
degree: int) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.torch_spline_conv.spline_basis(pseudo, kernel_size,
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
is_open_spline, degree)
RuntimeError: pseudo.size(1) == kernel_size.numel() INTERNAL ASSERT FAILED at "D:\a\pytorch_spline_conv\pytorch_spline_conv\csrc\cuda\basis_cuda.cu":104, please report a bug to PyTorch. Input mismatch

My Code

`
import os.path as osp

import torch
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric.transforms as T
from torch_geometric.datasets import MNISTSuperpixels
from torch_geometric.loader import DataLoader
from torch_geometric.nn import (
    SplineConv,
    global_mean_pool,
    graclus,
    max_pool,
    max_pool_x,
)
from torch_geometric.utils import normalized_cut


//Datasets
path = osp.join(osp.dirname(osp.realpath("/")), '..', 'data', 'MNIST')
transform = T.Cartesian(cat=False)
train_dataset = MNISTSuperpixels(path, True, transform=transform)
test_dataset = MNISTSuperpixels(path, False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
d = train_dataset`

//Normalized Cut
def normalized_cut_2d(edge_index, pos):
    row, col = edge_index
    edge_attr = torch.norm(pos[row] - pos[col], p=2, dim=1)
    return normalized_cut(edge_index, edge_attr, num_nodes=pos.size(0))

//SplineCNN
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = SplineConv(in_channels = d.num_features, out_channels= 32,dim=1, kernel_size = 3)
        self.conv2 = SplineConv(in_channels = 32, out_channels= 64, dim=1, kernel_size = 3)
        self.fc1 = torch.nn.Linear(64, 128)
        self.fc2 = torch.nn.Linear(128, d.num_classes)

    def forward(self, data):
        data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr))
        weight = normalized_cut_2d(data.edge_index, data.pos)
        cluster = graclus(data.edge_index, weight, data.x.size(0))
        data.edge_attr = None
        data = max_pool(cluster, data, transform=transform)

        data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr))
        weight = normalized_cut_2d(data.edge_index, data.facepos)
        cluster = graclus(data.edge_index, weight, data.x.size(0))
        x, batch = max_pool_x(cluster, data.x, data.batch)

        x = global_mean_pool(x, batch)
        x = F.elu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        return F.log_softmax(self.fc2(x), dim=1)

//Create Model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

//Train Function
def train(epoch):
    model.train()

    if epoch == 16:
        for param_group in optimizer.param_groups:
            param_group['lr'] = 0.001

    if epoch == 26:
        for param_group in optimizer.param_groups:
            param_group['lr'] = 0.0001

    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        F.nll_loss(model(data), data.y).backward()
        optimizer.step()

//Test Function
def test():
    model.eval()
    correct = 0

    for data in test_loader:
        data = data.to(device)
        pred = model(data).max(1)[1]
        correct += pred.eq(data.y).sum().item()
    return correct / len(test_dataset)

//Run epoch
for epoch in range(1, 31):
    train(epoch)
    test_acc = test()
    print(f'Epoch: {epoch:02d}, Test: {test_acc:.4f}')

Environment

  • PyG version: 2.1.0
  • PyTorch version: 1.13.0
  • OS: Windows
  • Python version:3.10.8
  • CUDA/cuDNN version: 11.7
  • How you installed PyTorch and PyG (conda, pip, source):pip
  • Any other relevant information (e.g., version of torch-scatter):pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.13.0+cu117.html
@Amirtmgr Amirtmgr added the bug label Dec 30, 2022
@rusty1s
Copy link
Member

rusty1s commented Dec 30, 2022

Do other examples of SplineConv work for you, e.g., cora.py? One thing to look out for is that edge_attr.min() >= 0 and edge_attr.max() <= 1. Can you see if this is indeed the case for you?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants