Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loading Pre-trained custom model #39

Closed
ishangupta3 opened this issue Mar 17, 2020 · 4 comments
Closed

Loading Pre-trained custom model #39

ishangupta3 opened this issue Mar 17, 2020 · 4 comments

Comments

@ishangupta3
Copy link

I am following the project structure and the custom.yaml file, I would like to load the pre trained model myself. Before loading the model, I would need to instantiate the class like this below:

model = TheModelClass(*args, **kwargs)

According to the custom.yaml file; the 'task' is pvnet, but the network is res.
clean-pvnet/lib/networks/pvnet

Which class network would the appropriate one to use when loading the custom model in pytorch?

Would it be this one: Resnet18?
https://github.com/zju3dv/clean-pvnet/blob/07e7c09e80cde938e013d8bc2880605add33a0ce/lib/networks/pvnet/resnet18.py

Or this one: PoseResNet?
https://github.com/zju3dv/clean-pvnet/blob/07e7c09e80cde938e013d8bc2880605add33a0ce/lib/networks/resnet.py

My end goal is to just take my pretrained model that I trained; and create a brand new inference code pipeline, where I can customize it with my own usecase.

@pengsida
Copy link
Member

If you want to train the pvnet on your custom data, please follow these steps: https://github.com/zju3dv/clean-pvnet#training-on-the-custom-object

@ishangupta3
Copy link
Author

I already have the custom trained model already trained.

I would like to load it myself, I was asking which network class I should use to load the model.

@ishangupta3
Copy link
Author

Pytorch requires you to set up model in this format: (I already have the model trained on my own custom dataset) I would like to load the model myself.

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))

@pengsida
Copy link
Member

python run.py --type visualize --cfg_file configs/custom.yaml
The network is defined in configs/custom.yaml.
The network is https://github.com/zju3dv/clean-pvnet/blob/master/lib/networks/pvnet/__init__.py#L5

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants