-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
train error: RuntimeError: the derivative for 'weight' is not implemented #6
Comments
I guess that the derivative for |
OK. Thanks so much. Best regards. |
After reading the post you recommended, i think it needs the derivative, but i dont know how to implement binary_cross_entropy_with_logits. Could you give some suggestions ? |
To my knowledge, you dont need the derivative for focal loss weight. So just detach it. |
@thangvubk I have changed the w with w.detach() in focal_loss.py .Then it runs well. Then i will test the model to check if it reproduces well. |
Hi, thanks for your outstanding work.
Problem:
I met an error when i finetune with pretrained model. RuntimeError: the derivative for 'weight' is not implemented. The details are as follows.
After pre-training with L1 loss using DIV2K dataset (200 epochs), I plan to finetune on the pre-trained model with GAN (200 epochs).
The loss of pretraining with L1 loss is here.
< Pretrain >
Model check_point/my_model. Epoch [200/200]. Learning rate: 5e-05
100%|██████████████████████████████████████████████████████████████████████████████████| 1000/1000 [02:08<00:00, 7.77it/s]
Finish train [200/200]. Loss: 5.98
Validating...
100%|██████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:06<00:00, 1.62it/s]
Finish valid [200/200]. Best PSNR: 27.5345dB. Cur PSNR: 27.4683dB
Followings are my command of pretrain .
< Pretrain >
python train.py --phase pretrain --learning_rate 1e-4
YOUR SETTINGS
fl_gamma: 1
valid_dataset: PIRM
num_epochs: 200
gan_type: RSGAN
check_point: check_point/my_model
spectral_norm: False
batch_size: 16
alpha_vgg: 50
res_scale: 0.1
focal_loss: True
lr_step: 120
snapshot_every: 10
GP: False
scale: 4
train_dataset: DIV2K
alpha_tv: 1e-06
learning_rate: 0.0001
alpha_gan: 1
pretrained_model:
num_valids: 10
num_channels: 256
num_repeats: 20
patch_size: 24
phase: pretrain
num_blocks: 32
alpha_l1: 0
Then i use pretrained model of best_model.pt saved in check_point/my_modedl/pretrain to finetune the model with GAN . It gave the error of RuntimeError: the derivative for 'weight' is not implemented.
Command of finetune .
python train.py --pretrained_model check_point/my_model/pretrain/best_model.pt
YOUR SETTINGS
num_repeats: 20
GP: False
spectral_norm: False
snapshot_every: 10
num_epochs: 200
gan_type: RSGAN
num_channels: 256
lr_step: 120
alpha_vgg: 50
num_blocks: 32
alpha_l1: 0
phase: train
num_valids: 10
batch_size: 16
focal_loss: True
valid_dataset: PIRM
pretrained_model: check_point/my_model/pretrain/best_model.pt
scale: 4
res_scale: 0.1
alpha_gan: 1
train_dataset: DIV2K
check_point: check_point/my_model
alpha_tv: 1e-06
learning_rate: 5e-05
patch_size: 24
fl_gamma: 1
Loading dataset...
Loading model using 1 GPU(s)
Fetching pretrained model check_point/my_model/pretrain/best_model.pt
Model check_point/my_model/train. Epoch [1/200]. Learning rate: 5e-05
0%| | 0/1000 [00:02<?, ?it/s]
Traceback (most recent call last):
File "train.py", line 323, in
main()
File "train.py", line 257, in main
G_loss = f_loss_fn(pred_fake - pred_real, target_real) #Focal loss
File "/home/anaconda3/lib/python3.5/site-packages/torch/nn/modules/module.py", line 489, in call
result = self.forward(*input, **kwargs)
File "/home/PESR-master/model/focal_loss.py", line 13, in forward
return F.binary_cross_entropy_with_logits(x, t, w)
File "/home/anaconda3/lib/python3.5/site-packages/torch/nn/functional.py", line 2077, in binary_cross_entropy_with_logits
return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum)
RuntimeError: the derivative for 'weight' is not implemented
Does anyone have some idea about this problem?
The text was updated successfully, but these errors were encountered: