Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simple Example does not work #1

Closed
marcown opened this issue Oct 19, 2021 · 4 comments
Closed

Simple Example does not work #1

marcown opened this issue Oct 19, 2021 · 4 comments

Comments

@marcown
Copy link

marcown commented Oct 19, 2021

Hey there!

Thanks for the great work and open source code.

I have tried a very simple example but couldnt get it to work:

import torch
import torch.nn as nn
import torch.nn.functional as F
import ckconv
from ckconv.nn import CKConv
from omegaconf import OmegaConf


kernel_config = OmegaConf.create({"type": "MLP", "dim_linear": 2, "no_hidden": 2, "no_layers": 3, "activ_function": "ReLU","norm": "BatchNorm","omega_0": 1,"learn_omega_0": False,"weight_norm": False,"steerable": False,"init_spatial_value": 1.0,"bias_init": None,"input_scale": 25.6,"sampling_rate_norm": 1.0,"regularize": False,"regularize_params": {"res": 0 ,"res_offset": 0,"target": "gabor+mask","fn": "l2_relu","method":"together","factor": 0.001,"gauss_stddevs": 2.0,"gauss_factor": 0.5},"srf": {"scale": 0.}})


conv_config = OmegaConf.create({"type": "","use_fft": False, "bias": True,"padding": "same","stride": 1,"horizon": "same","cache": False })

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = CKConv(3, 6, kernel_config, conv_config) # nn.Conv2d(3, 6, 5) --> original conv that works
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        print("x: ", x.shape)
        y = self.conv1(x)
        print("y: ", y.shape)
        x = self.pool(F.relu(y))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()


inn = torch.randn((1,3, 28, 28))
out = net(inn)

-->

RuntimeError: Given weight of size [2, 2, 1, 1], expected bias to be 1-dimensional with 2 elements, but got bias of size [2, 2] instead

(you can ignore everything after the first conv, borrowed from pytorch examples)

I tried different configuration (above is only one example).

Thanks for any help :)

@dwromero
Copy link
Collaborator

Hi Markus,

Thank you very much! :)

What is the size of the input signal?

@marcown
Copy link
Author

marcown commented Oct 19, 2021

Hey thanks for the fast response

see above: inn = torch.randn((1,3, 28, 28))

but I also tested it with other dimensions

@dwromero
Copy link
Collaborator

Oh sorry, I missed that.

I found the error. It was at our side. Sorry about that. So, the problem is that the bias in the Conv2d was being overwritten during the uniform initialization of the biases. I checked this only for the 1D case, and it leads to errors in 2D. If you pull the repo again, it should work now.

Regarding your code, please change the horizon here:
conv_config = OmegaConf.create({"type": "","use_fft": False, "bias": True,"padding": "same","stride": 1,"horizon": "same","cache": False })
to 29. This is necessary because the input is of even size, and as the kernel is generated on the fly, it would also be even.

With that being said, I would not expect the results with the MLP kernels to be very good. Normal MLPs actually perform pretty poorly. This is why we need implicit neural representations (e.g., MAGNets). I would try that out if the results with the MLP are not very good.

Please let me know if this solves the problems for now :)

Cheers,
David

@marcown
Copy link
Author

marcown commented Oct 19, 2021

Perfect, thanks!

And yes I will try MAGNets, that was just a first small test :)

@marcown marcown closed this as completed Oct 19, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants