Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to modify pre-trained models? #23

Closed
DWhettam opened this issue Jul 6, 2019 · 7 comments
Closed

How to modify pre-trained models? #23

DWhettam opened this issue Jul 6, 2019 · 7 comments
Assignees
Labels
question Further information is requested

Comments

@DWhettam
Copy link

DWhettam commented Jul 6, 2019

Is there a good way to go about modifying the pre-trained models? I want to tweak the forward() function to return activations at a few layers.
I'm going to be comparing between several CIFAR models so training them all myself isn't really viable.

Thanks!

@osmr osmr self-assigned this Jul 7, 2019
@osmr osmr added the question Further information is requested label Jul 7, 2019
@osmr
Copy link
Owner

osmr commented Jul 7, 2019

Hi. There are many ways. The easiest and dirtiest one is an instance method replacement (see a discussion). Example (for PyTorch):

from pytorchcv.model_provider import get_model as ptcv_get_model
import torch
import types
net = ptcv_get_model("resnet20_cifar10", pretrained=True)
def my_forward(self, x):
    outs = []
    for module in self.features._modules.values():
        x = module(x)
        outs.append(x)
    x = x.view(x.size(0), -1)
    x = self.output(x)
    return x, outs
net.forward = types.MethodType(my_forward, net)
x = torch.randn(1, 3, 32, 32)
y = net(x)

@DWhettam
Copy link
Author

DWhettam commented Jul 7, 2019

That seems to work well, thanks!
Do you know best way to go about saving this modified model and loading it for inference elsewhere? I realise this may be a little out of the scope of your repo so no worries if you don't have an answer.
Thanks again

@osmr
Copy link
Owner

osmr commented Jul 7, 2019

First of all, if you modify only forward(), then the weights don't change. Resaving the model isn't necessary.
Secondly, ptcv_get_model() caches the weights file in the ~/.torch/models folder. You can always use standard methods for loading/saving models. Example:

from pytorchcv.model_provider import get_model as ptcv_get_model
import torch
import os
net = ptcv_get_model("resnet20_cifar10", pretrained=False)
net.load_state_dict(torch.load(os.path.expanduser("~/.torch/models/resnet20_cifar10-0597-9b0024ac.pth")))
torch.save(obj=net.state_dict(), f="tmp.pth")

@DWhettam
Copy link
Author

DWhettam commented Jul 8, 2019

Great, thanks for the help!

@DWhettam DWhettam closed this as completed Jul 8, 2019
@DWhettam
Copy link
Author

Hey, sorry to re-open this - I've got a related question if that's okay? I'm trying to run my model on multiple gpus with the modified forward function as above, but keep getting the following error :
RuntimeError: Expected tensor for argument #1 'input' to have the same device as tensor for argument #2 'weight'; but device 1 does not equal 0 (while checking arguments for cudnn_convolution)
My code is as follows:

import torch
import types
net = ptcv_get_model("densenet40_k12_cifar10", root = 'loc', pretrained=True)
def my_forward(self, x):
    activations = []
    for module in self.features._modules.values():
        x = module(x) #error happens here
        activations .append(x)
    x = x.view(x.size(0), -1)
    x = self.output(x)
    return x, outs

net.forward = types.MethodType(my_forward, net)

if torch.cuda.device_count() > 1:
    net = nn.Dataparallel(teach, device_ids=[0,1,2,3]
net.to(device)
net.eval() ```

the error is happening on the line ``` x = module(x) ```. Do you have any ideas on how to go about fixing this? I've seen [this](https://github.com/pytorch/pytorch/issues/8637), but I'm not sure how it really helps in my use case. Thanks!

@DWhettam DWhettam reopened this Jul 10, 2019
@osmr
Copy link
Owner

osmr commented Jul 10, 2019

Hi, This is a very good question for the https://discuss.pytorch.org/.
I am sure that there you will receive an exhaustive answer, why your code will not work, even if you fix a bug with the incorrect placement of the network input and the network itself on the devices;)

@DWhettam
Copy link
Author

Great, I'll post there. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants