Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loss becoming nan while training on large batch size on COCO #60

Closed
subhankar-ghosh opened this issue Jul 22, 2020 · 10 comments
Closed

Comments

@subhankar-ghosh
Copy link

Hey @rwightman ,
I am trying to train EfficientDet D0 model on COCO from scratch, it works perfectly and converges when I use your settings:
./distributed_train.sh 4 /mscoco --model efficientdet_d0 -b 22 --amp --lr .12 --sync-bn --opt fusedmomentum --warmup-epochs 5 --lr-noise 0.4 0.9 --model-ema --model-ema-decay 0.9999

But when I use a Larger Batch size setting like the following, loss becomes nan:

  1. Number of GPU=8 and Per GPU batch size=30 and lr=0.32 (Single Node)
  2. Number of GPU=32 and Per GPU batch size=30 and lr=1.39 (Multi Node) Both FP32 and AMP
  3. Number of GPU=32 and Per GPU batch size=30 and lr=0.86 (Multi Node) LR I chose midway between linear scaling and square-root scaling. Both FP32 and AMP.

In case of AMP there is a cascade of loss scaling and then the loss becomes nan, And this does not necessarily happen in the first few epochs.

Strangely enough the TF1 Google automl code base has no problem scaling linearly. Do you know what might be the problem here? Do I need to change the recipe when using a large batch size?

@rwightman
Copy link
Owner

@subhankar-ghosh that is odd, looks like you scaled the LR appropriately... are you using the exact raining script here on COCO with the standard model head (not changing it), and still using the other params in the command above? especially warmup and sync-bn?

The first version of this were very touchy, but that's mostly because there were issues in the model, head init, head currently does not get re-inited properly without manually doing it if you replace it after model creation.

@subhankar-ghosh
Copy link
Author

@rwightman , I took this commit and changed nothing:
8fc03d4f0090896e0814bc7ed0b8af744d0c1133 Port and updated D7 weights from official TF repo (53.1 mAP in PyTorch). Fix a train arg desc.

My training script is the same as yours:

python -m torch.distributed.launch \
    --nproc_per_node=8 \
    --nnodes=\${NGC_ARRAY_SIZE} \
    --node_rank=\${NGC_ARRAY_INDEX} \
    --master_addr=\${NGC_MASTER_ADDR} \
    --master_port=49152 \
    train.py \
    /workspace/object_detection/datasets/coco \
    --model efficientdet_d0 \
    -b 30 \
    --lr .86 \
    --sync-bn \
    --amp \
    --opt fusedmomentum \
    --warmup-epochs 5 \
    --lr-noise 0.4 0.9 \
    --fill-color mean \
    --model-ema \
    --model-ema-decay 0.99'"

@subhankar-ghosh
Copy link
Author

Very strange, batch size 30 lr 0.7 with amp these configs do not lead to nan loss, and got mAP of 33.14 after 300 epochs. @rwightman any idea what is happening?

@rwightman
Copy link
Owner

@subhankar-ghosh nope, I have no idea, I don't have access to those sorts of resources so no ability to test, I'm actually suprised it works at all with 4-nodes, 8 GPU per node. There are limits to the scaling wrt to batch and LR, I'm not sure if that's being run into or if their are other issues. Does single node 8-gpu training definitely have stability issues, someone else was doing somethign like that and didn't mention any issues.

@wangraying
Copy link

wangraying commented Jul 27, 2020

I have the same problem while training with
--model tf_efficientdet_d0 \ -b 32 \ --opt sgd \ --epochs 300 \ --warmup-epochs 5 \ --lr 0.32 \

and I found the output x of backbone network becomes all nan:

class EfficientDet(nn.Module):

    def __init__(self, config, norm_kwargs=None, pretrained_backbone=True, alternate_init=False):
        super(EfficientDet, self).__init__()
        ...

    def forward(self, x):
        x = self.backbone(x)   ### x becomes all nan here
        x = self.fpn(x)
        x_class = self.class_net(x)
        x_box = self.box_net(x)
        return x_class, x_box

@subhankar-ghosh
Copy link
Author

@wangraying From my experiments I found that as I was scaling the LR linearly with batch size, I also had to increase warm-up epochs either linearly or at least square root linearly with batch size. This prevents loss from becoming nan. May be this might work for you too.

@rwightman
Copy link
Owner

rwightman commented Jul 27, 2020

@wangraying don't use 'sgd' as the opt string, it is not stable with the default hparams, for legacy reasons 'sgd' with my optmizer factory is sgd + nesterov, 'momentum' is sgd without nesterov (after the and that's what the hparams from official paper/impl were based on... also the official version does warmup ramp per step, I only ramp per epoch, one could try a different scheduler.

I'm going to close this issue now. I don't think there is any major bug or defect here (please let me know if one is found), just the usual hparam tuning and the default hparams for these models are on the edge of stability.

@pichuang1984
Copy link

pichuang1984 commented Aug 1, 2020

@subhankar-ghosh I also noticed that for large-scale, multi-node training increasing the number of epoch for warmup seems to avoid the NaN (divergence) issue, however sometimes later in epochs the training will still diverge. Have you found a set of hyper-params that can reach 33.6mAP for d0 under 32-gpu setting? So far the best I can do is to get 33.5mAP with 8GPU, and this just linearly scales the LR.

@subhankar-ghosh
Copy link
Author

Hey @pichuang1984 , The best results I have got with d0 32-gpu setting is 33.35mAP. I simply linearly scaled the LR and warmup and ema-decay 0.999 instead of 0.9999. Increasing ema-decay a little bit actually might lead to convergence faster.

@subhankar-ghosh
Copy link
Author

@rwightman what is the role of --dist-bn parameter? I am currently using its default value '', is it supposed to be like that?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants