-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Torch vs Tensorflow AGNP is different #22
Comments
Hey @DrJonnyT! Hmmm, this is curious. I think you might be onto something here. I can tell that the difference in the encoder. Will investigate and get back to you! |
The discrepancy indeed is in the |
Actually, @DrJonnyT, I'm thinking that things might be fine after all. Could you try running the following and seeing if you find that things are equal too? import neuralprocesses.torch as nps_torch
import neuralprocesses.tensorflow as nps_tf
import lab as B
import numpy as np
import tensorflow as tf
import torch
model_tf = nps_tf.construct_gnp(
dim_x=17,
dim_y=9,
dim_embedding=128,
num_enc_layers=6,
num_dec_layers=6,
)
model_torch = nps_torch.construct_gnp(
dim_x=17,
dim_y=9,
dim_embedding=128,
num_enc_layers=6,
num_dec_layers=6,
)
model_torch(
B.randn(torch.float32, 16, 17, 10),
B.randn(torch.float32, 16, 9, 10),
B.randn(torch.float32, 16, 17, 10),
)
model_tf(
B.randn(tf.float32, 16, 17, 10),
B.randn(tf.float32, 16, 9, 10),
B.randn(tf.float32, 16, 17, 10),
)
assert len(model_tf.get_weights()) == len(list(model_torch.parameters()))
for x, y in zip(model_tf.get_weights(), model_torch.parameters()):
assert x.shape == y.shape or x.shape == (y.shape[1], y.shape[0])
print("Ok!")
model_tf = nps_tf.construct_agnp(
dim_x=17,
dim_y=9,
dim_embedding=128,
num_enc_layers=6,
num_dec_layers=6,
)
model_torch = nps_torch.construct_agnp(
dim_x=17,
dim_y=9,
dim_embedding=128,
num_enc_layers=6,
num_dec_layers=6,
)
model_torch(
B.randn(torch.float32, 16, 17, 10),
B.randn(torch.float32, 16, 9, 10),
B.randn(torch.float32, 16, 17, 10),
)
model_tf(
B.randn(tf.float32, 16, 17, 10),
B.randn(tf.float32, 16, 9, 10),
B.randn(tf.float32, 16, 17, 10),
)
assert len(model_tf.get_weights()) == len(list(model_torch.parameters()))
for x, y in zip(model_tf.get_weights(), model_torch.parameters()):
assert x.shape == y.shape or x.shape == (y.shape[1], y.shape[0])
print("Ok!") |
Hmm, the performance may be sensitive to initialisation, which could be different between PyTorch and TF and the precise optimiser settings (how do learning rate and batch size interact?). Are you sure that those are equal? |
@wesselb that script gives me OK for both gnp and agnp, so all good there
|
@wesselb here's a messy chatgpt script to make 2 identical relu networks, just using pytorch and tensorflow. It seems to train much quicker in tensorflow so I think it's probably just a tensorflow vs torch thing rather than an issues with neuralprocesses? I would have thought that would be more well known though?
|
@DrJonnyT Your ReLU example is a good one. I would chase that down. It should be possible to configure things so that the convergence is exactly the same between TF and PyTorch. Do PyTorch and TF initialise the weights in the same way? That could make a big difference. |
@wesselb This gets you pretty close! I get these losses for 5 epochs (the first data point is testing the untrained model).
I suspect/hope that if you set the random seed the same it would come out exactly the same. I tried the optimizer in this configuration with some very basic tf/torch models and the loss was exactly the same. So I'm happy to close. Phew!
|
@DrJonnyT That's some impressive investigative work! :) Very nice!! Did you also check the attentive models? Perhaps its worthwhile to do that too? |
@DrJonnyT That's amazing. This is a super good check. :) I think the convolutional models do not line up exactly because TF adopts a channels-last convention whether PyTorch is channels-first, so you may need to reorder the convolutional weights to get equality. |
@wesselb Cool, I've made a minor tweak to that gist and now it works for convgnp as well 👍 |
@DrJonnyT That's super good. How would you like it if I were to link the gist from the documentation, because I think this is a super important check? A more ambitious plan would be to turn it into a unit test for the library, but that might not be so simple |
@wesselb Sounds good! |
I've been converting my code from tensorflow to pytorch and it's much easier to get it training faster. However, the performance after n epochs is worse in torch. After lots of digging, it seems like the model architectures come out different for AGNP? But it doesn't seem to be an issue for a GNP:
The text was updated successfully, but these errors were encountered: