Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

train error: RuntimeError: the derivative for 'weight' is not implemented #6

Open
Marshall-yao opened this issue Sep 11, 2019 · 5 comments

Comments

@Marshall-yao
Copy link

Marshall-yao commented Sep 11, 2019

Hi, thanks for your outstanding work.

Problem:
I met an error when i finetune with pretrained model. RuntimeError: the derivative for 'weight' is not implemented. The details are as follows.

After pre-training with L1 loss using DIV2K dataset (200 epochs), I plan to finetune on the pre-trained model with GAN (200 epochs).

The loss of pretraining with L1 loss is here.

< Pretrain >
Model check_point/my_model. Epoch [200/200]. Learning rate: 5e-05
100%|██████████████████████████████████████████████████████████████████████████████████| 1000/1000 [02:08<00:00, 7.77it/s]
Finish train [200/200]. Loss: 5.98
Validating...
100%|██████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:06<00:00, 1.62it/s]
Finish valid [200/200]. Best PSNR: 27.5345dB. Cur PSNR: 27.4683dB

Followings are my command of pretrain .
< Pretrain >
python train.py --phase pretrain --learning_rate 1e-4
YOUR SETTINGS
fl_gamma: 1
valid_dataset: PIRM
num_epochs: 200
gan_type: RSGAN
check_point: check_point/my_model
spectral_norm: False
batch_size: 16
alpha_vgg: 50
res_scale: 0.1
focal_loss: True
lr_step: 120
snapshot_every: 10
GP: False
scale: 4
train_dataset: DIV2K
alpha_tv: 1e-06
learning_rate: 0.0001
alpha_gan: 1
pretrained_model:
num_valids: 10
num_channels: 256
num_repeats: 20
patch_size: 24
phase: pretrain
num_blocks: 32
alpha_l1: 0

Then i use pretrained model of best_model.pt saved in check_point/my_modedl/pretrain to finetune the model with GAN . It gave the error of RuntimeError: the derivative for 'weight' is not implemented.

Command of finetune .
python train.py --pretrained_model check_point/my_model/pretrain/best_model.pt

YOUR SETTINGS
num_repeats: 20
GP: False
spectral_norm: False
snapshot_every: 10
num_epochs: 200
gan_type: RSGAN
num_channels: 256
lr_step: 120
alpha_vgg: 50
num_blocks: 32
alpha_l1: 0
phase: train
num_valids: 10
batch_size: 16
focal_loss: True
valid_dataset: PIRM
pretrained_model: check_point/my_model/pretrain/best_model.pt
scale: 4
res_scale: 0.1
alpha_gan: 1
train_dataset: DIV2K
check_point: check_point/my_model
alpha_tv: 1e-06
learning_rate: 5e-05
patch_size: 24
fl_gamma: 1

Loading dataset...
Loading model using 1 GPU(s)
Fetching pretrained model check_point/my_model/pretrain/best_model.pt
Model check_point/my_model/train. Epoch [1/200]. Learning rate: 5e-05
0%| | 0/1000 [00:02<?, ?it/s]
Traceback (most recent call last):
File "train.py", line 323, in
main()
File "train.py", line 257, in main
G_loss = f_loss_fn(pred_fake - pred_real, target_real) #Focal loss
File "/home/anaconda3/lib/python3.5/site-packages/torch/nn/modules/module.py", line 489, in call
result = self.forward(*input, **kwargs)
File "/home/PESR-master/model/focal_loss.py", line 13, in forward
return F.binary_cross_entropy_with_logits(x, t, w)
File "/home/anaconda3/lib/python3.5/site-packages/torch/nn/functional.py", line 2077, in binary_cross_entropy_with_logits
return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum)
RuntimeError: the derivative for 'weight' is not implemented

Does anyone have some idea about this problem?

@thangvubk
Copy link
Owner

I guess that the derivative for weight is not supported in binary_cross_entropy. You can follow this post.

@Marshall-yao
Copy link
Author

OK. Thanks so much.
I will have a try.

Best regards.

@Marshall-yao
Copy link
Author

@thangvubk

After reading the post you recommended, i think it needs the derivative, but i dont know how to implement binary_cross_entropy_with_logits.

Could you give some suggestions ?

@thangvubk
Copy link
Owner

To my knowledge, you dont need the derivative for focal loss weight. So just detach it.

@Marshall-yao
Copy link
Author

@thangvubk
Thanks so much. I misunderstood the post you commented.

I have changed the w with w.detach() in focal_loss.py .Then it runs well.

Then i will test the model to check if it reproduces well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants