Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No distributed view in tensorboard #640

Closed
woolpeeker opened this issue Aug 2, 2022 · 7 comments
Closed

No distributed view in tensorboard #640

woolpeeker opened this issue Aug 2, 2022 · 7 comments
Labels
bug Something isn't working plugin PyTorch Profiler TensorBoard Plugin related

Comments

@woolpeeker
Copy link

woolpeeker commented Aug 2, 2022

The python code is below. I use slurm sbatch to start it in the cluster.
backend is nccl.
The generated .json files seems normal.
There are "Overview" and other views except "Distributed", which is exactly what I need.
There are some error message from tensorboar. I paste part of them below because the rests are replications.

environment:
Python==3.8.13
torch==1.12.0
tensorboard==2.9.1
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
torch-tb-profiler==0.4.0

import logging
import os
import argparse
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import utils
from utils import info, error
import time
import traceback


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', default='convnext_xlarge', type=str, 
                        help='model name')
    parser.add_argument('--bs', default=256, type=int, 
                        help='batch size')
    parser.add_argument('--input_shape', default=224, type=int, 
                        help='input shape')
    parser.add_argument('--sample_num', default=10240, type=int, 
                        help='test iteration numbers.')
    parser.add_argument('--profiler', default=False, type=bool, 
                        help='enable pytorch profiler')
    parser.add_argument('--out_dir', default='results', type=str, 
                        help='output dir')
    args = parser.parse_args()
    return args

def main(args):
    utils.slurm_init()
    log_file = Path(args.out_dir) / 'log.txt'
    utils.set_logger(rank=dist.get_rank(), log_file=log_file)
    info('dist.init finished')
    _gpu_idxes = [int(x) for x in os.environ['CUDA_VISIBLE_DEVICES'].split(',')]
    node_name = os.environ["SLURMD_NODENAME"]
    logging.info(f'Host-{node_name}, GPU-{_gpu_idxes[torch.cuda.current_device()]}')

    info(f"command arguments: {args}")
    
    ### calibrate batch size and iter_num ##
    assert args.sample_num % args.bs == 0
    iter_num = args.sample_num // args.bs
    assert args.bs % dist.get_world_size() == 0
    bs = args.bs // dist.get_world_size()

    ### model ###
    model = utils.get_model(args.model)
    model.cuda()
    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = torch.nn.parallel.DistributedDataParallel(model)
    info('model created')
        
    ### optimizer ###
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-5)
    scaler = torch.cuda.amp.GradScaler()

    torch.backends.cudnn.benchmark = True
    
    ### main loop ###
    info('training iteration starts')
    tic = time.time()

    prof = torch.profiler.profile(
        activities=[ 
            torch.profiler.ProfilerActivity.CPU, 
            torch.profiler.ProfilerActivity.CUDA],
        schedule=torch.profiler.schedule(skip_first=10, wait=10, warmup=8, active=2, repeat=3),
        on_trace_ready=torch.profiler.tensorboard_trace_handler(args.out_dir, worker_name=f'worker_{dist.get_rank()}'),
        record_shapes=True,
        profile_memory=True,
        with_stack=True)
    prof.start()

    for i in range(iter_num):
        if i % 10 == 0:
            info(f'iteraton: {i}')
        inputs = torch.randn([bs, 3, 224, 224]).cuda()
        targets = torch.argmax(torch.randn([bs, 21841]), dim=1).cuda()
        with torch.cuda.amp.autocast():
            outputs = model(inputs)

        loss = F.cross_entropy(outputs, targets)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        prof.step()

    prof.stop()
    toc = time.time()
    total_time = toc - tic
    fps = args.sample_num / total_time
    info(f'total_time: {total_time}')
    info(f'FPS: {fps:.3f}')
    info(f'max_memory_reserved: {torch.cuda.max_memory_reserved():,}')
    info(f'max_memory_allocated: {torch.cuda.max_memory_allocated():,}')

if __name__ == '__main__':
    try: # to print CUDA OOM to log.txt
        args = parse_args()
        main(args)
    except:
        error(traceback.format_exc())
        exit(-1)
TensorFlow installation not found - running with reduced feature set.

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

I0802 16:36:47.627506 123145378619392 plugin.py:429] Monitor runs begin
I0802 16:36:47.628741 123145378619392 plugin.py:444] Find run directory /Users/luojiapeng/Documents/projects/BigModelAccelerate/profiler_tb/convnext_xlarge_64_224
I0802 16:36:47.629572 123145412198400 plugin.py:493] Load run convnext_xlarge_64_224
I0802 16:36:47.973850 123145412198400 loader.py:57] started all processing
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.9.1 at http://localhost:6006/ (Press CTRL+C to quit)
W0802 16:37:02.206120 123145429524480 security_validator.py:46] In 3.0, this warning will become an error:
Requires default-src for Content-Security-Policy
WARNING: Logging before flag parsing goes to stderr.
W0802 16:37:06.869280 4427961856 loader.py:102] Failed to parse profile data for Run convnext_xlarge_64_224 on SH-IDC1-10-5-37-22_20298. Exception=Expecting ',' delimiter: line 418103 column 23 (char 25471711)
Traceback (most recent call last):
  File "/Users/luojiapeng/miniconda3/lib/python3.8/site-packages/torch_tb_profiler/profiler/data.py", line 126, in _preprocess_file
    trace_json = json.loads(data)
  File "/Users/luojiapeng/miniconda3/lib/python3.8/json/__init__.py", line 357, in loads
    return _default_decoder.decode(s)
  File "/Users/luojiapeng/miniconda3/lib/python3.8/json/decoder.py", line 337, in decode
    obj, end = self.raw_decode(s, idx=_w(s, 0).end())
  File "/Users/luojiapeng/miniconda3/lib/python3.8/json/decoder.py", line 353, in raw_decode
    obj, end = self.scan_once(s, idx)
json.decoder.JSONDecodeError: Expecting ',' delimiter: line 418103 column 23 (char 25471711)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/luojiapeng/miniconda3/lib/python3.8/site-packages/torch_tb_profiler/profiler/data.py", line 131, in _preprocess_file
    trace_json = json.loads(data, strict=False)
  File "/Users/luojiapeng/miniconda3/lib/python3.8/json/__init__.py", line 370, in loads
    return cls(**kw).decode(s)
  File "/Users/luojiapeng/miniconda3/lib/python3.8/json/decoder.py", line 337, in decode
    obj, end = self.raw_decode(s, idx=_w(s, 0).end())
  File "/Users/luojiapeng/miniconda3/lib/python3.8/json/decoder.py", line 353, in raw_decode
    obj, end = self.scan_once(s, idx)
json.decoder.JSONDecodeError: Expecting ',' delimiter: line 418103 column 23 (char 25471711)
@aaronenyeshi aaronenyeshi added the bug Something isn't working label Nov 1, 2022
@aaronenyeshi
Copy link
Member

A new torch-tb-profiler v0.4.1 has been released. Could you please check if your issue is fixed there: https://pypi.org/project/torch-tb-profiler/0.4.1

@arthurfeeney
Copy link

Same thing is happening to me on 0.4.1: I do a DDP training run, and the output traces include calls to DistributedDataParallel.forward and nccl allReduce. So it seems like it should be showing the distributed view.

But when I load the traces into tensorboard, all of the views appear except for distributed. It also shows the calls to all reduce and DDP in the trace, so I think the tool is seeing them...
tried with

pytorch==2.0 and 1.13.
tensorboard==2.11.2
torch_tb_profiler==0.4.1

@arthurfeeney
Copy link

arthurfeeney commented Jan 27, 2023

So it seems that tb_plugin looks for "DistributedDataParallel.forward" to see if DDP is being used. This is a user annotation, but the tb_plugin seems to not create events for all user annotations?

It looks for DDP here: https://github.com/pytorch/kineto/blob/main/tb_plugin/torch_tb_profiler/profiler/event_parser.py#L199, but in my trace, the "DistributedDataParallel.forward" is in a user_annotation:

  {
    "ph": "X", "cat": "user_annotation", "name": "DistributedDataParallel.forward", "pid": 208187, "tid": 208187,
    "ts": 1674842868181040, "dur": 1791,
    "args": {
      "External id": 9372,"Ev Idx": 25561
    }
  },

It also looks like user_annotations do not all have trace events constructed for them: https://github.com/pytorch/kineto/blob/main/tb_plugin/torch_tb_profiler/profiler/trace.py#L187 Just adding a bunch of print statements, it definitely never sets use_ddp = True or setting the comm lib. (Or, at least, that is my best guess!)

@srikanthmalla
Copy link

srikanthmalla commented Feb 24, 2023

without downgrading torch to 1.11.0, distributed view doesn't work (even in the latest torch-tb-profiler v0.4.1).
to downgrade torch to right version: pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113

@aaronenyeshi
Copy link
Member

aaronenyeshi commented Feb 24, 2023

Thank you for the issue report and updates. Fixed the issue in PR #717, but I missed releasing a new torch-tb-profiler package. Let me create a PR to get that going: #732 cc @woolpeeker , @arthurfeeney , @srikanthmalla

@arthurfeeney
Copy link

I think the distributed view still has issues. I was able to hack around some and at least got the distributed view to appear by making some changes similar to the ones you made in #717. (basically just forcing it to look at user annotations for all reduce and DistributedDataParallel.forward). However, even with that change, the view still seemed to not have all of the needed information: For instance, latency breakdown was not showing communication time.
I guess there's some other user_annotations it needs that I don't know about :-/

facebook-github-bot pushed a commit that referenced this issue Jun 1, 2023
Summary:
It's one of the reasons of missing Distributed View #640

aaronenyeshi mrshenli kwen2501

Pull Request resolved: #763

Reviewed By: xuzhao9

Differential Revision: D46370583

Pulled By: aaronenyeshi

fbshipit-source-id: 064212a7890676a7724a81aa343c30b0a53bd408
@aaronenyeshi aaronenyeshi added the plugin PyTorch Profiler TensorBoard Plugin related label Jun 23, 2023
@npuichigo
Copy link

does 1.13.1 work for distributed view?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working plugin PyTorch Profiler TensorBoard Plugin related
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants