Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

about the train code, train_base.py of BAM is not used to this repo's src/model/pspnet.py #3

Closed
Nevaeh7 opened this issue Aug 30, 2023 · 7 comments

Comments

@Nevaeh7
Copy link

Nevaeh7 commented Aug 30, 2023

No description provided.

@Nevaeh7
Copy link
Author

Nevaeh7 commented Aug 30, 2023

Dear Sina,
Hello, I am very interested in your work, excellent work in the field of GFSS. But I found some problems while running the code I used train_base.py to train the model on my dataset, but found that the trained model can't be loaded in this repo Discovered that BAM's PSPNet and DIaM's PSPNet models are not written in the same way, which resulted in the model not being able to be loaded properly Can you please provide me with the training code that you were using?

@sinahmr
Copy link
Owner

sinahmr commented Aug 30, 2023

Hi,
Thanks for your interest in our work!

I used the same train_base.py code that you are using, but you're right, the module names in BAM differ from ours. To resolve this, what I did was to run the following script once on the BAM's trained model to align it with our naming:

from collections import OrderedDict
state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
    k = k.replace('module.', '')  # If BAM is run on multiple GPUs, 'module.' will precede keys, otherwise, we should add it ourselves
    if k.startswith('cls.4.'):
        k = k.replace('cls.4', 'classifier')
    elif k.startswith('cls.'):
        k = k.replace('cls', 'bottleneck')
    elif k.startswith('encoder.'):
        continue
    k = 'module.' + k
    state_dict[k] = v

filename = os.path.join(root, f'model.pth')
torch.save({'epoch': checkpoint['epoch'], 'state_dict': state_dict, 'optimizer': checkpoint['optimizer']}, filename)

You can paste it somewhere after the BAM model is loaded, for example after this line, run the code until the end of this snippet (maybe add an exit() after it). Then delete this snippet and run the code once again, but with model.pth as the checkpoint (setting ckpt_used in the config file to model).

Let me know if this doesn't resolve the issue.

@Nevaeh7
Copy link
Author

Nevaeh7 commented Aug 30, 2023

Thanks for the answer! It solved perfectly the problem I was having. In principle, the code you provided loads the original pspnet model parameters into the structure corresponding to the current repo's pspnet. Very ingenious. Thank you again!

@sinahmr
Copy link
Owner

sinahmr commented Aug 31, 2023

No worries, I'm happy it helped!

@sinahmr sinahmr closed this as completed Aug 31, 2023
@silsgah
Copy link

silsgah commented Feb 27, 2024

@Nevaeh7 are you available for a quick talk please

@Jason-u
Copy link

Jason-u commented Mar 14, 2024

Hello, I would like to ask if I can use a network other than PSPNet, such as SegFormer from the MMSEG framework, since you mentioned in the readme that any model trained on the base class can be used. If so, how can I modify the code in your repository?

@sinahmr
Copy link
Owner

sinahmr commented Mar 15, 2024

Hi,
Sure, you should be able to use other networks. Our contribution is what is implemented in the src/classifier.py file. Functions in the Classifier class take as input base classes' prototypes, alongside support and query features, and perform the inference. You can use other networks to generate features for the support and query images, and also use its final classifier (Linear) weights as base prototypes. Our proposed approach happens from this line to this line in src/test.py, so you can change the rest.

@sinahmr sinahmr reopened this Mar 15, 2024
@sinahmr sinahmr closed this as completed Mar 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants