Skip to content

Conversation

xiaohu2015
Copy link
Contributor

@xiaohu2015 xiaohu2015 commented Nov 26, 2021

The pr is about #4509.
Since amp is supported on classification and detection training, I also modify some files to support amp training on segmentation models.

cc @datumbox

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Nov 26, 2021

💊 CI failures summary and remediations

As of commit aab4e40 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @xiaohu2015.

Could you please attach two logs one with --amp and one without it (1 epoch would do fine) that shows the reference scripts still works fine? We don't currently have tests to cover those scripts so that's a final check we do to ensure the code didn't brake somewhere obviously.

@xiaohu2015
Copy link
Contributor Author

xiaohu2015 commented Nov 26, 2021

@datumbox
train fcn_resnet50:

torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet50 --aux-loss

no amp:

Epoch: [0]  [2770/2891]  eta: 0:01:05  lr: 0.01942397522379751  loss: 0.4647 (0.8574)  time: 0.5327  data: 0.0057  max mem: 4723
Epoch: [0]  [2780/2891]  eta: 0:00:59  lr: 0.01942189305490391  loss: 0.5236 (0.8569)  time: 0.5326  data: 0.0059  max mem: 4723
Epoch: [0]  [2790/2891]  eta: 0:00:54  lr: 0.019419810861207375  loss: 0.7463 (0.8567)  time: 0.5328  data: 0.0060  max mem: 4723
Epoch: [0]  [2800/2891]  eta: 0:00:49  lr: 0.019417728642704662  loss: 0.6900 (0.8564)  time: 0.5331  data: 0.0057  max mem: 4723
Epoch: [0]  [2810/2891]  eta: 0:00:43  lr: 0.019415646399392514  loss: 0.5944 (0.8553)  time: 0.5350  data: 0.0059  max mem: 4723
Epoch: [0]  [2820/2891]  eta: 0:00:38  lr: 0.019413564131267682  loss: 0.5944 (0.8550)  time: 0.5355  data: 0.0058  max mem: 4723
Epoch: [0]  [2830/2891]  eta: 0:00:32  lr: 0.01941148183832691  loss: 0.7028 (0.8544)  time: 0.5334  data: 0.0056  max mem: 4723
Epoch: [0]  [2840/2891]  eta: 0:00:27  lr: 0.019409399520566945  loss: 0.5951 (0.8541)  time: 0.5304  data: 0.0054  max mem: 4723
Epoch: [0]  [2850/2891]  eta: 0:00:22  lr: 0.019407317177984537  loss: 0.5951 (0.8544)  time: 0.5301  data: 0.0053  max mem: 4723
Epoch: [0]  [2860/2891]  eta: 0:00:16  lr: 0.01940523481057643  loss: 0.5810 (0.8537)  time: 0.5333  data: 0.0063  max mem: 4723
Epoch: [0]  [2870/2891]  eta: 0:00:11  lr: 0.01940315241833936  loss: 0.5390 (0.8529)  time: 0.5330  data: 0.0064  max mem: 4723
Epoch: [0]  [2880/2891]  eta: 0:00:05  lr: 0.019401070001270077  loss: 0.8151 (0.8530)  time: 0.5314  data: 0.0056  max mem: 4723
Epoch: [0]  [2890/2891]  eta: 0:00:00  lr: 0.019398987559365324  loss: 0.5705 (0.8526)  time: 0.5324  data: 0.0057  max mem: 4723
Epoch: [0] Total time: 0:26:04
Test:  [  0/625]  eta: 0:29:29    time: 2.8311  data: 2.7424  max mem: 4723
Test:  [100/625]  eta: 0:00:53    time: 0.0716  data: 0.0025  max mem: 4723
Test:  [200/625]  eta: 0:00:37    time: 0.0742  data: 0.0028  max mem: 4723
Test:  [300/625]  eta: 0:00:27    time: 0.0747  data: 0.0034  max mem: 4723
Test:  [400/625]  eta: 0:00:18    time: 0.0731  data: 0.0024  max mem: 4723
Test:  [500/625]  eta: 0:00:10    time: 0.0740  data: 0.0026  max mem: 4723
Test:  [600/625]  eta: 0:00:01    time: 0.0701  data: 0.0026  max mem: 4723
Test: Total time: 0:00:50
global correct: 89.7
average row correct: ['97.1', '64.6', '45.8', '62.7', '33.5', '27.8', '39.6', '43.1', '71.1', '11.5', '54.6', '9.0', '48.0', '62.8', '72.9', '85.1', '1.3', '46.4', '15.1', '48.6', '34.1'] 
IoU: ['89.2', '52.5', '41.5', '39.2', '29.4', '24.5', '37.7', '36.7', '54.1', '10.8', '34.6', '8.6', '38.1', '50.0', '59.1', '71.1', '1.3', '38.5', '14.4', '45.4', '28.2'] 
mean IoU: 38.3

use amp:

Epoch: [0]  [2810/2891]  eta: 0:01:27  lr: 0.019415646399392514  loss: 0.5286 (0.8925)  time: 1.0529  data: 0.0079  max mem: 2963
Epoch: [0]  [2820/2891]  eta: 0:01:16  lr: 0.019413564131267682  loss: 0.5774 (0.8921)  time: 1.0840  data: 0.0069  max mem: 2963
Epoch: [0]  [2830/2891]  eta: 0:01:05  lr: 0.01941148183832691  loss: 0.5774 (0.8914)  time: 1.0806  data: 0.0066  max mem: 2963
Epoch: [0]  [2840/2891]  eta: 0:00:54  lr: 0.019409399520566945  loss: 0.4950 (0.8905)  time: 1.0685  data: 0.0076  max mem: 2963
Epoch: [0]  [2850/2891]  eta: 0:00:44  lr: 0.019407317177984537  loss: 0.5319 (0.8895)  time: 1.0873  data: 0.0076  max mem: 2963
Epoch: [0]  [2860/2891]  eta: 0:00:33  lr: 0.01940523481057643  loss: 0.5323 (0.8890)  time: 1.0632  data: 0.0071  max mem: 2963
Epoch: [0]  [2870/2891]  eta: 0:00:22  lr: 0.01940315241833936  loss: 0.5072 (0.8883)  time: 1.0518  data: 0.0072  max mem: 2963
Epoch: [0]  [2880/2891]  eta: 0:00:11  lr: 0.019401070001270077  loss: 0.6067 (0.8880)  time: 1.0619  data: 0.0073  max mem: 2963
Epoch: [0]  [2890/2891]  eta: 0:00:01  lr: 0.019398987559365324  loss: 0.5897 (0.8872)  time: 1.0614  data: 0.0076  max mem: 2963
Epoch: [0] Total time: 0:51:50
Test:  [  0/625]  eta: 0:32:00    time: 3.0732  data: 2.9764  max mem: 2963
Test:  [100/625]  eta: 0:00:56    time: 0.0729  data: 0.0030  max mem: 2963
Test:  [200/625]  eta: 0:00:39    time: 0.0758  data: 0.0031  max mem: 2963
Test:  [300/625]  eta: 0:00:28    time: 0.0740  data: 0.0030  max mem: 2963
Test:  [400/625]  eta: 0:00:19    time: 0.0746  data: 0.0029  max mem: 2963
Test:  [500/625]  eta: 0:00:10    time: 0.0803  data: 0.0029  max mem: 2963
Test:  [600/625]  eta: 0:00:02    time: 0.0687  data: 0.0021  max mem: 2963
Test: Total time: 0:00:51
global correct: 89.5
average row correct: ['97.1', '61.6', '52.4', '60.3', '28.3', '21.1', '37.8', '43.8', '79.2', '19.5', '55.6', '5.0', '43.6', '59.4', '62.2', '84.3', '9.1', '50.4', '19.9', '40.8', '33.9']
IoU: ['89.0', '51.7', '43.3', '40.6', '26.5', '19.6', '36.4', '37.1', '52.3', '16.6', '36.6', '4.9', '37.7', '48.9', '57.1', '70.1', '8.5', '39.9', '18.5', '38.7', '29.0']
mean IoU: 38.2

the gpu memory consume reduces a lot, the mean IoU is nearly same. For the train time, maybe the amp is not suppored in the GPU card.

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution @xiaohu2015. Keep them coming! :)

@github-actions
Copy link

Hey @datumbox!

You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py

@xiaohu2015 xiaohu2015 deleted the patch-1 branch November 26, 2021 13:13
facebook-github-bot pushed a commit that referenced this pull request Nov 30, 2021
Summary:
* support amp training for segmention models

* fix lint

Reviewed By: NicolasHug

Differential Revision: D32694301

fbshipit-source-id: 904803e18783b70182409f048dea076a21c69c58

Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants