Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP very slow on multi-node training #102434

Closed
JulioZhao97 opened this issue May 27, 2023 · 28 comments
Closed

FSDP very slow on multi-node training #102434

JulioZhao97 opened this issue May 27, 2023 · 28 comments
Labels
module: fsdp triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@JulioZhao97
Copy link

JulioZhao97 commented May 27, 2023

馃悰 Describe the bug

When I try to train model using torch.distributed.FullyShardedDataParallel, I found that :
when training using single-node multi-gpu (1x8A100), the training speed is normal.
when training using multi-node multi-gpu(2x8A100 or 4x8A100), the training speed is very slow.
My FSDP code is as follows:

def wrap_model_using_fsdp(self):
        params_no_grad = [n for n, p in self._model.named_parameters() if not p.requires_grad]
        
        from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
            
        if len(params_no_grad) > 0:
            print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'. format(len(params_no_grad), params_no_grad))
            print("[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental.")
            print("[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build.  See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining")

            def patch_FSDP_use_orig_params(func):
                def wrap_func(*args, **kwargs):
                    use_orig_params = kwargs.pop('use_orig_params', True)
                    return func(*args, **kwargs, use_orig_params=use_orig_params)
                return wrap_func

            FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__)

        from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision, ShardingStrategy
                
        dtype = torch.float16
        mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)
        device_id = int(os.environ['RANK']) % torch.cuda.device_count()
                
        def get_module_class_from_name(module, name):
            modules_children = list(module.children())
            if module.__class__.__name__ == name:
                return module.__class__
            elif len(modules_children) == 0:
                return
            else:
                for child_module in modules_children:
                    module_class = get_module_class_from_name(child_module, name)
                    if module_class is not None:
                        return module_class
                
        from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
        import functools
        transformer_cls_to_wrap = set()
        for layer_class in ['LlamaDecoderLayer']:
            transformer_cls = get_module_class_from_name(self._model, layer_class)
            if transformer_cls is None:
                raise Exception("Could not find the transformer layer class to wrap in the model.")
            else:
                transformer_cls_to_wrap.add(transformer_cls)
        auto_wrap_policy = functools.partial(
            transformer_auto_wrap_policy,
            # Transformer layer class to wrap
            transformer_layer_cls=transformer_cls_to_wrap,
        )
                
        self._wrapped_model = self._model = FSDP(
            self._model,
            sharding_strategy=ShardingStrategy.FULL_SHARD,
            cpu_offload=CPUOffload(offload_params=False),
            mixed_precision=mixed_precision_policy,
            auto_wrap_policy=auto_wrap_policy,
            device_id=f'cuda:{device_id}'
        )

I print out the training speed, results as as follows
(three lines, first line is load data time, second is model inference and calculate loss time, last is backward() time):
First is the speed using 4x8A100, the model inference is very slow.

====================
1.2894313335418701
6.707820892333984
0.387603759765625
====================
0.5644333362579346
7.291527271270752
0.38623809814453125
====================
0.00037860870361328125
8.131412506103516
0.38402795791625977
====================
0.7470672130584717
7.245628356933594
0.3820772171020508
====================
0.6294591426849365
7.050528049468994
0.38535380363464355
====================

Then is the speed using 1x8A100, the model inference is perfectly normal:

====================
0.0004448890686035156
2.280407190322876
0.12227678298950195
====================
0.5785701274871826
0.7706379890441895
0.11761355400085449
====================
1.5213301181793213
1.0720155239105225
0.1305837631225586
====================
0.7982532978057861
1.343599796295166
0.1258220672607422
====================
0.5638444423675537
1.0778486728668213
0.36525917053222656
====================
0.617133617401123
1.3708162307739258
0.11827230453491211
====================

could someone tell me why this is happening?
My test code:

        import time
        for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header):
            print('='*20)
            t1 = time.time()
            # if using iter-based runner, we stop after iters_per_epoch iterations.
            if i >= iters_per_epoch:
                break
            
            samples = next(data_loader)

            samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
            samples.update(
                {
                    "epoch": inner_epoch,
                    "num_iters_per_epoch": iters_per_epoch,
                    "iters": i,
                }
            )
            print(time.time() - t1)
            t2 = time.time()
            lr_scheduler.step(cur_epoch=inner_epoch, cur_step=i)

            with torch.cuda.amp.autocast(enabled=use_amp):
                loss = self.train_step(model=model, samples=samples)

            # after_train_step()
            if use_amp:
                scaler.scale(loss).backward()
            else:
                loss.backward()
                
            print(time.time() - t2)
            t3 = time.time()
            # update gradients every accum_grad_iters iterations
            if (i + 1) % accum_grad_iters == 0:
                if use_amp:
                    scaler.step(optimizer)
                    scaler.update()                     
                else:    
                    optimizer.step()
                optimizer.zero_grad()

            metric_logger.update(loss=loss.item())
            metric_logger.update(lr=optimizer.param_groups[0]["lr"])
            print(time.time() - t3)

Versions

My versions:

pytorch-mutex             1.0                        cuda    pytorch
pytorch-triton            2.1.0+7d1a95b046          pypi_0    pypi
torch                     2.1.0.dev20230517+cu117          pypi_0    pypi
torchvision               0.16.0.dev20230517+cu117          pypi_0    pypi

My driver version:
浼佷笟寰俊鎴浘_16851761668393

cc @zhaojuanmao @mrshenli @rohan-varma @awgu

@awgu
Copy link
Contributor

awgu commented May 27, 2023

Maybe your workload is communication bound, and when you go from single-node to multi-node, FSDP's communications (all-gather / reduce-scatter) and heavily exposed on the critical path?

I would recommend you collect profiler traces for both the single-node and multi-node cases.
https://pytorch.org/docs/stable/profiler.html


By the way, regarding:

(three lines, first line is load data time, second is model inference and calculate loss time, last is backward() time):

If I am understanding correctly, the second time is actually the forward and backward time, and the third time is the optimizer step? (It would be strange if forward takes much longer than backward.)

@JulioZhao97
Copy link
Author

Maybe your workload is communication bound, and when you go from single-node to multi-node, FSDP's communications (all-gather / reduce-scatter) and heavily exposed on the critical path?

I would recommend you collect profiler traces for both the single-node and multi-node cases. https://pytorch.org/docs/stable/profiler.html

By the way, regarding:

(three lines, first line is load data time, second is model inference and calculate loss time, last is backward() time):

If I am understanding correctly, the second time is actually the forward and backward time, and the third time is the optimizer step? (It would be strange if forward takes much longer than backward.)

Yes, the second time is actually the forward and backward time. Can I ask what do FSDP's communications (all-gather / reduce-scatter) and heavily exposed on the critical path? means? Can you be more specific? Thanks

@JulioZhao97
Copy link
Author

JulioZhao97 commented May 28, 2023

Maybe your workload is communication bound, and when you go from single-node to multi-node, FSDP's communications (all-gather / reduce-scatter) and heavily exposed on the critical path?

I would recommend you collect profiler traces for both the single-node and multi-node cases. https://pytorch.org/docs/stable/profiler.html

By the way, regarding:

(three lines, first line is load data time, second is model inference and calculate loss time, last is backward() time):

If I am understanding correctly, the second time is actually the forward and backward time, and the third time is the optimizer step? (It would be strange if forward takes much longer than backward.)

I further tested the forward time and backward time, results are as follows:
multi-node:

==========
0.22862625122070312
4.046004056930542
==========
0.9854505062103271
3.9894392490386963
==========
0.24094605445861816
2.315075159072876
==========
0.24256587028503418
0.7100715637207031
==========
0.22773265838623047
2.1520705223083496
==========
0.23922944068908691
1.5446033477783203
==========
1.0994553565979004
1.3673655986785889
==========
1.2900798320770264
1.36488676071167
==========
0.2466275691986084
1.9220454692840576
==========
0.22912001609802246
2.4840309619903564
==========
1.028911828994751
3.7588601112365723
==========
0.8808095455169678
1.895629644393921
==========
1.0707406997680664
3.225971221923828
==========
0.8065776824951172
1.9871690273284912
==========
0.8572988510131836
1.0078165531158447
==========
0.23727655410766602
0.7042930126190186

single-node:

==========
0.16225576400756836
0.6536195278167725
==========
0.20273613929748535
0.5899288654327393
==========
0.4853992462158203
0.8575475215911865
==========
0.20531225204467773
0.6352689266204834
==========
0.17136907577514648
0.9528191089630127
==========
0.1586010456085205
0.6574850082397461
==========
0.17702102661132812
0.744593620300293
==========
0.20734310150146484
0.7262928485870361
==========
0.12151241302490234
0.7816987037658691
==========
0.2062222957611084
0.7138354778289795
==========
0.19709062576293945
0.9603440761566162
==========
0.20880985260009766
0.7303977012634277

according to my observation is seems that both inference and backward are slowed down in multi-node training.

@awgu
Copy link
Contributor

awgu commented May 28, 2023

FSDP uses collectives communications: all-gather for parameters and reduce-scatter for gradient reduction. The forward pass only uses all-gather, whereas the backward pass uses both all-gather and reduce-scatter (meaning twice as much communication as forward). If communication is exposed on the critical path, then it is not overlapped with computation. For multi-node, the communication may take longer due to using slower inter-node bandwidth, which may make communication more easily exposed.

The times you are getting look unintuitive to me. As I mentioned before, I would recommend getting a profiler trace. Then, it will be clear what is going on. In addition, I would not recommend using time.time() to time your program since it may not accurately capture GPU kernel execution. Instead, you can use CUDA events:
https://auro-227.medium.com/timing-your-pytorch-code-fragments-e1a556e81f2

@JulioZhao97
Copy link
Author

OK, thank you. I will try to calculate time and update soon. Could I please ask about what does critical path means?

@awgu
Copy link
Contributor

awgu commented May 28, 2023

Critical path refers to the ops that actually affect your end-to-end time. An op is not on the critical path if it is fully overlapped with other ops, which can happen since communication and computation can use separate GPU resources. For example, if FSDP can all-gather the 'next' layer's parameters before finishing the 'current' layer's forward computation, then that all-gather is not on the critical path because by the time we run the 'next' layer's forward computation, we already have the parameters materialized. On the other hand, if the 'next' all-gather takes longer than the 'current' computation, then the part that is not overlapped is exposed and delays the 'next' computation.

I highly recommend looking at profiler traces to make these ideas concrete.

@JulioZhao97
Copy link
Author

JulioZhao97 commented May 29, 2023

Critical path refers to the ops that actually affect your end-to-end time. An op is not on the critical path if it is fully overlapped with other ops, which can happen since communication and computation can use separate GPU resources. For example, if FSDP can all-gather the 'next' layer's parameters before finishing the 'current' layer's forward computation, then that all-gather is not on the critical path because by the time we run the 'next' layer's forward computation, we already have the parameters materialized. On the other hand, if the 'next' all-gather takes longer than the 'current' computation, then the part that is not overlapped is exposed and delays the 'next' computation.

I highly recommend looking at profiler traces to make these ideas concrete.

I further tested the speed using following code:

        print('Python VERSION:', sys.version)
        print('PyTorch VERSION:', torch.__version__)
        print('RCCL VERSION:', torch.cuda.nccl.version())
        print ('Available GPUs ', torch.cuda.device_count())
        for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header):
            # if using iter-based runner, we stop after iters_per_epoch iterations.
            if i >= iters_per_epoch:
                break
            
            samples = next(data_loader)

            samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
            samples.update(
                {
                    "epoch": inner_epoch,
                    "num_iters_per_epoch": iters_per_epoch,
                    "iters": i,
                }
            )
            lr_scheduler.step(cur_epoch=inner_epoch, cur_step=i)

            torch.cuda.synchronize()  # wait for move to complete
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()
            with torch.cuda.amp.autocast(enabled=use_amp):
                loss = self.train_step(model=model, samples=samples)
            torch.cuda.synchronize()  # wait for all_reduce to complete
            end.record()
            torch.cuda.synchronize()  # need to wait once more for op to finish
            print('-'*20)
            print(f"inference time: {start.elapsed_time(end)}")
                
            torch.cuda.synchronize()  # wait for move to complete
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()
            # after_train_step()
            if use_amp:
                scaler.scale(loss).backward()
            else:
                loss.backward()
            torch.cuda.synchronize()  # wait for all_reduce to complete
            end.record()
            torch.cuda.synchronize()  # need to wait once more for op to finish
            print(f"backward time: {start.elapsed_time(end)}")
            
            # update gradients every accum_grad_iters iterations
            if (i + 1) % accum_grad_iters == 0:
                if use_amp:
                    scaler.step(optimizer)
                    scaler.update()                     
                else:    
                    optimizer.step()
                optimizer.zero_grad()

            metric_logger.update(loss=loss.item())
            metric_logger.update(lr=optimizer.param_groups[0]["lr"])

Results are as follows:
single node (8xA100)

Python VERSION: 3.9.16 (main, Mar  8 2023, 14:00:05) 
[GCC 11.2.0]
PyTorch VERSION: 1.13.0+cu116
RCCL VERSION: (2, 14, 3)
Available GPUs  8
--------------------
inference time: 237.29110717773438
backward time: 655.5602416992188
--------------------
inference time: 203.00047302246094
backward time: 768.2235107421875
--------------------
inference time: 272.8787536621094
backward time: 683.4915771484375
--------------------
inference time: 306.6569519042969
backward time: 723.1790771484375
--------------------
inference time: 258.9256591796875
backward time: 862.365478515625
--------------------
inference time: 273.30902099609375
backward time: 698.7430419921875
--------------------
inference time: 225.83270263671875
backward time: 913.3738403320312
--------------------
inference time: 262.9409484863281
backward time: 719.9522705078125
--------------------
inference time: 290.29736328125
backward time: 725.3414916992188
--------------------
inference time: 272.18145751953125
backward time: 749.257080078125

multi-node (16xA100):

Python VERSION: 3.9.16 (main, Mar  8 2023, 14:00:05) 
[GCC 11.2.0]
PyTorch VERSION: 1.13.0+cu116
RCCL VERSION: (2, 14, 3)
Available GPUs  8
--------------------
inference time: 2360.786865234375
backward time: 5019.62890625
--------------------
inference time: 2220.7119140625
backward time: 4965.62353515625
--------------------
inference time: 2185.54833984375
backward time: 5020.552734375
--------------------
inference time: 2225.5
backward time: 4991.302734375
--------------------
inference time: 2175.977294921875
backward time: 4985.82177734375
--------------------
inference time: 2194.322265625
backward time: 5019.4736328125
--------------------
inference time: 2190.4521484375
backward time: 5011.6689453125
--------------------
inference time: 2168.658203125
backward time: 4985.7353515625

It seems that in multi-node training, cuda event time is 10x slower.
I also tested torch2.1+cu117, results are similiar. I will try torch.profiler later.

@JulioZhao97
Copy link
Author

JulioZhao97 commented May 29, 2023

These are my distributed init code for possible bug checking @awgu :

_init_dist_slurm(args.dist_backend)
args.world_size = int(os.environ['WORLD_SIZE'])
args.local_rank = int(os.environ['LOCAL_RANK'])
args.rank = int(os.environ['RANK'])
args.gpu = args.rank % torch.cuda.device_count()
print(
    "| distributed init (rank {}, world {}, local_rank {}): {}".format(
    args.rank, args.world_size, args.local_rank, args.dist_url
    ),
    flush=True,
)

def _init_dist_slurm(backend, port=None) -> None:
    """Initialize slurm distributed training environment.
    If argument ``port`` is not specified, then the master port will be system
    environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
    environment variable, then a default port ``29500`` will be used.
    Args:
        backend (str): Backend of torch.distributed.
        port (int, optional): Master port. Defaults to None.
    """
    proc_id = int(os.environ['SLURM_PROCID'])
    ntasks = int(os.environ['SLURM_NTASKS'])
    node_list = os.environ['SLURM_NODELIST']
    num_gpus = torch.cuda.device_count()
    torch.cuda.set_device(proc_id % num_gpus)
    addr = subprocess.getoutput(
        f'scontrol show hostname {node_list} | head -n1')
    # specify master port
    if port is not None:
        os.environ['MASTER_PORT'] = str(port)
    elif 'MASTER_PORT' in os.environ:
        pass  # use MASTER_PORT in the environment variable
    else:
        # 29500 is torch.distributed default port
        os.environ['MASTER_PORT'] = '29500'
    # use MASTER_ADDR in the environment variable if it already exists
    if 'MASTER_ADDR' not in os.environ:
        os.environ['MASTER_ADDR'] = addr
    os.environ['WORLD_SIZE'] = str(ntasks)
    os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
    os.environ['RANK'] = str(proc_id)
    torch.distributed.init_process_group(backend=backend)

Besides, I want to ask whether the device_id in my FSDP init correct?
I tested passing device_id = f'cuda:{int(os.environ['RANK']) % torch.cuda.device_count()}' and device_id = f'cuda:{int(os.environ['LOCAL_RANK'])}', which are both normal

@JulioZhao97
Copy link
Author

JulioZhao97 commented May 29, 2023

These are results I gathered using torch.profiler, I warmup for 10 iterations and calculate for 10 active iterations.
First is the multi-node result

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                 c10d::_allgather_base_         0.04%      51.574ms         0.05%      61.336ms      75.723us       42.930s        54.65%       42.930s      53.000ms           810  
ncclKernel_AllGather_RING_LL_Sum_int8_t(ncclDevComm*...         0.00%       0.000us         0.00%       0.000us       0.000us       42.930s        54.65%       42.930s      53.000ms           810  
                                     record_param_comms         0.02%      29.285ms         0.03%      33.318ms      20.440us       27.985s        35.63%       27.985s      17.169ms          1630  
ncclKernel_ReduceScatter_RING_LL_Sum_half(ncclDevCom...         0.00%       0.000us         0.00%       0.000us       0.000us       27.985s        35.63%       27.985s      68.257ms           410  
                                               aten::mm         0.38%     456.046ms        10.75%       12.890s       1.233ms        3.376s         4.30%        3.376s     322.916us         10454  
                                            aten::copy_         0.11%     134.708ms         6.51%        7.805s     395.294us     791.111ms         1.01%     791.111ms      40.064us         19746  
                                              aten::cat         0.04%      47.347ms         0.05%      58.984ms      42.011us     728.563ms         0.93%     728.563ms     518.920us          1404  
ampere_fp16_s16816gemm_fp16_128x128_ldg8_f2f_stages_...         0.00%       0.000us         0.00%       0.000us       0.000us     724.303ms         0.92%     724.303ms     256.118us          2828  
ampere_fp16_s16816gemm_fp16_256x128_ldg8_f2f_stages_...         0.00%       0.000us         0.00%       0.000us       0.000us     644.382ms         0.82%     644.382ms     344.221us          1872  
void at::native::(anonymous namespace)::CatArrayBatc...         0.00%       0.000us         0.00%       0.000us       0.000us     620.840ms         0.79%     620.840ms     991.757us           626  
                                              aten::mul         0.35%     417.274ms         1.94%        2.324s     172.450us     573.682ms         0.73%     573.682ms      42.561us         13479  
ampere_fp16_s16816gemm_fp16_256x128_ldg8_f2f_stages_...         0.00%       0.000us         0.00%       0.000us       0.000us     493.032ms         0.63%     493.032ms     397.606us          1240  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     415.334ms         0.53%     415.334ms     191.046us          2174  
                                             aten::div_         0.01%       8.808ms         0.01%      16.280ms      19.854us     350.381ms         0.45%     350.381ms     427.294us           820  
                                              aten::bmm         0.10%     120.971ms         0.70%     839.702ms     222.851us     291.428ms         0.37%     291.428ms      77.343us          3768  
ampere_fp16_s16816gemm_fp16_128x256_ldg8_f2f_stages_...         0.00%       0.000us         0.00%       0.000us       0.000us     285.431ms         0.36%     285.431ms     548.906us           520  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us     277.532ms         0.35%     277.532ms      48.216us          5756  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     261.536ms         0.33%     261.536ms      40.101us          6522  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     243.290ms         0.31%     243.290ms      50.475us          4820  
ampere_fp16_s16816gemm_fp16_128x256_ldg8_f2f_stages_...         0.00%       0.000us         0.00%       0.000us       0.000us     198.475ms         0.25%     198.475ms     413.490us           480  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us     177.386ms         0.23%     177.386ms      40.224us          4410  
                                              aten::add         0.10%     118.880ms         0.96%        1.150s     277.533us     152.129ms         0.19%     152.129ms      36.728us          4142  
void cutlass::Kernel<cutlass_80_tensorop_f16_s16816g...         0.00%       0.000us         0.00%       0.000us       0.000us     147.563ms         0.19%     147.563ms     614.846us           240  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     146.335ms         0.19%     146.335ms      25.169us          5814  
                                             aten::add_         0.02%      29.949ms         2.02%        2.425s     491.691us     136.926ms         0.17%     136.926ms      27.768us          4931  
ampere_fp16_s16816gemm_fp16_256x128_ldg8_f2f_stages_...         0.00%       0.000us         0.00%       0.000us       0.000us     131.339ms         0.17%     131.339ms     328.348us           400  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     128.817ms         0.16%     128.817ms      76.223us          1690  
                                    aten::_foreach_mul_         0.01%       6.658ms         0.01%      10.563ms     352.100us     126.358ms         0.16%     126.358ms       4.212ms            30  
void at::native::(anonymous namespace)::multi_tensor...         0.00%       0.000us         0.00%       0.000us       0.000us     126.358ms         0.16%     126.358ms     110.840us          1140  
void cutlass::Kernel<cutlass_80_tensorop_f16_s16816g...         0.00%       0.000us         0.00%       0.000us       0.000us     107.500ms         0.14%     107.500ms     671.875us           160  
void at::native::(anonymous namespace)::CatArrayBatc...         0.00%       0.000us         0.00%       0.000us       0.000us     107.260ms         0.14%     107.260ms     134.075us           800  
                                            aten::fill_         0.02%      24.850ms        10.42%       12.493s       1.759ms     100.706ms         0.13%     100.706ms      14.180us          7102  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      98.604ms         0.13%      98.604ms      14.295us          6898  
ampere_fp16_s16816gemm_fp16_128x128_ldg8_f2f_stages_...         0.00%       0.000us         0.00%       0.000us       0.000us      97.693ms         0.12%      97.693ms     163.914us           596  
void cutlass::Kernel<cutlass_80_tensorop_f16_s16816g...         0.00%       0.000us         0.00%       0.000us       0.000us      96.321ms         0.12%      96.321ms     401.337us           240  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      96.216ms         0.12%      96.216ms      42.423us          2268  
ampere_fp16_s16816gemm_fp16_256x128_ldg8_f2f_stages_...         0.00%       0.000us         0.00%       0.000us       0.000us      89.900ms         0.11%      89.900ms     280.938us           320  
                                              aten::div         0.06%      69.216ms         0.53%     633.447ms     294.079us      88.402ms         0.11%      88.402ms      41.041us          2154  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      84.409ms         0.11%      84.409ms      41.995us          2010  
                                              aten::sum         0.07%      86.647ms         0.74%     889.445ms     332.379us      80.434ms         0.10%      80.508ms      30.085us          2676  
                                aten::_foreach_addcdiv_         0.00%       1.214ms         0.00%       2.462ms     246.200us      76.592ms         0.10%      76.592ms       7.659ms            10  
void at::native::(anonymous namespace)::multi_tensor...         0.00%       0.000us         0.00%       0.000us       0.000us      76.592ms         0.10%      76.592ms     201.558us           380  
void cutlass::Kernel<cutlass_80_tensorop_f16_s16816g...         0.00%       0.000us         0.00%       0.000us       0.000us      64.760ms         0.08%      64.760ms     323.800us           200  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      63.591ms         0.08%      63.591ms      30.602us          2078  
void at::native::(anonymous namespace)::multi_tensor...         0.00%       0.000us         0.00%       0.000us       0.000us      59.807ms         0.08%      59.807ms     157.387us           380  
                                aten::_foreach_addcmul_         0.00%       1.901ms         0.00%       4.574ms     457.400us      59.646ms         0.08%      59.646ms       5.965ms            10  
                                    aten::_foreach_add_         0.00%       3.413ms         0.01%       8.764ms     438.200us      59.272ms         0.08%      59.272ms       2.964ms            20  
void at::native::(anonymous namespace)::multi_tensor...         0.00%       0.000us         0.00%       0.000us       0.000us      59.272ms         0.08%      59.272ms     155.979us           380  
                                              aten::neg         0.04%      48.800ms         0.27%     318.002ms     198.751us      58.062ms         0.07%      58.062ms      36.289us          1600  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      58.062ms         0.07%      58.062ms      36.289us          1600  
void cutlass::Kernel<cutlass_75_tensorop_f16_s1688ge...         0.00%       0.000us         0.00%       0.000us       0.000us      49.890ms         0.06%      49.890ms     242.184us           206  
                           aten::_softmax_backward_data         0.02%      20.122ms         0.18%     215.717ms     343.498us      49.724ms         0.06%     103.079ms     164.139us           628  
void at::native::reduce_kernel<128, 4, at::native::R...         0.00%       0.000us         0.00%       0.000us       0.000us      46.408ms         0.06%      46.408ms      25.140us          1846  
void cutlass::Kernel<cutlass_80_tensorop_f16_s16816g...         0.00%       0.000us         0.00%       0.000us       0.000us      46.151ms         0.06%      46.151ms     576.888us            80  
void cutlass::Kernel<cutlass_75_tensorop_f16_s1688ge...         0.00%       0.000us         0.00%       0.000us       0.000us      45.461ms         0.06%      45.461ms     114.801us           396  
ampere_fp16_s16816gemm_fp16_128x128_ldg8_f2f_stages_...         0.00%       0.000us         0.00%       0.000us       0.000us      45.333ms         0.06%      45.333ms     283.331us           160  
                                         aten::_softmax         0.01%      15.268ms         2.62%        3.141s       5.001ms      43.942ms         0.06%      43.942ms      69.971us           628  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      43.625ms         0.06%      43.625ms      53.858us           810  
                                    aten::_foreach_div_         0.00%       1.138ms         0.00%       2.425ms     242.500us      42.890ms         0.05%      42.890ms       4.289ms            10  
void at::native::(anonymous namespace)::multi_tensor...         0.00%       0.000us         0.00%       0.000us       0.000us      42.890ms         0.05%      42.890ms     112.868us           380  
       aten::_amp_foreach_non_finite_check_and_unscale_         0.00%       1.199ms         0.30%     362.932ms      36.293ms      42.725ms         0.05%      42.725ms       4.272ms            10  
void at::native::(anonymous namespace)::multi_tensor...         0.00%       0.000us         0.00%       0.000us       0.000us      42.725ms         0.05%      42.725ms     112.434us           380  
                                    aten::_foreach_sqrt         0.00%       2.635ms         0.01%       8.805ms     880.500us      42.099ms         0.05%      42.099ms       4.210ms            10  
void at::native::(anonymous namespace)::multi_tensor...         0.00%       0.000us         0.00%       0.000us       0.000us      42.099ms         0.05%      42.099ms     110.787us           380  
                                     aten::_foreach_add         0.00%       1.932ms         0.00%       4.547ms     454.700us      41.753ms         0.05%      41.753ms       4.175ms            10  
void at::native::(anonymous namespace)::multi_tensor...         0.00%       0.000us         0.00%       0.000us       0.000us      41.753ms         0.05%      41.753ms     109.876us           380  
void cutlass::Kernel<cutlass_75_tensorop_f16_s1688ge...         0.00%       0.000us         0.00%       0.000us       0.000us      41.105ms         0.05%      41.105ms      97.869us           420  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      39.457ms         0.05%      39.457ms      48.712us           810  
void cutlass::Kernel<cutlass_75_tensorop_f16_s1688ge...         0.00%       0.000us         0.00%       0.000us       0.000us      37.561ms         0.05%      37.561ms       3.756ms            10  
void (anonymous namespace)::softmax_warp_backward<fl...         0.00%       0.000us         0.00%       0.000us       0.000us      36.433ms         0.05%      36.433ms      72.866us           500  
void cutlass::Kernel<cutlass_75_tensorop_f16_s1688ge...         0.00%       0.000us         0.00%       0.000us       0.000us      35.290ms         0.04%      35.290ms       3.529ms            10  
                                              aten::pow         0.26%     306.898ms         3.14%        3.761s       1.526ms      33.340ms         0.04%      65.561ms      26.608us          2464  
void at::native::reduce_kernel<512, 1, at::native::R...         0.00%       0.000us         0.00%       0.000us       0.000us      33.236ms         0.04%      33.236ms      41.032us           810  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      32.864ms         0.04%      32.864ms      40.573us           810  
                                            aten::addmm         0.03%      35.553ms         0.04%      48.963ms      48.382us      32.863ms         0.04%      32.863ms      32.473us          1012  
                                             aten::mean         0.02%      27.646ms         0.03%      35.133ms      43.374us      32.826ms         0.04%      32.826ms      40.526us           810  
void at::native::reduce_kernel<512, 1, at::native::R...         0.00%       0.000us         0.00%       0.000us       0.000us      32.826ms         0.04%      32.826ms      40.526us           810  
void (anonymous namespace)::softmax_warp_forward<c10...         0.00%       0.000us         0.00%       0.000us       0.000us      31.860ms         0.04%      31.860ms      63.720us           500  
                                            aten::where         0.00%       4.834ms         0.11%     134.290ms     335.725us      31.716ms         0.04%      31.716ms      79.290us           400  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      31.716ms         0.04%      31.716ms      79.290us           400  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      30.919ms         0.04%      30.919ms      19.086us          1620  
sm80_xmma_gemm_f16f16_f16f32_f32_tn_n_tilesize160x12...         0.00%       0.000us         0.00%       0.000us       0.000us      30.074ms         0.04%      30.074ms     187.963us           160  
sm80_xmma_gemm_f16f16_f16f32_f32_nn_n_tilesize160x12...         0.00%       0.000us         0.00%       0.000us       0.000us      27.677ms         0.04%      27.677ms     172.981us           160  
void cutlass::Kernel<cutlass_75_tensorop_f16_s1688ge...         0.00%       0.000us         0.00%       0.000us       0.000us      26.877ms         0.03%      26.877ms      67.871us           396  
                                        Memset (Device)         0.00%       0.000us         0.00%       0.000us       0.000us      26.333ms         0.03%      26.333ms       5.532us          4760  
                                          aten::maximum         0.01%      11.198ms         0.04%      43.289ms     108.222us      24.632ms         0.03%      24.632ms      61.580us           400  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      24.632ms         0.03%      24.632ms      61.580us           400  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      24.138ms         0.03%      24.138ms      29.800us           810  
void cutlass::Kernel<cutlass_75_tensorop_f16_s1688ge...         0.00%       0.000us         0.00%       0.000us       0.000us      24.133ms         0.03%      24.133ms     120.665us           200  
void cutlass::Kernel<cutlass_75_tensorop_f16_s1688ge...         0.00%       0.000us         0.00%       0.000us       0.000us      24.024ms         0.03%      24.024ms     100.100us           240  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      23.449ms         0.03%      23.449ms      28.949us           810  
                                             aten::silu         0.01%      10.941ms         0.04%      52.046ms     130.115us      22.632ms         0.03%      22.632ms      56.580us           400  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      22.632ms         0.03%      22.632ms      56.580us           400  
                                     aten::masked_fill_         0.00%       3.290ms         0.07%      86.398ms     200.926us      22.608ms         0.03%      22.608ms      52.577us           430  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      22.528ms         0.03%      22.528ms      54.946us           410  
                                    aten::silu_backward         0.01%      14.473ms         0.13%     151.733ms     379.332us      22.021ms         0.03%      22.021ms      55.053us           400  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      22.021ms         0.03%      22.021ms      55.053us           400  
                                           aten::gather         0.02%      27.190ms         0.03%      31.839ms      39.799us      20.830ms         0.03%      20.830ms      26.038us           800  
void at::native::_scatter_gather_elementwise_kernel<...         0.00%       0.000us         0.00%       0.000us       0.000us      20.830ms         0.03%      20.830ms      26.038us           800  
void cutlass::Kernel<cutlass_75_tensorop_f16_s1688ge...         0.00%       0.000us         0.00%       0.000us       0.000us      20.314ms         0.03%      20.314ms       5.079ms             4  
                                               aten::eq         0.01%      12.039ms         7.53%        9.025s      22.011ms      18.493ms         0.02%      18.493ms      45.105us           410  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      18.454ms         0.02%      18.454ms      46.135us           400  
                                               aten::lt         0.01%      11.691ms         0.13%     151.848ms     370.361us      17.289ms         0.02%      17.289ms      42.168us           410  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      17.234ms         0.02%      17.234ms      43.085us           400  
                       aten::native_layer_norm_backward         0.01%       7.192ms         0.04%      44.265ms     100.602us      13.157ms         0.02%      13.157ms      29.902us           440  
void (anonymous namespace)::softmax_warp_backward<fl...         0.00%       0.000us         0.00%       0.000us       0.000us      12.817ms         0.02%      12.817ms     160.213us            80  
void (anonymous namespace)::softmax_warp_forward<c10...         0.00%       0.000us         0.00%       0.000us       0.000us      11.889ms         0.02%      11.889ms     148.613us            80  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      10.966ms         0.01%      10.966ms      13.538us           810  
                                            aten::rsqrt         0.02%      22.649ms         0.06%      74.578ms      91.958us      10.823ms         0.01%      10.827ms      13.350us           811  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      10.823ms         0.01%      10.823ms      13.362us           810  
ampere_fp16_s1688gemm_fp16_256x128_ldg8_f2f_stages_3...         0.00%       0.000us         0.00%       0.000us       0.000us      10.667ms         0.01%      10.667ms     133.338us            80  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      10.431ms         0.01%      10.431ms      25.441us           410  
ampere_fp16_s16816gemm_fp16_128x64_ldg8_relu_f2f_sta...         0.00%       0.000us         0.00%       0.000us       0.000us       9.753ms         0.01%       9.753ms      31.260us           312  
ampere_fp16_s1688gemm_fp16_256x128_ldg8_f2f_stages_3...         0.00%       0.000us         0.00%       0.000us       0.000us       9.727ms         0.01%       9.727ms     121.588us            80  
ampere_fp16_s16816gemm_fp16_128x128_ldg8_relu_f2f_st...         0.00%       0.000us         0.00%       0.000us       0.000us       9.665ms         0.01%       9.665ms      61.955us           156  
void cutlass::Kernel<cutlass_75_tensorop_f16_s1688ge...         0.00%       0.000us         0.00%       0.000us       0.000us       9.448ms         0.01%       9.448ms      48.204us           196  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       9.202ms         0.01%       9.202ms      11.360us           810  
void cutlass::Kernel<cutlass_75_tensorop_f16_s1688ge...         0.00%       0.000us         0.00%       0.000us       0.000us       8.591ms         0.01%       8.591ms      43.832us           196  
ampere_fp16_s16816gemm_fp16_128x64_ldg8_relu_f2f_sta...         0.00%       0.000us         0.00%       0.000us       0.000us       8.282ms         0.01%       8.282ms      53.090us           156  
void cutlass::Kernel<cutlass_75_wmma_tensorop_f16_s1...         0.00%       0.000us         0.00%       0.000us       0.000us       8.051ms         0.01%       8.051ms      51.609us           156  
ampere_fp16_s16816gemm_fp16_128x128_ldg8_f2f_stages_...         0.00%       0.000us         0.00%       0.000us       0.000us       6.909ms         0.01%       6.909ms      57.575us           120  
ampere_fp16_s16816gemm_fp16_128x64_ldg8_f2f_stages_3...         0.00%       0.000us         0.00%       0.000us       0.000us       6.724ms         0.01%       6.724ms      27.557us           244  
void at::native::(anonymous namespace)::layer_norm_g...         0.00%       0.000us         0.00%       0.000us       0.000us       6.664ms         0.01%       6.664ms      15.145us           440  
sm80_xmma_gemm_f16f16_f16f32_f32_nt_n_tilesize96x128...         0.00%       0.000us         0.00%       0.000us       0.000us       4.671ms         0.01%       4.671ms      29.942us           156  
void cutlass::Kernel<cutlass_80_tensorop_f16_s16816g...         0.00%       0.000us         0.00%       0.000us       0.000us       4.328ms         0.01%       4.328ms     108.200us            40  
void at::native::(anonymous namespace)::GammaBetaBac...         0.00%       0.000us         0.00%       0.000us       0.000us       4.282ms         0.01%       4.282ms      13.551us           316  
                                    aten::gelu_backward         0.00%       3.575ms         0.19%     227.445ms       1.115ms       4.229ms         0.01%       4.229ms      20.730us           204  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       4.229ms         0.01%       4.229ms      20.730us           204  
void cutlass::Kernel<cutlass_80_tensorop_f16_s16816g...         0.00%       0.000us         0.00%       0.000us       0.000us       3.974ms         0.01%       3.974ms      99.350us            40  
void cutlass::Kernel<cutlass_80_tensorop_f16_s16816g...         0.00%       0.000us         0.00%       0.000us       0.000us       3.807ms         0.00%       3.807ms      95.175us            40  
void cutlass::Kernel<cutlass_80_tensorop_f16_s16816g...         0.00%       0.000us         0.00%       0.000us       0.000us       3.743ms         0.00%       3.743ms      93.575us            40  
ampere_fp16_s16816gemm_fp16_64x64_sliced1x2_ldg8_f2f...         0.00%       0.000us         0.00%       0.000us       0.000us       3.735ms         0.00%       3.735ms      15.562us           240  
                                aten::native_layer_norm         0.01%       9.083ms         0.02%      18.680ms      42.455us       3.618ms         0.00%       3.618ms       8.223us           440  
void at::native::(anonymous namespace)::vectorized_l...         0.00%       0.000us         0.00%       0.000us       0.000us       3.618ms         0.00%       3.618ms       8.223us           440  
void cutlass::Kernel<cutlass_80_tensorop_f16_s16816g...         0.00%       0.000us         0.00%       0.000us       0.000us       3.240ms         0.00%       3.240ms      81.000us            40  
void cutlass::Kernel<cutlass_80_tensorop_f16_s16816g...         0.00%       0.000us         0.00%       0.000us       0.000us       3.143ms         0.00%       3.143ms      78.575us            40  
ampere_fp16_s16816gemm_fp16_128x64_ldg8_f2f_stages_6...         0.00%       0.000us         0.00%       0.000us       0.000us       3.124ms         0.00%       3.124ms      13.017us           240  
void at::native::(anonymous namespace)::GammaBetaBac...         0.00%       0.000us         0.00%       0.000us       0.000us       2.211ms         0.00%       2.211ms      17.831us           124  
void cutlass::Kernel<cutlass_80_tensorop_f16_s16816g...         0.00%       0.000us         0.00%       0.000us       0.000us       2.183ms         0.00%       2.183ms       7.580us           288  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       2.182ms         0.00%       2.182ms      10.102us           216  
                                     aten::_log_softmax         0.00%     363.000us         0.00%     440.000us      44.000us       2.058ms         0.00%       2.058ms     205.800us            10  
void at::native::(anonymous namespace)::cunn_SoftMax...         0.00%       0.000us         0.00%       0.000us       0.000us       2.058ms         0.00%       2.058ms     205.800us            10  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.982ms         0.00%       1.982ms      14.157us           140  
                       aten::_log_softmax_backward_data         0.00%     275.000us         0.00%     333.000us      33.300us       1.974ms         0.00%       1.974ms     197.400us            10  
void at::native::(anonymous namespace)::cunn_SoftMax...         0.00%       0.000us         0.00%       0.000us       0.000us       1.974ms         0.00%       1.974ms     197.400us            10  
                                             aten::gelu         0.00%       3.302ms         0.00%       4.357ms      21.358us       1.926ms         0.00%       1.926ms       9.441us           204  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.926ms         0.00%       1.926ms       9.441us           204  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.807ms         0.00%       1.807ms      11.018us           164  
ampere_fp16_s16816gemm_fp16_128x64_ldg8_f2f_stages_3...         0.00%       0.000us         0.00%       0.000us       0.000us       1.789ms         0.00%       1.789ms      44.725us            40  
ampere_fp16_s16816gemm_fp16_128x64_ldg8_f2f_stages_3...         0.00%       0.000us         0.00%       0.000us       0.000us       1.634ms         0.00%       1.634ms      40.850us            40  
void splitKreduce_kernel<32, 16, int, __half, __half...         0.00%       0.000us         0.00%       0.000us       0.000us       1.443ms         0.00%       1.443ms       5.010us           288  
void cutlass::Kernel<cutlass_80_wmma_tensorop_f16_s1...         0.00%       0.000us         0.00%       0.000us       0.000us       1.091ms         0.00%       1.091ms      11.365us            96  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       1.065ms         0.00%       1.065ms       6.827us           156  
sm80_xmma_gemm_f16f16_f16f32_f32_nt_n_tilesize64x96x...         0.00%       0.000us         0.00%       0.000us       0.000us       1.043ms         0.00%       1.043ms      21.729us            48  
                         aten::embedding_dense_backward         0.00%     320.000us         0.00%       5.079ms     149.382us       1.024ms         0.00%       9.861ms     290.029us            34  
void at::native::(anonymous namespace)::embedding_ba...         0.00%       0.000us         0.00%       0.000us       0.000us       1.024ms         0.00%       1.024ms      30.118us            34  
sm80_xmma_gemm_f16f16_f16f32_f32_nt_n_tilesize160x12...         0.00%       0.000us         0.00%       0.000us       0.000us       1.018ms         0.00%       1.018ms      19.577us            52  
ampere_fp16_s1688gemm_fp16_256x64_ldg8_f2f_stages_32...         0.00%       0.000us         0.00%       0.000us       0.000us     846.000us         0.00%     846.000us      17.625us            48  
ampere_s16816gemm_fp16_64x64_sliced1x2_ldg8_stages_6...         0.00%       0.000us         0.00%       0.000us       0.000us     819.000us         0.00%     819.000us      17.062us            48  
ampere_fp16_s16816gemm_fp16_64x64_sliced1x2_ldg8_f2f...         0.00%       0.000us         0.00%       0.000us       0.000us     780.000us         0.00%     780.000us      16.250us            48  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us     725.000us         0.00%     725.000us       6.042us           120  
                             aten::convolution_backward         0.00%     401.000us         0.00%     806.000us     201.500us     716.000us         0.00%       1.102ms     275.500us             4  
ampere_fp16_s16816gemm_fp16_64x64_ldg8_relu_f2f_stag...         0.00%       0.000us         0.00%       0.000us       0.000us     697.000us         0.00%     697.000us      14.521us            48  
void cutlass::Kernel<cutlass_75_wmma_tensorop_f16_s1...         0.00%       0.000us         0.00%       0.000us       0.000us     660.000us         0.00%     660.000us      27.500us            24  
                                     aten::index_select         0.00%     498.000us         0.00%       1.087ms      31.971us     648.000us         0.00%     648.000us      19.059us            34  
void cutlass::Kernel<cutlass_80_tensorop_f16_s16816g...         0.00%       0.000us         0.00%       0.000us       0.000us     584.000us         0.00%     584.000us      12.167us            48  
void cutlass::Kernel<cutlass_80_wmma_tensorop_f16_s1...         0.00%       0.000us         0.00%       0.000us       0.000us     560.000us         0.00%     560.000us      11.667us            48  
void splitKreduce_kernel<32, 16, int, float, __half,...         0.00%       0.000us         0.00%       0.000us       0.000us     542.000us         0.00%     542.000us      11.292us            48  
void cutlass::Kernel<cutlass_80_wmma_tensorop_f16_s1...         0.00%       0.000us         0.00%       0.000us       0.000us     542.000us         0.00%     542.000us      11.292us            48  
void at::native::(anonymous namespace)::indexSelectL...         0.00%       0.000us         0.00%       0.000us       0.000us     485.000us         0.00%     485.000us      48.500us            10  
void (anonymous namespace)::softmax_warp_backward<fl...         0.00%       0.000us         0.00%       0.000us       0.000us     474.000us         0.00%     474.000us       9.875us            48  
                                 aten::nll_loss_forward         0.00%     572.000us         0.00%     782.000us      78.200us     444.000us         0.00%     444.000us      44.400us            10  
void at::native::(anonymous namespace)::nll_loss_for...         0.00%       0.000us         0.00%       0.000us       0.000us     444.000us         0.00%     444.000us      44.400us            10  
sm80_xmma_wgrad_implicit_gemm_indexed_wo_smem_f16f16...         0.00%       0.000us         0.00%       0.000us       0.000us     428.000us         0.00%     428.000us     107.000us             4  
                                aten::cudnn_convolution         0.00%     459.000us         0.00%     636.000us     159.000us     395.000us         0.00%     395.000us      98.750us             4  
void cutlass::Kernel<cutlass_75_wmma_tensorop_f16_s1...         0.00%       0.000us         0.00%       0.000us       0.000us     390.000us         0.00%     390.000us      16.250us            24  
void cutlass::Kernel<cutlass_75_tensorop_f16_s1688ge...         0.00%       0.000us         0.00%       0.000us       0.000us     370.000us         0.00%     370.000us      15.417us            24  
void cutlass::Kernel<cutlass_75_wmma_tensorop_f16_s1...         0.00%       0.000us         0.00%       0.000us       0.000us     366.000us         0.00%     366.000us      15.250us            24  
                                aten::nll_loss_backward         0.00%     385.000us         0.00%     701.000us      70.100us     364.000us         0.00%       1.134ms     113.400us            10  
void at::native::(anonymous namespace)::nll_loss_bac...         0.00%       0.000us         0.00%       0.000us       0.000us     364.000us         0.00%     364.000us      36.400us            10  
void at::native::(anonymous namespace)::CatArrayBatc...         0.00%       0.000us         0.00%       0.000us       0.000us     276.000us         0.00%     276.000us      27.600us            10  
void cudnn::ops::nchwToNhwcKernel<__half, __half, fl...         0.00%       0.000us         0.00%       0.000us       0.000us     260.000us         0.00%     260.000us      16.250us            16  
void cutlass::Kernel<cutlass_80_wmma_tensorop_f16_s1...         0.00%       0.000us         0.00%       0.000us       0.000us     240.000us         0.00%     240.000us       5.000us            48  
void cutlass::Kernel<cutlass_80_wmma_tensorop_f16_s1...         0.00%       0.000us         0.00%       0.000us       0.000us     240.000us         0.00%     240.000us       5.000us            48  
void cutlass::Kernel<cutlass_75_wmma_tensorop_f16_s1...         0.00%       0.000us         0.00%       0.000us       0.000us     240.000us         0.00%     240.000us      10.000us            24  
sm80_xmma_fprop_implicit_gemm_indexed_wo_smem_f16f16...         0.00%       0.000us         0.00%       0.000us       0.000us     232.000us         0.00%     232.000us      58.000us             4  
                         Memcpy HtoD (Pinned -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us     217.000us         0.00%     217.000us      21.700us            10  
void (anonymous namespace)::softmax_warp_forward<c10...         0.00%       0.000us         0.00%       0.000us       0.000us     193.000us         0.00%     193.000us       4.021us            48  
void at::native::(anonymous namespace)::indexSelectS...         0.00%       0.000us         0.00%       0.000us       0.000us     163.000us         0.00%     163.000us       6.792us            24  
                       Memcpy HtoD (Pageable -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us     151.000us         0.00%     151.000us       1.480us           102  
void cudnn::ops::nhwcToNchwKernel<__half, __half, fl...         0.00%       0.000us         0.00%       0.000us       0.000us     148.000us         0.00%     148.000us      18.500us             8  
ampere_fp16_s16816gemm_fp16_128x128_ldg8_f2f_stages_...         0.00%       0.000us         0.00%       0.000us       0.000us     105.000us         0.00%     105.000us      26.250us             4  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     102.000us         0.00%     102.000us       3.000us            34  
void at::native::(anonymous namespace)::CatArrayBatc...         0.00%       0.000us         0.00%       0.000us       0.000us      94.000us         0.00%      94.000us      11.750us             8  
void at::native::reduce_kernel<512, 1, at::native::R...         0.00%       0.000us         0.00%       0.000us       0.000us      92.000us         0.00%      92.000us       4.600us            20  
                              aten::_local_scalar_dense         0.00%     298.000us         8.62%       10.339s      11.884ms      89.000us         0.00%      89.000us       0.102us           870  
                       Memcpy DtoH (Device -> Pageable)         0.00%       0.000us         0.00%       0.000us       0.000us      89.000us         0.00%      89.000us       1.780us            50  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      75.000us         0.00%      75.000us       3.409us            22  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      74.000us         0.00%      74.000us       3.700us            20  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      71.000us         0.00%      71.000us       7.100us            10  
                                              aten::sub         0.00%     229.000us         0.00%     324.000us      18.000us      66.000us         0.00%      66.000us       3.667us            18  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      66.000us         0.00%      66.000us       3.667us            18  
                                           aten::arange         0.00%     268.000us         0.00%       1.045ms      26.125us      61.000us         0.00%     119.000us       2.975us            40  
void (anonymous namespace)::elementwise_kernel_with_...         0.00%       0.000us         0.00%       0.000us       0.000us      61.000us         0.00%      61.000us       3.050us            20  
                                               aten::ne         0.00%     317.000us         0.00%     507.000us      25.350us      60.000us         0.00%      60.000us       3.000us            20  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      60.000us         0.00%      60.000us       6.000us            10  
ampere_fp16_s16816gemm_fp16_128x64_ldg8_relu_f2f_sta...         0.00%       0.000us         0.00%       0.000us       0.000us      60.000us         0.00%      60.000us      15.000us             4  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      55.000us         0.00%      55.000us       5.500us            10  
void splitKreduce_kernel<32, 16, int, __half, __half...         0.00%       0.000us         0.00%       0.000us       0.000us      52.000us         0.00%      52.000us      13.000us             4  
void at::native::(anonymous namespace)::CatArrayBatc...         0.00%       0.000us         0.00%       0.000us       0.000us      50.000us         0.00%      50.000us       5.000us            10  
void at::native::(anonymous namespace)::CatArrayBatc...         0.00%       0.000us         0.00%       0.000us       0.000us      43.000us         0.00%      43.000us       4.300us            10  
                               aten::_amp_update_scale_         0.00%     121.000us         0.00%     199.000us      19.900us      42.000us         0.00%      42.000us       4.200us            10  
at::native::amp_update_scale_cuda_kernel(float*, int...         0.00%       0.000us         0.00%       0.000us       0.000us      42.000us         0.00%      42.000us       4.200us            10  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      41.000us         0.00%      41.000us       4.100us            10  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      40.000us         0.00%      40.000us       4.000us            10  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      40.000us         0.00%      40.000us       4.000us            10  
                                       aten::reciprocal         0.00%     305.000us         0.00%     389.000us      38.900us      40.000us         0.00%      40.000us       4.000us            10  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      40.000us         0.00%      40.000us       4.000us            10  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      30.000us         0.00%      30.000us       3.000us            10  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      30.000us         0.00%      30.000us       3.000us            10  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      28.000us         0.00%      28.000us       7.000us             4  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      18.000us         0.00%      18.000us       3.000us             6  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      12.000us         0.00%      12.000us       3.000us             4  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       6.000us         0.00%       6.000us       3.000us             2  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       6.000us         0.00%       6.000us       3.000us             2  
                                          ProfilerStep*        39.56%       47.440s        60.79%       72.897s        7.290s       0.000us         0.00%       25.810s        2.581s            10  
                                    cudaStreamWaitEvent         0.01%       7.538ms         0.01%       7.538ms       0.505us       0.000us         0.00%       0.000us       0.000us         14924  
                                    aten::record_stream         0.01%       8.912ms         0.01%       8.912ms       3.623us       0.000us         0.00%       0.000us       0.000us          2460  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...         0.00%     727.000us         0.00%     727.000us      72.700us       0.000us         0.00%       0.000us       0.000us            10  
                                               aten::to         0.06%      68.618ms         5.59%        6.701s     461.557us       0.000us         0.00%     339.387ms      23.377us         14518  
                                         aten::_to_copy         0.04%      49.423ms         5.57%        6.685s     672.119us       0.000us         0.00%     344.177ms      34.605us          9946  
                                    aten::empty_strided         0.17%     201.915ms         0.18%     212.565ms      18.775us       0.000us         0.00%       0.000us       0.000us         11322  
                                        cudaMemcpyAsync         8.99%       10.781s         8.99%       10.781s       4.813ms       0.000us         0.00%       0.000us       0.000us          2240  
                                  cudaStreamSynchronize         1.07%        1.282s         1.07%        1.282s       7.913ms       0.000us         0.00%       0.000us       0.000us           162  
                       FullyShardedDataParallel.forward         0.13%     150.628ms        22.22%       26.647s      64.993ms       0.000us         0.00%       48.363s     117.957ms           410  
             FullyShardedDataParallel._root_pre_forward         0.01%      10.496ms         0.01%      11.802ms      28.785us       0.000us         0.00%      28.000us       0.068us           410  
                    FullyShardedDataParallel._to_kwargs         0.00%       1.006ms         0.00%       1.006ms     100.600us       0.000us         0.00%       0.000us       0.000us            10  
                  FullyShardedDataParallel._pre_forward         0.24%     282.624ms         0.32%     382.291ms     932.417us       0.000us         0.00%       23.014s      56.132ms           410  
                                         cudaEventQuery         0.04%      52.042ms         0.04%      52.062ms       0.041us       0.000us         0.00%       0.000us       0.000us       1257386  
                                       cudaLaunchKernel        45.83%       54.959s        45.83%       54.960s     640.720us       0.000us         0.00%       0.000us       0.000us         85778  
                aten::_has_compatible_shallow_copy_type         0.00%      19.000us         0.00%      19.000us       0.001us       0.000us         0.00%       0.000us       0.000us         36850  
                                  cudaStreamIsCapturing         0.00%       1.004ms         0.00%       1.004ms       0.811us       0.000us         0.00%       0.000us       0.000us          1238  
                                  nccl:_all_gather_base         0.00%       0.000us             0      45.463ms      56.127us       0.000us         0.00%       0.000us       0.000us           810  
                                                INVALID         0.00%      45.000us         0.00%      45.000us       0.037us       0.000us         0.00%       0.000us       0.000us          1220  
                                            aten::slice         0.03%      38.062ms         0.03%      41.027ms       2.214us       0.000us         0.00%       0.000us       0.000us         18532  
                                       aten::as_strided         0.01%       9.696ms         0.01%       9.696ms       0.117us       0.000us         0.00%       0.000us       0.000us         82919  
                                             aten::view         0.05%      63.421ms         0.05%      63.421ms       1.222us       0.000us         0.00%       0.000us       0.000us         51894  
                                 aten::split_with_sizes         0.03%      36.028ms         0.03%      36.205ms      44.152us       0.000us         0.00%       0.000us       0.000us           820  
                               cudaPointerGetAttributes         0.00%     123.000us         0.00%     132.000us      13.200us       0.000us         0.00%       0.000us       0.000us            10  
                                        aten::expand_as         0.00%       2.266ms         0.01%       6.506ms       4.041us       0.000us         0.00%       0.000us       0.000us          1610  
                                           aten::expand         0.02%      18.657ms         0.02%      19.860ms       3.215us       0.000us         0.00%       0.000us       0.000us          6178  
                                            aten::empty         0.25%     296.365ms         0.25%     299.472ms      16.347us       0.000us         0.00%       0.000us       0.000us         18320  
                                       aten::lift_fresh         0.00%       1.000us         0.00%       1.000us       0.002us       0.000us         0.00%       0.000us       0.000us           498  
                                          aten::detach_         0.00%     823.000us         0.00%     831.000us       1.738us       0.000us         0.00%       0.000us       0.000us           478  
                                                detach_         0.00%      12.000us         0.00%      12.000us       0.025us       0.000us         0.00%       0.000us       0.000us           478  
                                        aten::embedding         0.00%     423.000us         0.00%       1.572ms      46.235us       0.000us         0.00%     648.000us      19.059us            34  
                                          aten::reshape         0.03%      41.430ms         0.69%     821.677ms      30.650us       0.000us         0.00%     185.350ms       6.914us         26808  
                                          aten::resize_         0.01%      13.067ms         0.01%      15.731ms      33.903us       0.000us         0.00%       0.000us       0.000us           464  
                                             aten::ones         0.00%     146.000us         0.00%     418.000us      11.000us       0.000us         0.00%      42.000us       1.105us            38  
                                             aten::set_         0.00%     139.000us         0.00%     139.000us       4.633us       0.000us         0.00%       0.000us       0.000us            30  
                                        aten::new_empty         0.00%      23.000us         0.00%      66.000us       6.600us       0.000us         0.00%       0.000us       0.000us            10  
                                           aten::unbind         0.00%     124.000us         0.00%     214.000us      10.700us       0.000us         0.00%       0.000us       0.000us            20  
                                           aten::select         0.00%       3.192ms         0.00%       3.224ms       3.303us       0.000us         0.00%       0.000us       0.000us           976  
                                             aten::item         0.03%      40.586ms         8.62%       10.340s      11.886ms       0.000us         0.00%      85.000us       0.098us           870  
                                      aten::masked_fill         0.00%     107.000us         0.00%     861.000us      43.050us       0.000us         0.00%     156.000us       7.800us            20  
                                            aten::clone         0.02%      19.450ms         0.63%     753.430ms     157.424us       0.000us         0.00%     186.250ms      38.916us          4786  
                                       aten::empty_like         0.01%       9.887ms         0.10%     123.089ms      22.973us       0.000us         0.00%       0.000us       0.000us          5358  
                                        aten::unsqueeze         0.00%       2.131ms         0.00%       2.175ms       2.512us       0.000us         0.00%       0.000us       0.000us           866  
                                             aten::full         0.00%      92.000us         0.00%     600.000us      30.000us       0.000us         0.00%      64.000us       3.200us            20  
                                             aten::rsub         0.00%      87.000us         0.00%     411.000us      22.833us       0.000us         0.00%      66.000us       3.667us            18  
                                      aten::result_type         0.00%       6.000us         0.00%       6.000us       0.001us       0.000us         0.00%       0.000us       0.000us          5630  
                                           aten::linear         0.03%      33.247ms         3.76%        4.504s     983.694us       0.000us         0.00%        1.254s     273.887us          4579  
                                                aten::t         0.04%      42.732ms         0.06%      67.863ms       3.551us       0.000us         0.00%       0.000us       0.000us         19110  
                                        aten::transpose         0.04%      42.781ms         0.04%      46.639ms       1.767us       0.000us         0.00%       0.000us       0.000us         26398  
                                           aten::matmul         0.03%      40.326ms         3.62%        4.340s     999.268us       0.000us         0.00%        1.410s     324.646us          4343  
cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFla...         0.01%       7.243ms         0.01%       7.243ms       0.262us       0.000us         0.00%       0.000us       0.000us         27672  
                                        cudaMemsetAsync         0.18%     215.154ms         0.18%     215.154ms      45.200us       0.000us         0.00%       0.000us       0.000us          4760  
                                     aten::_unsafe_view         0.01%      16.048ms         0.01%      16.048ms       1.842us       0.000us         0.00%       0.000us       0.000us          8710  
                                           aten::repeat         0.01%      17.174ms         0.07%      80.515ms      67.096us       0.000us         0.00%      29.981ms      24.984us          1200  
                                            aten::alias         0.00%      52.000us         0.00%      52.000us       0.043us       0.000us         0.00%       0.000us       0.000us          1200  
                                           aten::unfold         0.01%       6.574ms         0.01%       7.355ms       1.532us       0.000us         0.00%       0.000us       0.000us          4800  
                                              aten::max         0.00%       1.324ms         0.04%      43.850ms     109.625us       0.000us         0.00%      24.369ms      60.922us           400  
                                          aten::softmax         0.00%       2.712ms         2.74%        3.284s       5.014ms       0.000us         0.00%      43.942ms      67.087us           655  
                                   cudaFuncSetAttribute         0.00%     782.000us         0.00%     782.000us       0.243us       0.000us         0.00%       0.000us       0.000us          3222  
                 FullyShardedDataParallel._post_forward         0.15%     176.626ms         0.15%     183.157ms     446.724us       0.000us         0.00%       0.000us       0.000us           410  
                                       aten::contiguous         0.00%     298.000us         0.00%       5.186ms      43.949us       0.000us         0.00%       2.296ms      19.458us           118  
                               aten::cross_entropy_loss         0.00%      89.000us         0.00%       1.905ms     190.500us       0.000us         0.00%       4.215ms     421.500us            10  
                                      aten::log_softmax         0.00%      60.000us         0.00%     500.000us      50.000us       0.000us         0.00%       2.058ms     205.800us            10  
                                      aten::nll_loss_nd        -0.00%     -96.000us         0.00%       1.316ms     131.600us       0.000us         0.00%       2.157ms     215.700us            10  
                                         aten::nll_loss         0.00%     213.000us         0.00%       2.097ms     104.850us       0.000us         0.00%       2.510ms     125.500us            20  
                                        aten::ones_like         0.00%      59.000us         0.00%     831.000us      83.100us       0.000us         0.00%     210.000us      21.000us            10  
      autograd::engine::evaluate_function: MulBackward0         0.04%      47.899ms         2.64%        3.160s     834.725us       0.000us         0.00%     418.234ms     110.469us          3786  
                                           MulBackward0         0.02%      21.019ms         1.03%        1.239s     327.356us       0.000us         0.00%     252.561ms      66.709us          3786  
autograd::engine::evaluate_function: NllLossBackward...         0.00%     452.000us         0.08%      95.974ms       9.597ms       0.000us         0.00%     525.147ms      52.515ms            10  
            FullyShardedDataParallel._pre_backward_hook         0.25%     300.098ms         0.30%     360.494ms     879.254us       0.000us         0.00%       20.025s      48.842ms           410  
                                       NllLossBackward0         0.00%      54.000us         0.00%     755.000us      75.500us       0.000us         0.00%       1.134ms     113.400us            10  
                                            aten::zero_         0.01%      15.176ms        10.43%       12.507s       1.786ms       0.000us         0.00%     100.195ms      14.305us          7004  
autograd::engine::evaluate_function: ToCopyBackward0...        -0.05%  -61905.000us         4.82%        5.782s       1.448ms       0.000us         0.00%     156.814ms      39.262us          3994  
                                        ToCopyBackward0         0.04%      50.384ms         4.04%        4.848s       1.214ms       0.000us         0.00%     127.126ms      31.829us          3994  
autograd::engine::evaluate_function: LogSoftmaxBackw...         0.00%      54.000us         0.00%     443.000us      44.300us       0.000us         0.00%       1.974ms     197.400us            10  
                                    LogSoftmaxBackward0         0.00%      56.000us         0.00%     389.000us      38.900us       0.000us         0.00%       1.974ms     197.400us            10  
     autograd::engine::evaluate_function: ViewBackward0        -0.04%  -53246.000us         1.26%        1.507s     108.985us       0.000us         0.00%      83.226ms       6.018us         13830  
                                          ViewBackward0         0.10%     124.071ms         0.55%     664.172ms      48.024us       0.000us         0.00%      51.826ms       3.747us         13830  
    autograd::engine::evaluate_function: CloneBackward0         0.01%       8.719ms         0.01%       8.743ms       3.467us       0.000us         0.00%       0.000us       0.000us          2522  
                                         CloneBackward0         0.00%      24.000us         0.00%      24.000us       0.010us       0.000us         0.00%       0.000us       0.000us          2522  
    autograd::engine::evaluate_function: SliceBackward0        -0.09%  -110827.000us        11.00%       13.186s       7.684ms       0.000us         0.00%     118.180ms      68.869us          1716  
                                         SliceBackward0         0.10%     123.424ms        10.82%       12.977s       7.563ms       0.000us         0.00%      96.627ms      56.309us          1716  
                                   aten::slice_backward        -0.18%  -212669.000us        10.82%       12.975s       7.561ms       0.000us         0.00%     100.646ms      58.652us          1716  
                                            aten::zeros         0.20%     239.912ms        10.56%       12.667s       1.853ms       0.000us         0.00%      94.339ms      13.796us          6838  
autograd::engine::evaluate_function: UnsafeViewBackw...         0.02%      23.535ms         0.07%      80.513ms      12.375us       0.000us         0.00%      19.191ms       2.950us          6506  
                                    UnsafeViewBackward0         0.01%       9.019ms         0.05%      54.130ms       8.320us       0.000us         0.00%      17.714ms       2.723us          6506  
       autograd::engine::evaluate_function: MmBackward0         0.02%      26.582ms         6.59%        7.906s       2.814ms       0.000us         0.00%        2.090s     743.652us          2810  
                                            MmBackward0         0.02%      27.907ms         6.57%        7.879s       2.804ms       0.000us         0.00%        2.089s     743.359us          2810  
        autograd::engine::evaluate_function: TBackward0         0.04%      51.006ms         0.06%      69.612ms      18.214us       0.000us         0.00%       0.000us       0.000us          3822  
                                             TBackward0         0.00%       4.575ms         0.01%      16.079ms       4.207us       0.000us         0.00%       0.000us       0.000us          3822  
    autograd::engine::evaluate_function: RsqrtBackward0         0.00%       4.980ms         0.31%     366.713ms     452.732us       0.000us         0.00%      27.255ms      33.648us           810  
                                         RsqrtBackward0         0.01%       7.189ms         0.30%     361.733ms     446.584us       0.000us         0.00%      27.255ms      33.648us           810  
      autograd::engine::evaluate_function: AddBackward0         0.02%      25.958ms         0.25%     303.615ms      91.505us       0.000us         0.00%       19.503s       5.878ms          3318  
                                           AddBackward0         0.00%     456.000us         0.00%     456.000us       0.137us       0.000us         0.00%       0.000us       0.000us          3318  
     autograd::engine::evaluate_function: MeanBackward1         0.00%       5.280ms         0.16%     195.266ms     241.069us       0.000us         0.00%      23.449ms      28.949us           810  
                                          MeanBackward1         0.00%       4.885ms         0.16%     188.969ms     233.295us       0.000us         0.00%      23.240ms      28.691us           810  
      autograd::engine::evaluate_function: PowBackward0         0.00%       5.820ms         0.63%     760.792ms     939.249us       0.000us         0.00%      87.498ms     108.022us           810  
                                           PowBackward0         0.00%       5.589ms         0.63%     754.972ms     932.064us       0.000us         0.00%      87.498ms     108.022us           810  
     autograd::engine::evaluate_function: SiluBackward0         0.00%       3.312ms         0.13%     156.150ms     390.375us       0.000us         0.00%      22.021ms      55.053us           400  
                                          SiluBackward0         0.00%       1.105ms         0.13%     152.701ms     381.752us       0.000us         0.00%      21.819ms      54.547us           400  
autograd::engine::evaluate_function: TransposeBackwa...         0.01%      12.263ms         0.02%      21.030ms       8.807us       0.000us         0.00%       0.000us       0.000us          2388  
                                     TransposeBackward0         0.00%       2.979ms         0.01%       8.649ms       3.622us       0.000us         0.00%       0.000us       0.000us          2388  
      autograd::engine::evaluate_function: BmmBackward0         0.01%      10.929ms         0.66%     786.578ms     626.256us       0.000us         0.00%     187.685ms     149.431us          1256  
                                           BmmBackward0         0.01%      10.361ms         0.65%     774.725ms     616.819us       0.000us         0.00%     187.216ms     149.057us          1256  
autograd::engine::evaluate_function: ExpandBackward0...         0.01%      10.152ms         0.01%      11.020ms       4.349us       0.000us         0.00%     236.000us       0.093us          2534  
                                        ExpandBackward0         0.00%     141.000us         0.00%     842.000us       0.332us       0.000us         0.00%     221.000us       0.087us          2534  
autograd::engine::evaluate_function: SoftmaxBackward...         0.00%       3.334ms         0.19%     221.934ms     353.398us       0.000us         0.00%     103.079ms     164.139us           628  
                                       SoftmaxBackward0         0.00%       2.883ms         0.18%     218.127ms     347.336us       0.000us         0.00%     101.915ms     162.285us           628  
autograd::engine::evaluate_function: MaximumBackward...         0.00%       3.694ms         7.89%        9.460s      23.650ms       0.000us         0.00%     108.514ms     271.285us           400  
                                       MaximumBackward0         0.00%       5.206ms         7.89%        9.456s      23.640ms       0.000us         0.00%     108.514ms     271.285us           400  
      autograd::engine::evaluate_function: DivBackward0         0.00%       2.347ms         0.31%     377.339ms     799.447us       0.000us         0.00%      21.026ms      44.547us           472  
                                           DivBackward0         0.00%       1.162ms         0.31%     374.910ms     794.301us       0.000us         0.00%      20.968ms      44.424us           472  
      autograd::engine::evaluate_function: CatBackward0         0.01%       7.143ms         0.02%      20.106ms      20.643us       0.000us         0.00%       0.000us       0.000us           974  
                                           CatBackward0         0.00%       3.615ms         0.01%      12.880ms      13.224us       0.000us         0.00%       0.000us       0.000us           974  
                                           aten::narrow         0.01%      10.976ms         0.02%      18.633ms       2.147us       0.000us         0.00%       0.000us       0.000us          8678  
      autograd::engine::evaluate_function: NegBackward0         0.00%       2.331ms         0.24%     292.449ms     365.561us       0.000us         0.00%      19.146ms      23.933us           800  
                                           NegBackward0         0.00%       2.045ms         0.24%     289.226ms     361.533us       0.000us         0.00%      18.868ms      23.585us           800  
autograd::engine::evaluate_function: SplitWithSizesB...         0.02%      18.752ms         0.17%     199.896ms     487.551us       0.000us         0.00%     674.355ms       1.645ms           410  
                                SplitWithSizesBackward0         0.01%       6.774ms         0.15%     180.149ms     439.388us       0.000us         0.00%     643.090ms       1.569ms           410  
autograd::engine::evaluate_function: torch::autograd...         0.02%      23.000ms         0.36%     427.548ms       1.043ms       0.000us         0.00%       28.385s      69.233ms           410  
                        torch::autograd::AccumulateGrad         0.00%       1.009ms         0.00%       2.329ms       5.680us       0.000us         0.00%       0.000us       0.000us           410  
                                           aten::detach         0.00%     517.000us         0.00%       1.345ms       3.280us       0.000us         0.00%       0.000us       0.000us           410  
                                                 detach         0.00%     849.000us         0.00%     849.000us       2.071us       0.000us         0.00%       0.000us       0.000us           410  
           FullyShardedDataParallel._post_backward_hook         0.23%     270.341ms         0.34%     402.173ms     980.910us       0.000us         0.00%       28.385s      69.233ms           410  
                                            aten::chunk         0.00%       2.009ms         0.02%      23.401ms      57.076us       0.000us         0.00%       0.000us       0.000us           410  
                                            aten::split         0.01%      11.597ms         0.02%      22.341ms      54.490us       0.000us         0.00%       0.000us       0.000us           410  
                            c10d::_reduce_scatter_base_         0.00%       2.274ms         0.03%      32.895ms      80.232us       0.000us         0.00%       27.985s      68.257ms           410  
                              nccl:_reduce_scatter_base         0.00%       0.000us             0      22.752ms      55.493us       0.000us         0.00%       0.000us       0.000us           410  
autograd::engine::evaluate_function: EmbeddingBackwa...         0.00%     550.000us         0.01%       6.157ms     181.088us       0.000us         0.00%      23.651ms     695.618us            34  
                                     EmbeddingBackward0         0.00%     193.000us         0.00%       5.237ms     154.029us       0.000us         0.00%       9.330ms     274.412us            34  
                               aten::embedding_backward         0.00%      98.000us         0.00%       5.131ms     150.912us       0.000us         0.00%       9.598ms     282.294us            34  
                              Optimizer.step#AdamW.step         0.01%      10.380ms         0.04%      53.885ms       5.388ms       0.000us         0.00%     448.610ms      44.861ms            10  
                    Optimizer.zero_grad#AdamW.zero_grad         0.00%       1.375ms         0.00%       1.375ms     137.500us       0.000us         0.00%       0.000us       0.000us            10  
                                           aten::conv2d         0.00%      46.000us         0.00%     931.000us     232.750us       0.000us         0.00%     426.000us     106.500us             4  
                                      aten::convolution         0.00%      67.000us         0.00%     885.000us     221.250us       0.000us         0.00%     426.000us     106.500us             4  
                                     aten::_convolution         0.00%      72.000us         0.00%     818.000us     204.500us       0.000us         0.00%     426.000us     106.500us             4  
                                  cudaStreamGetPriority         0.00%       1.000us         0.00%       1.000us       0.125us       0.000us         0.00%       0.000us       0.000us             8  
                       cudaDeviceGetStreamPriorityRange         0.00%       2.000us         0.00%       2.000us       0.250us       0.000us         0.00%       0.000us       0.000us             8  
                                          aten::flatten         0.00%       8.000us         0.00%      25.000us       6.250us       0.000us         0.00%       0.000us       0.000us             4  
                                          aten::dropout         0.00%       3.000us         0.00%       3.000us       0.004us       0.000us         0.00%       0.000us       0.000us           668  
                                       aten::layer_norm         0.00%       3.529ms         0.06%      68.509ms      77.851us       0.000us         0.00%      13.901ms      15.797us           880  
                                       aten::zeros_like         0.00%     563.000us         0.00%       3.810ms      24.423us       0.000us         0.00%     448.000us       2.872us           156  
                                          aten::permute         0.00%       3.386ms         0.00%       3.601ms       4.055us       0.000us         0.00%       0.000us       0.000us           888  
    autograd::engine::evaluate_function: AddmmBackward0         0.02%      21.797ms         1.08%        1.291s       1.275ms       0.000us         0.00%      97.321ms      96.167us          1012  
                                         AddmmBackward0         0.01%      10.352ms         0.84%        1.008s     995.763us       0.000us         0.00%      78.565ms      77.633us          1012  
autograd::engine::evaluate_function: NativeLayerNorm...         0.00%       5.221ms         0.04%      51.336ms     116.673us       0.000us         0.00%      13.157ms      29.902us           440  
                               NativeLayerNormBackward0         0.00%       1.850ms         0.04%      45.697ms     103.857us       0.000us         0.00%      12.913ms      29.348us           440  
     autograd::engine::evaluate_function: GeluBackward0         0.00%       1.435ms         0.19%     229.524ms       1.125ms       0.000us         0.00%       4.229ms      20.730us           204  
                                          GeluBackward0         0.00%     644.000us         0.19%     227.983ms       1.118ms       0.000us         0.00%       4.107ms      20.132us           204  
autograd::engine::evaluate_function: PermuteBackward...         0.00%       1.756ms         0.00%       3.966ms       8.932us       0.000us         0.00%       0.000us       0.000us           444  
                                       PermuteBackward0         0.00%     745.000us         0.00%       2.104ms       4.739us       0.000us         0.00%       0.000us       0.000us           444  
                                   aten::_reshape_alias         0.00%     146.000us         0.00%     146.000us       1.921us       0.000us         0.00%       0.000us       0.000us            76  
autograd::engine::evaluate_function: SelectBackward0...         0.00%       3.267ms         0.03%      36.295ms      77.553us       0.000us         0.00%      25.108ms      53.650us           468  
                                        SelectBackward0         0.00%     999.000us         0.02%      24.843ms      53.083us       0.000us         0.00%      20.466ms      43.731us           468  
                                  aten::select_backward         0.00%       1.854ms         0.02%      24.241ms      51.797us       0.000us         0.00%      20.789ms      44.421us           468  
autograd::engine::evaluate_function: ConvolutionBack...         0.00%     108.000us         0.00%     940.000us     235.000us       0.000us         0.00%       1.102ms     275.500us             4  
                                   ConvolutionBackward0         0.00%      26.000us         0.00%     832.000us     208.000us       0.000us         0.00%       1.102ms     275.500us             4  
                                  cudaDeviceSynchronize         0.00%      52.000us         0.00%      52.000us      52.000us       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------

Then is the single-node result:

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                 c10d::_allgather_base_         0.32%      56.208ms         0.38%      67.078ms      82.812us        4.689s        29.27%        4.689s       5.789ms           810  
ncclKernel_AllGather_RING_LL_Sum_int8_t(ncclDevComm*...         0.00%       0.000us         0.00%       0.000us       0.000us        4.689s        29.27%        4.689s       5.789ms           810  
                                               aten::mm         2.46%     430.458ms         5.16%     902.353ms      86.317us        3.460s        21.60%        3.460s     330.939us         10454  
                                     record_param_comms         0.18%      31.059ms         0.20%      35.489ms      21.772us        3.337s        20.83%        3.337s       2.047ms          1630  
ncclKernel_ReduceScatter_RING_LL_Sum_half(ncclDevCom...         0.00%       0.000us         0.00%       0.000us       0.000us        3.337s        20.83%        3.337s       8.140ms           410  
                                            aten::copy_         0.77%     134.901ms         7.44%        1.302s      64.611us     878.095ms         5.48%     878.095ms      43.587us         20146  
ampere_fp16_s16816gemm_fp16_128x128_ldg8_f2f_stages_...         0.00%       0.000us         0.00%       0.000us       0.000us     851.835ms         5.32%     851.835ms     281.319us          3028  
ampere_fp16_s16816gemm_fp16_256x128_ldg8_f2f_stages_...         0.00%       0.000us         0.00%       0.000us       0.000us     849.478ms         5.30%     849.478ms     349.292us          2432  
ampere_fp16_s16816gemm_fp16_256x128_ldg8_f2f_stages_...         0.00%       0.000us         0.00%       0.000us       0.000us     783.768ms         4.89%     783.768ms     399.882us          1960  
                                              aten::cat         0.24%      42.263ms         0.41%      71.871ms      51.190us     651.682ms         4.07%     651.682ms     464.161us          1404  
void at::native::(anonymous namespace)::CatArrayBatc...         0.00%       0.000us         0.00%       0.000us       0.000us     613.315ms         3.83%     613.315ms     979.736us           626  
                                              aten::mul         2.17%     379.254ms         3.84%     672.051ms      49.863us     499.485ms         3.12%     499.485ms      37.059us         13478  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     434.046ms         2.71%     434.046ms     199.653us          2174  
                                             aten::div_         0.05%       9.122ms         0.10%      17.476ms      21.312us     377.147ms         2.35%     377.147ms     459.935us           820  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us     373.290ms         2.33%     373.290ms      64.852us          5756  
                                              aten::bmm         0.66%     116.150ms         1.57%     273.963ms      72.708us     303.164ms         1.89%     303.164ms      80.458us          3768  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us     255.600ms         1.60%     255.600ms      57.959us          4410  
                                    aten::_foreach_mul_         0.05%       9.346ms         0.11%      19.014ms     633.800us     252.931ms         1.58%     252.931ms       8.431ms            30  
void at::native::(anonymous namespace)::multi_tensor...         0.00%       0.000us         0.00%       0.000us       0.000us     252.931ms         1.58%     252.931ms     110.935us          2280  
void cutlass::Kernel<cutlass_80_tensorop_f16_s16816g...         0.00%       0.000us         0.00%       0.000us       0.000us     217.189ms         1.36%     217.189ms     452.477us           480  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     201.673ms         1.26%     201.673ms      30.922us          6522  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     196.401ms         1.23%     196.401ms      40.747us          4820  
void cutlass::Kernel<cutlass_80_tensorop_f16_s16816g...         0.00%       0.000us         0.00%       0.000us       0.000us     178.622ms         1.12%     178.622ms     744.258us           240  
                                aten::_foreach_addcdiv_         0.01%       2.342ms         0.03%       5.493ms     549.300us     153.018ms         0.96%     153.018ms      15.302ms            10  
void at::native::(anonymous namespace)::multi_tensor...         0.00%       0.000us         0.00%       0.000us       0.000us     153.018ms         0.96%     153.018ms     201.339us           760  
void cutlass::Kernel<cutlass_80_tensorop_f16_s16816g...         0.00%       0.000us         0.00%       0.000us       0.000us     135.466ms         0.85%     135.466ms     338.665us           400  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     125.264ms         0.78%     125.264ms      21.545us          5814  
                                aten::_foreach_addcmul_         0.01%       2.365ms         0.03%       5.940ms     594.000us     119.682ms         0.75%     119.682ms      11.968ms            10  
void at::native::(anonymous namespace)::multi_tensor...         0.00%       0.000us         0.00%       0.000us       0.000us     119.682ms         0.75%     119.682ms     157.476us           760  
                                    aten::_foreach_add_         0.05%       9.424ms         0.11%      19.188ms     959.400us     118.507ms         0.74%     118.507ms       5.925ms            20  
void at::native::(anonymous namespace)::multi_tensor...         0.00%       0.000us         0.00%       0.000us       0.000us     118.507ms         0.74%     118.507ms     155.930us           760  
                                             aten::add_         0.18%      31.232ms         1.83%     319.686ms      59.643us     118.274ms         0.74%     118.274ms      22.066us          5360  
                                              aten::add         0.58%     101.757ms         0.97%     170.493ms      41.162us     114.876ms         0.72%     114.876ms      27.734us          4142  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     100.275ms         0.63%     100.275ms      59.334us          1690  
ampere_fp16_s16816gemm_fp16_128x256_ldg8_f2f_stages_...         0.00%       0.000us         0.00%       0.000us       0.000us      95.594ms         0.60%      95.594ms     398.308us           240  
                                    aten::_foreach_div_         0.01%       2.206ms         0.03%       5.354ms     535.400us      85.788ms         0.54%      85.788ms       8.579ms            10  
void at::native::(anonymous namespace)::multi_tensor...         0.00%       0.000us         0.00%       0.000us       0.000us      85.788ms         0.54%      85.788ms     112.879us           760  
       aten::_amp_foreach_non_finite_check_and_unscale_         0.01%       1.989ms         0.03%       5.343ms     534.300us      85.313ms         0.53%      85.313ms       8.531ms            10  
void at::native::(anonymous namespace)::multi_tensor...         0.00%       0.000us         0.00%       0.000us       0.000us      85.313ms         0.53%      85.313ms     112.254us           760  
                                    aten::_foreach_sqrt         0.02%       2.977ms         0.06%      11.152ms       1.115ms      83.868ms         0.52%      83.868ms       8.387ms            10  
void at::native::(anonymous namespace)::multi_tensor...         0.00%       0.000us         0.00%       0.000us       0.000us      83.868ms         0.52%      83.868ms     110.353us           760  
                                     aten::_foreach_add         0.02%       3.406ms         0.05%       9.092ms     909.200us      83.168ms         0.52%      83.168ms       8.317ms            10  
void at::native::(anonymous namespace)::multi_tensor...         0.00%       0.000us         0.00%       0.000us       0.000us      83.168ms         0.52%      83.168ms     109.432us           760  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      81.010ms         0.51%      81.010ms      35.719us          2268  
ampere_fp16_s16816gemm_fp16_128x256_ldg8_f2f_stages_...         0.00%       0.000us         0.00%       0.000us       0.000us      79.533ms         0.50%      79.533ms     662.775us           120  
                                              aten::div         0.38%      66.574ms         0.56%      97.925ms      45.462us      78.680ms         0.49%      78.680ms      36.527us          2154  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      76.412ms         0.48%      76.412ms      38.016us          2010  
                                              aten::sum         0.42%      74.047ms         0.83%     145.223ms      54.269us      73.978ms         0.46%      74.042ms      27.669us          2676  
                                            aten::fill_         0.15%      26.940ms         2.48%     434.424ms      61.169us      67.110ms         0.42%      67.110ms       9.449us          7102  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      65.900ms         0.41%      65.900ms       9.553us          6898  
void cutlass::Kernel<cutlass_75_tensorop_f16_s1688ge...         0.00%       0.000us         0.00%       0.000us       0.000us      62.214ms         0.39%      62.214ms     130.702us           476  
void cutlass::Kernel<cutlass_75_tensorop_f16_s1688ge...         0.00%       0.000us         0.00%       0.000us       0.000us      57.919ms         0.36%      57.919ms     202.514us           286  
void cutlass::Kernel<cutlass_80_tensorop_f16_s16816g...         0.00%       0.000us         0.00%       0.000us       0.000us      55.403ms         0.35%      55.403ms     692.538us            80  
void cutlass::Kernel<cutlass_75_tensorop_f16_s1688ge...         0.00%       0.000us         0.00%       0.000us       0.000us      52.633ms         0.33%      52.633ms     105.266us           500  
                           aten::_softmax_backward_data         0.10%      17.824ms         0.28%      49.466ms      78.768us      45.710ms         0.29%      91.665ms     145.963us           628  
void (anonymous namespace)::softmax_warp_backward<fl...         0.00%       0.000us         0.00%       0.000us       0.000us      44.640ms         0.28%      44.640ms      82.667us           540  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      43.731ms         0.27%      43.731ms      53.989us           810  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      43.453ms         0.27%      43.453ms      20.911us          2078  
void at::native::reduce_kernel<128, 4, at::native::R...         0.00%       0.000us         0.00%       0.000us       0.000us      41.334ms         0.26%      41.334ms      22.391us          1846  
                                              aten::neg         0.24%      41.156ms         0.66%     116.086ms      72.554us      41.016ms         0.26%      41.016ms      25.635us          1600  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      41.016ms         0.26%      41.016ms      25.635us          1600  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      40.370ms         0.25%      40.370ms      49.840us           810  
ampere_fp16_s16816gemm_fp16_128x128_ldg8_f2f_stages_...         0.00%       0.000us         0.00%       0.000us       0.000us      40.110ms         0.25%      40.110ms     250.688us           160  
void at::native::(anonymous namespace)::CatArrayBatc...         0.00%       0.000us         0.00%       0.000us       0.000us      37.920ms         0.24%      37.920ms      47.400us           800  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      33.883ms         0.21%      33.883ms      41.831us           810  
                                         aten::_softmax         0.07%      13.060ms         0.15%      25.694ms      40.914us      33.681ms         0.21%      33.681ms      53.632us           628  
                                            aten::addmm         0.20%      35.246ms         0.28%      49.657ms      49.068us      32.942ms         0.21%      32.942ms      32.551us          1012  
void at::native::reduce_kernel<512, 1, at::native::R...         0.00%       0.000us         0.00%       0.000us       0.000us      32.553ms         0.20%      32.553ms      40.189us           810  
void (anonymous namespace)::softmax_warp_forward<c10...         0.00%       0.000us         0.00%       0.000us       0.000us      32.410ms         0.20%      32.410ms      60.019us           540  
void cutlass::Kernel<cutlass_75_tensorop_f16_s1688ge...         0.00%       0.000us         0.00%       0.000us       0.000us      31.524ms         0.20%      31.524ms       3.152ms            10  
void cutlass::Kernel<cutlass_75_tensorop_f16_s1688ge...         0.00%       0.000us         0.00%       0.000us       0.000us      31.266ms         0.20%      31.266ms      97.706us           320  
void cutlass::Kernel<cutlass_75_tensorop_f16_s1688ge...         0.00%       0.000us         0.00%       0.000us       0.000us      31.072ms         0.19%      31.072ms      71.266us           436  
void cutlass::Kernel<cutlass_75_tensorop_f16_s1688ge...         0.00%       0.000us         0.00%       0.000us       0.000us      30.593ms         0.19%      30.593ms     126.942us           241  
                                            aten::where         0.03%       4.656ms         0.62%     108.856ms     272.140us      30.107ms         0.19%      30.107ms      75.267us           400  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      30.107ms         0.19%      30.107ms      75.267us           400  
                                             aten::mean         0.12%      21.399ms         0.18%      31.785ms      39.241us      30.030ms         0.19%      30.030ms      37.074us           810  
void at::native::reduce_kernel<512, 1, at::native::R...         0.00%       0.000us         0.00%       0.000us       0.000us      30.030ms         0.19%      30.030ms      37.074us           810  
void cutlass::Kernel<cutlass_80_tensorop_f16_s16816g...         0.00%       0.000us         0.00%       0.000us       0.000us      28.160ms         0.18%      28.160ms     704.000us            40  
void cutlass::Kernel<cutlass_75_tensorop_f16_s1688ge...         0.00%       0.000us         0.00%       0.000us       0.000us      27.342ms         0.17%      27.342ms       3.418ms             8  
                                              aten::pow         0.41%      72.388ms         0.81%     141.718ms      57.445us      26.938ms         0.17%      60.058ms      24.345us          2467  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      25.851ms         0.16%      25.851ms      15.957us          1620  
                                     aten::masked_fill_         0.02%       3.142ms         0.65%     114.083ms     265.309us      22.715ms         0.14%      22.715ms      52.826us           430  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      22.635ms         0.14%      22.635ms      55.207us           410  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      21.781ms         0.14%      21.781ms      26.890us           810  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      21.705ms         0.14%      21.705ms      26.796us           810  
                                    aten::silu_backward         0.08%      13.361ms         0.09%      16.531ms      41.328us      18.899ms         0.12%      18.899ms      47.248us           400  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      18.899ms         0.12%      18.899ms      47.248us           400  
                                               aten::eq         0.07%      12.493ms         2.87%     502.275ms       1.225ms      18.203ms         0.11%      18.203ms      44.398us           410  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      18.167ms         0.11%      18.167ms      45.417us           400  
                                             aten::silu         0.05%       8.964ms         0.15%      26.433ms      66.082us      18.033ms         0.11%      18.033ms      45.083us           400  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      18.033ms         0.11%      18.033ms      45.083us           400  
void cutlass::Kernel<cutlass_75_tensorop_f16_s1688ge...         0.00%       0.000us         0.00%       0.000us       0.000us      17.740ms         0.11%      17.740ms       4.435ms             4  
                                          aten::maximum         0.05%       9.274ms         0.13%      22.551ms      56.377us      17.540ms         0.11%      17.540ms      43.850us           400  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      17.540ms         0.11%      17.540ms      43.850us           400  
                                               aten::lt         0.06%      10.860ms         0.08%      13.481ms      32.880us      17.165ms         0.11%      17.165ms      41.866us           410  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      17.111ms         0.11%      17.111ms      42.778us           400  
                                        Memset (Device)         0.00%       0.000us         0.00%       0.000us       0.000us      12.642ms         0.08%      12.642ms       2.451us          5157  
void cutlass::Kernel<cutlass_75_tensorop_f16_s1688ge...         0.00%       0.000us         0.00%       0.000us       0.000us      11.222ms         0.07%      11.222ms      56.964us           197  
ampere_fp16_s16816gemm_fp16_128x64_ldg8_relu_f2f_sta...         0.00%       0.000us         0.00%       0.000us       0.000us       9.845ms         0.06%       9.845ms      31.554us           312  
ampere_fp16_s16816gemm_fp16_128x128_ldg8_relu_f2f_st...         0.00%       0.000us         0.00%       0.000us       0.000us       9.654ms         0.06%       9.654ms      61.885us           156  
ampere_fp16_s16816gemm_fp16_128x128_ldg8_f2f_stages_...         0.00%       0.000us         0.00%       0.000us       0.000us       9.174ms         0.06%       9.174ms      58.808us           156  
void cutlass::Kernel<cutlass_75_tensorop_f16_s1688ge...         0.00%       0.000us         0.00%       0.000us       0.000us       8.511ms         0.05%       8.511ms      43.423us           196  
ampere_fp16_s16816gemm_fp16_128x64_ldg8_relu_f2f_sta...         0.00%       0.000us         0.00%       0.000us       0.000us       8.284ms         0.05%       8.284ms      53.103us           156  
                       aten::native_layer_norm_backward         0.03%       5.880ms         0.24%      41.941ms      95.320us       8.133ms         0.05%       8.133ms      18.484us           440  
                                           aten::gather         0.14%      24.211ms         0.17%      29.570ms      36.962us       7.464ms         0.05%       7.464ms       9.330us           800  
void at::native::_scatter_gather_elementwise_kernel<...         0.00%       0.000us         0.00%       0.000us       0.000us       7.464ms         0.05%       7.464ms       9.330us           800  
void cutlass::Kernel<cutlass_80_tensorop_f16_s16816g...         0.00%       0.000us         0.00%       0.000us       0.000us       7.450ms         0.05%       7.450ms      93.125us            80  
void cutlass::Kernel<cutlass_80_tensorop_f16_s16816g...         0.00%       0.000us         0.00%       0.000us       0.000us       7.020ms         0.04%       7.020ms      87.750us            80  
void cutlass::Kernel<cutlass_75_wmma_tensorop_f16_s1...         0.00%       0.000us         0.00%       0.000us       0.000us       6.383ms         0.04%       6.383ms      40.917us           156  
void cutlass::Kernel<cutlass_80_tensorop_f16_s16816g...         0.00%       0.000us         0.00%       0.000us       0.000us       6.379ms         0.04%       6.379ms      79.737us            80  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       5.833ms         0.04%       5.833ms       7.201us           810  
                                            aten::rsqrt         0.11%      18.773ms         0.22%      37.722ms      46.342us       5.708ms         0.04%       5.708ms       7.012us           814  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       5.708ms         0.04%       5.708ms       7.047us           810  
void cutlass::Kernel<cutlass_80_tensorop_f16_s16816g...         0.00%       0.000us         0.00%       0.000us       0.000us       5.468ms         0.03%       5.468ms      68.350us            80  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       5.233ms         0.03%       5.233ms       6.460us           810  
void cutlass::Kernel<cutlass_80_tensorop_f16_s16816g...         0.00%       0.000us         0.00%       0.000us       0.000us       5.221ms         0.03%       5.221ms      65.263us            80  
void cutlass::Kernel<cutlass_80_tensorop_f16_s16816g...         0.00%       0.000us         0.00%       0.000us       0.000us       5.010ms         0.03%       5.010ms      62.625us            80  
void at::native::(anonymous namespace)::layer_norm_g...         0.00%       0.000us         0.00%       0.000us       0.000us       4.001ms         0.02%       4.001ms       9.093us           440  
                                aten::native_layer_norm         0.05%       9.368ms         0.10%      18.318ms      41.632us       3.615ms         0.02%       3.615ms       8.216us           440  
void at::native::(anonymous namespace)::vectorized_l...         0.00%       0.000us         0.00%       0.000us       0.000us       3.615ms         0.02%       3.615ms       8.216us           440  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       3.453ms         0.02%       3.453ms       8.422us           410  
ampere_fp16_s16816gemm_fp16_128x64_ldg8_f2f_stages_3...         0.00%       0.000us         0.00%       0.000us       0.000us       3.439ms         0.02%       3.439ms      16.858us           204  
                                    aten::gelu_backward         0.02%       3.061ms         0.09%      15.351ms      75.250us       3.028ms         0.02%       3.028ms      14.843us           204  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       3.028ms         0.02%       3.028ms      14.843us           204  
sm80_xmma_gemm_f16f16_f16f32_f32_nt_n_tilesize96x128...         0.00%       0.000us         0.00%       0.000us       0.000us       2.960ms         0.02%       2.960ms      18.974us           156  
void at::native::(anonymous namespace)::GammaBetaBac...         0.00%       0.000us         0.00%       0.000us       0.000us       2.528ms         0.02%       2.528ms       8.000us           316  
ampere_fp16_s16816gemm_fp16_64x64_sliced1x2_ldg8_f2f...         0.00%       0.000us         0.00%       0.000us       0.000us       2.459ms         0.02%       2.459ms      10.246us           240  
void cutlass::Kernel<cutlass_80_tensorop_f16_s16816g...         0.00%       0.000us         0.00%       0.000us       0.000us       2.189ms         0.01%       2.189ms       7.601us           288  
ampere_fp16_s16816gemm_fp16_128x64_ldg8_f2f_stages_6...         0.00%       0.000us         0.00%       0.000us       0.000us       1.926ms         0.01%       1.926ms       8.025us           240  
                                             aten::gelu         0.02%       3.090ms         0.03%       4.395ms      21.544us       1.919ms         0.01%       1.919ms       9.407us           204  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.919ms         0.01%       1.919ms       9.407us           204  
void cutlass::Kernel<cutlass_75_tensorop_f16_s1688ge...         0.00%       0.000us         0.00%       0.000us       0.000us       1.890ms         0.01%       1.890ms      47.250us            40  
void at::native::(anonymous namespace)::GammaBetaBac...         0.00%       0.000us         0.00%       0.000us       0.000us       1.604ms         0.01%       1.604ms      12.935us           124  
void cutlass::Kernel<cutlass_75_tensorop_f16_s1688ge...         0.00%       0.000us         0.00%       0.000us       0.000us       1.503ms         0.01%       1.503ms      37.575us            40  
void splitKreduce_kernel<32, 16, int, __half, __half...         0.00%       0.000us         0.00%       0.000us       0.000us       1.438ms         0.01%       1.438ms       4.993us           288  
                       aten::_log_softmax_backward_data         0.00%     156.000us         0.00%     231.000us      23.100us       1.251ms         0.01%       1.251ms     125.100us            10  
void at::native::(anonymous namespace)::cunn_SoftMax...         0.00%       0.000us         0.00%       0.000us       0.000us       1.251ms         0.01%       1.251ms     125.100us            10  
                                     aten::_log_softmax         0.00%     247.000us         0.00%     334.000us      33.400us       1.211ms         0.01%       1.211ms     121.100us            10  
void at::native::(anonymous namespace)::cunn_SoftMax...         0.00%       0.000us         0.00%       0.000us       0.000us       1.211ms         0.01%       1.211ms     121.100us            10  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.090ms         0.01%       1.090ms       7.786us           140  
void (anonymous namespace)::softmax_warp_forward<c10...         0.00%       0.000us         0.00%       0.000us       0.000us       1.079ms         0.01%       1.079ms      26.975us            40  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       1.064ms         0.01%       1.064ms       6.821us           156  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.025ms         0.01%       1.025ms       4.745us           216  
                         aten::embedding_dense_backward         0.00%     291.000us         0.02%       3.145ms      92.500us     898.000us         0.01%      10.133ms     298.029us            34  
void at::native::(anonymous namespace)::embedding_ba...         0.00%       0.000us         0.00%       0.000us       0.000us     898.000us         0.01%     898.000us      26.412us            34  
sm80_xmma_gemm_f16f16_f16f32_f32_nt_n_tilesize64x96x...         0.00%       0.000us         0.00%       0.000us       0.000us     863.000us         0.01%     863.000us      17.979us            48  
void (anonymous namespace)::softmax_warp_backward<fl...         0.00%       0.000us         0.00%       0.000us       0.000us     818.000us         0.01%     818.000us      20.450us            40  
sm80_xmma_gemm_f16f16_f16f32_f32_nt_n_tilesize160x12...         0.00%       0.000us         0.00%       0.000us       0.000us     748.000us         0.00%     748.000us      14.385us            52  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     733.000us         0.00%     733.000us       4.470us           164  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us     723.000us         0.00%     723.000us       6.025us           120  
ampere_fp16_s16816gemm_fp16_64x64_ldg8_relu_f2f_stag...         0.00%       0.000us         0.00%       0.000us       0.000us     700.000us         0.00%     700.000us      14.583us            48  
                                     aten::index_select         0.00%     471.000us         0.01%       1.050ms      30.882us     604.000us         0.00%     604.000us      17.765us            34  
void cutlass::Kernel<cutlass_80_wmma_tensorop_f16_s1...         0.00%       0.000us         0.00%       0.000us       0.000us     595.000us         0.00%     595.000us       6.198us            96  
ampere_fp16_s1688gemm_fp16_256x64_ldg8_f2f_stages_32...         0.00%       0.000us         0.00%       0.000us       0.000us     587.000us         0.00%     587.000us      12.229us            48  
void cutlass::Kernel<cutlass_80_tensorop_f16_s16816g...         0.00%       0.000us         0.00%       0.000us       0.000us     579.000us         0.00%     579.000us      12.062us            48  
void cutlass::Kernel<cutlass_75_wmma_tensorop_f16_s1...         0.00%       0.000us         0.00%       0.000us       0.000us     573.000us         0.00%     573.000us      23.875us            24  
ampere_s16816gemm_fp16_64x64_sliced1x2_ldg8_stages_6...         0.00%       0.000us         0.00%       0.000us       0.000us     568.000us         0.00%     568.000us      11.833us            48  
ampere_fp16_s16816gemm_fp16_64x64_sliced1x2_ldg8_f2f...         0.00%       0.000us         0.00%       0.000us       0.000us     557.000us         0.00%     557.000us      11.604us            48  
                             aten::convolution_backward         0.00%     380.000us         0.00%     789.000us     197.250us     548.000us         0.00%     763.000us     190.750us             4  
void at::native::(anonymous namespace)::indexSelectL...         0.00%       0.000us         0.00%       0.000us       0.000us     441.000us         0.00%     441.000us      44.100us            10  
                                aten::cudnn_convolution         0.00%     379.000us         0.00%     552.000us     138.000us     397.000us         0.00%     397.000us      99.250us             4  
sm80_xmma_wgrad_implicit_gemm_indexed_wo_smem_f16f16...         0.00%       0.000us         0.00%       0.000us       0.000us     376.000us         0.00%     376.000us      94.000us             4  
void cutlass::Kernel<cutlass_80_wmma_tensorop_f16_s1...         0.00%       0.000us         0.00%       0.000us       0.000us     313.000us         0.00%     313.000us       6.521us            48  
void cutlass::Kernel<cutlass_80_wmma_tensorop_f16_s1...         0.00%       0.000us         0.00%       0.000us       0.000us     308.000us         0.00%     308.000us       6.417us            48  
void splitKreduce_kernel<32, 16, int, float, __half,...         0.00%       0.000us         0.00%       0.000us       0.000us     278.000us         0.00%     278.000us       5.792us            48  
void cutlass::Kernel<cutlass_75_wmma_tensorop_f16_s1...         0.00%       0.000us         0.00%       0.000us       0.000us     267.000us         0.00%     267.000us      11.125us            24  
                                 aten::nll_loss_forward         0.00%     329.000us         0.00%     584.000us      58.400us     265.000us         0.00%     265.000us      26.500us            10  
void at::native::(anonymous namespace)::nll_loss_for...         0.00%       0.000us         0.00%       0.000us       0.000us     265.000us         0.00%     265.000us      26.500us            10  
void at::native::(anonymous namespace)::CatArrayBatc...         0.00%       0.000us         0.00%       0.000us       0.000us     260.000us         0.00%     260.000us      26.000us            10  
void (anonymous namespace)::softmax_warp_backward<fl...         0.00%       0.000us         0.00%       0.000us       0.000us     252.000us         0.00%     252.000us       5.250us            48  
void cutlass::Kernel<cutlass_80_wmma_tensorop_f16_s1...         0.00%       0.000us         0.00%       0.000us       0.000us     240.000us         0.00%     240.000us       5.000us            48  
void cutlass::Kernel<cutlass_80_wmma_tensorop_f16_s1...         0.00%       0.000us         0.00%       0.000us       0.000us     240.000us         0.00%     240.000us       5.000us            48  
void cutlass::Kernel<cutlass_75_wmma_tensorop_f16_s1...         0.00%       0.000us         0.00%       0.000us       0.000us     240.000us         0.00%     240.000us      10.000us            24  
sm80_xmma_fprop_implicit_gemm_indexed_wo_smem_f16f16...         0.00%       0.000us         0.00%       0.000us       0.000us     236.000us         0.00%     236.000us      59.000us             4  
void cutlass::Kernel<cutlass_75_wmma_tensorop_f16_s1...         0.00%       0.000us         0.00%       0.000us       0.000us     218.000us         0.00%     218.000us       9.083us            24  
                         Memcpy HtoD (Pinned -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us     215.000us         0.00%     215.000us      21.500us            10  
void cutlass::Kernel<cutlass_75_tensorop_f16_s1688ge...         0.00%       0.000us         0.00%       0.000us       0.000us     209.000us         0.00%     209.000us       8.708us            24  
void cudnn::ops::nchwToNhwcKernel<__half, __half, fl...         0.00%       0.000us         0.00%       0.000us       0.000us     198.000us         0.00%     198.000us      12.375us            16  
                                aten::nll_loss_backward         0.00%     245.000us         0.00%     605.000us      60.500us     192.000us         0.00%     759.000us      75.900us            10  
void at::native::(anonymous namespace)::nll_loss_bac...         0.00%       0.000us         0.00%       0.000us       0.000us     192.000us         0.00%     192.000us      19.200us            10  
void (anonymous namespace)::softmax_warp_forward<c10...         0.00%       0.000us         0.00%       0.000us       0.000us     192.000us         0.00%     192.000us       4.000us            48  
void at::native::(anonymous namespace)::indexSelectS...         0.00%       0.000us         0.00%       0.000us       0.000us     163.000us         0.00%     163.000us       6.792us            24  
                       Memcpy HtoD (Pageable -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us     153.000us         0.00%     153.000us       1.500us           102  
void cudnn::ops::nhwcToNchwKernel<__half, __half, fl...         0.00%       0.000us         0.00%       0.000us       0.000us     123.000us         0.00%     123.000us      15.375us             8  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     108.000us         0.00%     108.000us       3.000us            36  
                              aten::_local_scalar_dense         0.00%     321.000us        12.66%        2.216s       1.327ms     104.000us         0.00%     104.000us       0.062us          1670  
                       Memcpy DtoH (Device -> Pageable)         0.00%       0.000us         0.00%       0.000us       0.000us     104.000us         0.00%     104.000us       2.080us            50  
void at::native::(anonymous namespace)::CatArrayBatc...         0.00%       0.000us         0.00%       0.000us       0.000us      92.000us         0.00%      92.000us      11.500us             8  
void at::native::reduce_kernel<512, 1, at::native::R...         0.00%       0.000us         0.00%       0.000us       0.000us      91.000us         0.00%      91.000us       4.550us            20  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      81.000us         0.00%      81.000us       3.240us            25  
ampere_fp16_s16816gemm_fp16_128x128_ldg8_f2f_stages_...         0.00%       0.000us         0.00%       0.000us       0.000us      79.000us         0.00%      79.000us      19.750us             4  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      69.000us         0.00%      69.000us       6.900us            10  
                                              aten::sub         0.00%     234.000us         0.00%     348.000us      19.333us      65.000us         0.00%      65.000us       3.611us            18  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      65.000us         0.00%      65.000us       3.611us            18  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      64.000us         0.00%      64.000us       3.200us            20  
                                               aten::ne         0.00%     304.000us         0.00%     508.000us      25.400us      60.000us         0.00%      60.000us       3.000us            20  
                                           aten::arange         0.00%     252.000us         0.01%       1.128ms      28.200us      60.000us         0.00%     120.000us       3.000us            40  
void (anonymous namespace)::elementwise_kernel_with_...         0.00%       0.000us         0.00%       0.000us       0.000us      60.000us         0.00%      60.000us       3.000us            20  
ampere_fp16_s16816gemm_fp16_128x64_ldg8_relu_f2f_sta...         0.00%       0.000us         0.00%       0.000us       0.000us      58.000us         0.00%      58.000us      14.500us             4  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      57.000us         0.00%      57.000us       5.700us            10  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      54.000us         0.00%      54.000us       5.400us            10  
void at::native::(anonymous namespace)::CatArrayBatc...         0.00%       0.000us         0.00%       0.000us       0.000us      50.000us         0.00%      50.000us       5.000us            10  
void at::native::(anonymous namespace)::CatArrayBatc...         0.00%       0.000us         0.00%       0.000us       0.000us      45.000us         0.00%      45.000us       4.500us            10  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      41.000us         0.00%      41.000us       4.100us            10  
                                       aten::reciprocal         0.00%     267.000us         0.00%     356.000us      35.600us      40.000us         0.00%      40.000us       4.000us            10  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      40.000us         0.00%      40.000us       4.000us            10  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      40.000us         0.00%      40.000us       4.000us            10  
                               aten::_amp_update_scale_         0.00%     114.000us         0.00%     229.000us      22.900us      40.000us         0.00%      40.000us       4.000us            10  
at::native::amp_update_scale_cuda_kernel(float*, int...         0.00%       0.000us         0.00%       0.000us       0.000us      40.000us         0.00%      40.000us       4.000us            10  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      39.000us         0.00%      39.000us       3.900us            10  
void splitKreduce_kernel<32, 16, int, __half, __half...         0.00%       0.000us         0.00%       0.000us       0.000us      31.000us         0.00%      31.000us       7.750us             4  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      30.000us         0.00%      30.000us       3.000us            10  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      30.000us         0.00%      30.000us       3.000us            10  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      28.000us         0.00%      28.000us       7.000us             4  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       9.000us         0.00%       9.000us       3.000us             3  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       6.000us         0.00%       6.000us       3.000us             2  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       6.000us         0.00%       6.000us       3.000us             2  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       6.000us         0.00%       6.000us       3.000us             2  
                                          ProfilerStep*        35.97%        6.296s        66.72%       11.678s        1.168s       0.000us         0.00%        5.957s     595.651ms            10  
                                    cudaStreamWaitEvent         0.04%       7.844ms         0.04%       7.844ms       0.526us       0.000us         0.00%       0.000us       0.000us         14924  
                                    aten::record_stream         0.03%       5.134ms         0.03%       5.136ms       2.088us       0.000us         0.00%       0.000us       0.000us          2460  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...         0.00%     759.000us         0.00%     821.000us      82.100us       0.000us         0.00%       0.000us       0.000us            10  
                                               aten::to         0.16%      28.311ms         5.82%        1.019s      68.283us       0.000us         0.00%     393.864ms      26.402us         14918  
                                         aten::_to_copy         0.28%      48.655ms         5.73%        1.003s      96.971us       0.000us         0.00%     398.629ms      38.530us         10346  
                                    aten::empty_strided         0.98%     172.270ms         1.03%     181.038ms      14.458us       0.000us         0.00%       0.000us       0.000us         12522  
                                        cudaMemcpyAsync        12.99%        2.274s        12.99%        2.274s       1.015ms       0.000us         0.00%       0.000us       0.000us          2240  
                                  cudaStreamSynchronize         3.35%     586.782ms         3.35%     586.942ms       3.623ms       0.000us         0.00%       0.000us       0.000us           162  
                       FullyShardedDataParallel.forward         2.11%     369.792ms        28.17%        4.930s      12.025ms       0.000us         0.00%        9.056s      22.088ms           410  
             FullyShardedDataParallel._root_pre_forward         0.06%      10.108ms         0.06%      11.026ms      26.893us       0.000us         0.00%      29.000us       0.071us           410  
                    FullyShardedDataParallel._to_kwargs         0.00%     674.000us         0.00%     674.000us      67.400us       0.000us         0.00%       0.000us       0.000us            10  
                  FullyShardedDataParallel._pre_forward         1.65%     288.009ms         2.29%     400.909ms     977.827us       0.000us         0.00%        2.932s       7.152ms           410  
                                         cudaEventQuery         0.26%      45.485ms         0.26%      45.506ms       0.046us       0.000us         0.00%       0.000us       0.000us        994449  
                                       cudaLaunchKernel        18.06%        3.161s        18.06%        3.161s      35.289us       0.000us         0.00%       0.000us       0.000us         89578  
                aten::_has_compatible_shallow_copy_type         0.00%      12.000us         0.00%      12.000us       0.000us       0.000us         0.00%       0.000us       0.000us         36850  
                                  cudaStreamIsCapturing         0.01%     930.000us         0.01%     930.000us       0.745us       0.000us         0.00%       0.000us       0.000us          1249  
                                  nccl:_all_gather_base         0.00%       0.000us             0      50.711ms      62.606us       0.000us         0.00%       0.000us       0.000us           810  
                                                INVALID         0.00%      42.000us         0.00%      42.000us       0.034us       0.000us         0.00%       0.000us       0.000us          1220  
                                            aten::slice         0.23%      40.759ms         0.25%      43.373ms       2.511us       0.000us         0.00%       0.000us       0.000us         17272  
                                       aten::as_strided         0.05%       9.623ms         0.05%       9.623ms       0.118us       0.000us         0.00%       0.000us       0.000us         81657  
                                             aten::view         0.38%      65.666ms         0.38%      65.666ms       1.256us       0.000us         0.00%       0.000us       0.000us         52294  
                                 aten::split_with_sizes         0.22%      38.531ms         0.22%      38.764ms      47.273us       0.000us         0.00%       0.000us       0.000us           820  
                               cudaPointerGetAttributes         0.00%     256.000us         0.00%     360.000us      36.000us       0.000us         0.00%      26.000us       2.600us            10  
                                        aten::expand_as         0.01%       2.060ms         0.03%       6.089ms       3.782us       0.000us         0.00%       0.000us       0.000us          1610  
                                           aten::expand         0.10%      18.194ms         0.11%      19.387ms       3.138us       0.000us         0.00%       0.000us       0.000us          6178  
                                            aten::empty         1.51%     264.620ms         1.56%     272.421ms      14.869us       0.000us         0.00%       0.000us       0.000us         18322  
                                       aten::lift_fresh         0.00%       0.000us         0.00%       0.000us       0.000us       0.000us         0.00%       0.000us       0.000us           498  
                                          aten::detach_         0.01%     892.000us         0.01%     897.000us       1.877us       0.000us         0.00%       0.000us       0.000us           478  
                                                detach_         0.00%      14.000us         0.00%      14.000us       0.029us       0.000us         0.00%       0.000us       0.000us           478  
                                        aten::embedding         0.00%     376.000us         0.01%       1.487ms      43.735us       0.000us         0.00%     604.000us      17.765us            34  
                                          aten::reshape         0.25%      43.468ms         1.63%     285.236ms      10.484us       0.000us         0.00%     149.815ms       5.506us         27208  
                                          aten::resize_         0.06%       9.975ms         0.06%      11.293ms      24.338us       0.000us         0.00%       0.000us       0.000us           464  
                                             aten::ones         0.00%     160.000us         0.00%     441.000us      11.605us       0.000us         0.00%      42.000us       1.105us            38  
                                             aten::set_         0.00%     143.000us         0.00%     143.000us       4.767us       0.000us         0.00%       0.000us       0.000us            30  
                                        aten::new_empty         0.00%      20.000us         0.00%      64.000us       6.400us       0.000us         0.00%       0.000us       0.000us            10  
                                           aten::unbind         0.00%     110.000us         0.00%     205.000us      10.250us       0.000us         0.00%       0.000us       0.000us            20  
                                           aten::select         0.02%       3.097ms         0.02%       3.127ms       3.204us       0.000us         0.00%       0.000us       0.000us           976  
                                             aten::item         0.00%     337.000us        12.67%        2.217s       1.327ms       0.000us         0.00%     104.000us       0.062us          1670  
                                      aten::masked_fill         0.00%     103.000us         0.01%     894.000us      44.700us       0.000us         0.00%     151.000us       7.550us            20  
                                            aten::clone         0.11%      19.753ms         1.22%     212.791ms      44.461us       0.000us         0.00%     150.488ms      31.443us          4786  
                                       aten::empty_like         0.06%      10.988ms         0.61%     107.000ms      19.970us       0.000us         0.00%       0.000us       0.000us          5358  
                                        aten::unsqueeze         0.01%       2.077ms         0.01%       2.119ms       2.447us       0.000us         0.00%       0.000us       0.000us           866  
                                             aten::full         0.00%      91.000us         0.00%     583.000us      29.150us       0.000us         0.00%      62.000us       3.100us            20  
                                             aten::rsub         0.00%      93.000us         0.00%     441.000us      24.500us       0.000us         0.00%      65.000us       3.611us            18  
                                      aten::result_type         0.00%      35.000us         0.00%      35.000us       0.004us       0.000us         0.00%       0.000us       0.000us          8830  
                                           aten::linear         0.19%      32.772ms         2.84%     497.144ms     108.476us       0.000us         0.00%        1.272s     277.621us          4583  
                                                aten::t         0.25%      44.183ms         0.40%      70.092ms       3.668us       0.000us         0.00%       0.000us       0.000us         19110  
                                        aten::transpose         0.25%      43.395ms         0.27%      47.586ms       1.803us       0.000us         0.00%       0.000us       0.000us         26398  
                                           aten::matmul         0.24%      41.528ms         3.03%     529.723ms     121.357us       0.000us         0.00%        1.405s     321.855us          4365  
cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFla...         0.04%       6.339ms         0.04%       6.339ms       0.224us       0.000us         0.00%       0.000us       0.000us         28352  
                                        cudaMemsetAsync         0.82%     143.614ms         0.82%     143.614ms      27.848us       0.000us         0.00%       0.000us       0.000us          5157  
                                     aten::_unsafe_view         0.09%      16.137ms         0.09%      16.138ms       1.853us       0.000us         0.00%       0.000us       0.000us          8710  
                                           aten::repeat         0.10%      17.450ms         0.47%      81.460ms      67.883us       0.000us         0.00%       8.896ms       7.413us          1200  
                                            aten::alias         0.00%      54.000us         0.00%      54.000us       0.045us       0.000us         0.00%       0.000us       0.000us          1200  
                                           aten::unfold         0.04%       6.472ms         0.04%       7.210ms       1.502us       0.000us         0.00%       0.000us       0.000us          4800  
                                              aten::max         0.01%     943.000us         0.13%      23.089ms      57.722us       0.000us         0.00%      17.238ms      43.095us           400  
                                          aten::softmax         0.02%       2.938ms         0.17%      28.936ms      44.586us       0.000us         0.00%      33.681ms      51.897us           649  
                                   cudaFuncSetAttribute         0.01%       1.054ms         0.01%       1.054ms       0.315us       0.000us         0.00%       0.000us       0.000us          3341  
                 FullyShardedDataParallel._post_forward         1.02%     177.684ms         1.06%     186.202ms     454.151us       0.000us         0.00%       0.000us       0.000us           410  
                                       aten::contiguous         0.00%     231.000us         0.02%       3.796ms      32.169us       0.000us         0.00%       2.019ms      17.110us           118  
                               aten::cross_entropy_loss        -0.00%     -13.000us         0.01%       1.475ms     147.500us       0.000us         0.00%       2.862ms     286.200us            10  
                                      aten::log_softmax         0.00%      52.000us         0.00%     386.000us      38.600us       0.000us         0.00%       1.211ms     121.100us            10  
                                      aten::nll_loss_nd         0.00%     112.000us         0.01%       1.021ms     102.100us       0.000us         0.00%       1.460ms     146.000us            10  
                                         aten::nll_loss         0.00%      67.000us         0.01%       1.603ms      80.150us       0.000us         0.00%       1.916ms      95.800us            20  
                                        aten::ones_like         0.00%      28.000us         0.00%     327.000us      32.700us       0.000us         0.00%      32.000us       3.200us            10  
      autograd::engine::evaluate_function: MulBackward0         0.28%      49.286ms         3.72%     650.458ms     171.806us       0.000us         0.00%     389.313ms     102.830us          3786  
                                           MulBackward0         0.10%      18.338ms         2.21%     386.217ms     102.012us       0.000us         0.00%     236.805ms      62.548us          3786  
autograd::engine::evaluate_function: NllLossBackward...         0.00%     450.000us         0.58%     101.574ms      10.157ms       0.000us         0.00%     327.952ms      32.795ms            10  
            FullyShardedDataParallel._pre_backward_hook         1.75%     306.081ms         2.13%     372.635ms     908.866us       0.000us         0.00%        1.985s       4.841ms           410  
                                       NllLossBackward0         0.00%      53.000us         0.00%     658.000us      65.800us       0.000us         0.00%     759.000us      75.900us            10  
                                            aten::zero_         0.10%      17.965ms         2.56%     448.602ms      64.049us       0.000us         0.00%      66.634ms       9.514us          7004  
autograd::engine::evaluate_function: ToCopyBackward0...         0.13%      22.956ms         1.57%     275.194ms      68.902us       0.000us         0.00%     135.984ms      34.047us          3994  
                                        ToCopyBackward0         0.04%       7.008ms         1.07%     188.104ms      47.097us       0.000us         0.00%     112.123ms      28.073us          3994  
autograd::engine::evaluate_function: LogSoftmaxBackw...         0.00%      65.000us         0.00%     362.000us      36.200us       0.000us         0.00%       1.251ms     125.100us            10  
                                    LogSoftmaxBackward0         0.00%      66.000us         0.00%     297.000us      29.700us       0.000us         0.00%       1.251ms     125.100us            10  
     autograd::engine::evaluate_function: ViewBackward0         0.28%      48.480ms         1.57%     274.824ms      19.872us       0.000us         0.00%      75.487ms       5.458us         13830  
                                          ViewBackward0         0.12%      21.123ms         0.80%     140.881ms      10.187us       0.000us         0.00%      48.514ms       3.508us         13830  
    autograd::engine::evaluate_function: CloneBackward0         0.05%       8.777ms         0.05%       8.826ms       3.500us       0.000us         0.00%       0.000us       0.000us          2522  
                                         CloneBackward0         0.00%      49.000us         0.00%      49.000us       0.019us       0.000us         0.00%       0.000us       0.000us          2522  
    autograd::engine::evaluate_function: SliceBackward0        -0.05%   -8405.000us         5.28%     924.450ms     538.724us       0.000us         0.00%      92.225ms      53.744us          1716  
                                         SliceBackward0         0.12%      20.624ms         4.34%     758.956ms     442.282us       0.000us         0.00%      72.248ms      42.103us          1716  
                                   aten::slice_backward        -0.11%  -18815.000us         4.32%     756.577ms     440.896us       0.000us         0.00%      74.239ms      43.263us          1716  
                                            aten::zeros         0.25%      43.919ms         3.45%     604.460ms      88.397us       0.000us         0.00%      61.971ms       9.063us          6838  
autograd::engine::evaluate_function: UnsafeViewBackw...         0.14%      23.670ms         0.45%      78.443ms      12.057us       0.000us         0.00%      16.331ms       2.510us          6506  
                                    UnsafeViewBackward0         0.05%       8.810ms         0.30%      52.272ms       8.034us       0.000us         0.00%      15.273ms       2.348us          6506  
       autograd::engine::evaluate_function: MmBackward0         0.14%      24.612ms         3.11%     544.736ms     193.856us       0.000us         0.00%        2.170s     772.219us          2810  
                                            MmBackward0         0.16%      27.950ms         2.97%     520.081ms     185.082us       0.000us         0.00%        2.169s     771.987us          2810  
        autograd::engine::evaluate_function: TBackward0         0.30%      52.004ms         0.41%      71.462ms      18.698us       0.000us         0.00%       0.000us       0.000us          3822  
                                             TBackward0         0.03%       4.705ms         0.10%      16.633ms       4.352us       0.000us         0.00%       0.000us       0.000us          3822  
    autograd::engine::evaluate_function: RsqrtBackward0         0.03%       5.000ms         0.91%     160.097ms     197.651us       0.000us         0.00%      15.109ms      18.653us           810  
                                         RsqrtBackward0         0.04%       7.169ms         0.89%     155.097ms     191.478us       0.000us         0.00%      15.109ms      18.653us           810  
      autograd::engine::evaluate_function: AddBackward0         0.15%      25.900ms         1.75%     306.733ms      92.445us       0.000us         0.00%        1.659s     499.854us          3318  
                                           AddBackward0         0.00%     481.000us         0.00%     481.000us       0.145us       0.000us         0.00%       0.000us       0.000us          3318  
     autograd::engine::evaluate_function: MeanBackward1         0.03%       5.911ms         0.28%      49.025ms      60.525us       0.000us         0.00%      21.781ms      26.890us           810  
                                          MeanBackward1         0.02%       4.014ms         0.24%      42.703ms      52.720us       0.000us         0.00%      21.509ms      26.554us           810  
      autograd::engine::evaluate_function: PowBackward0         0.03%       5.820ms         0.70%     122.316ms     151.007us       0.000us         0.00%      84.062ms     103.780us           810  
                                           PowBackward0         0.03%       5.632ms         0.67%     116.496ms     143.822us       0.000us         0.00%      84.062ms     103.780us           810  
     autograd::engine::evaluate_function: SiluBackward0         0.02%       3.284ms         0.12%      20.869ms      52.172us       0.000us         0.00%      18.899ms      47.248us           400  
                                          SiluBackward0         0.01%       1.054ms         0.10%      17.499ms      43.748us       0.000us         0.00%      18.809ms      47.023us           400  
autograd::engine::evaluate_function: TransposeBackwa...         0.07%      11.976ms         0.12%      20.902ms       8.753us       0.000us         0.00%       0.000us       0.000us          2388  
                                     TransposeBackward0         0.02%       3.070ms         0.05%       8.748ms       3.663us       0.000us         0.00%       0.000us       0.000us          2388  
      autograd::engine::evaluate_function: BmmBackward0         0.06%      11.332ms         1.38%     240.783ms     191.706us       0.000us         0.00%     199.582ms     158.903us          1256  
                                           BmmBackward0         0.06%      10.094ms         1.31%     229.058ms     182.371us       0.000us         0.00%     198.557ms     158.087us          1256  
autograd::engine::evaluate_function: ExpandBackward0...         0.05%       8.946ms         0.05%       9.593ms       3.786us       0.000us         0.00%     140.000us       0.055us          2534  
                                        ExpandBackward0         0.00%     110.000us         0.00%     646.000us       0.255us       0.000us         0.00%     140.000us       0.055us          2534  
autograd::engine::evaluate_function: SoftmaxBackward...         0.02%       3.294ms         0.32%      55.786ms      88.831us       0.000us         0.00%      91.665ms     145.963us           628  
                                       SoftmaxBackward0         0.02%       3.026ms         0.30%      51.992ms      82.790us       0.000us         0.00%      90.675ms     144.387us           628  
autograd::engine::evaluate_function: MaximumBackward...         0.02%       3.737ms         4.38%     766.399ms       1.916ms       0.000us         0.00%     106.968ms     267.420us           400  
                                       MaximumBackward0         0.03%       4.939ms         4.36%     762.662ms       1.907ms       0.000us         0.00%     106.968ms     267.420us           400  
      autograd::engine::evaluate_function: DivBackward0         0.01%       2.194ms         0.15%      26.428ms      55.992us       0.000us         0.00%      19.973ms      42.316us           472  
                                           DivBackward0         0.01%       1.327ms         0.14%      23.979ms      50.803us       0.000us         0.00%      19.560ms      41.441us           472  
      autograd::engine::evaluate_function: CatBackward0         0.04%       7.206ms         0.11%      20.118ms      20.655us       0.000us         0.00%       0.000us       0.000us           974  
                                           CatBackward0         0.02%       3.715ms         0.07%      12.891ms      13.235us       0.000us         0.00%       0.000us       0.000us           974  
                                           aten::narrow         0.04%       7.783ms         0.08%      14.435ms       2.674us       0.000us         0.00%       0.000us       0.000us          5398  
      autograd::engine::evaluate_function: NegBackward0         0.02%       3.162ms         0.55%      96.544ms     120.680us       0.000us         0.00%      20.558ms      25.698us           800  
                                           NegBackward0         0.01%       1.209ms         0.53%      93.282ms     116.603us       0.000us         0.00%      20.505ms      25.631us           800  
autograd::engine::evaluate_function: SplitWithSizesB...         0.13%      22.028ms         1.29%     225.040ms     548.878us       0.000us         0.00%     642.364ms       1.567ms           410  
                                SplitWithSizesBackward0         0.04%       6.626ms         1.15%     202.134ms     493.010us       0.000us         0.00%     612.912ms       1.495ms           410  
autograd::engine::evaluate_function: torch::autograd...         0.12%      20.362ms         2.43%     424.861ms       1.036ms       0.000us         0.00%        3.852s       9.396ms           410  
                        torch::autograd::AccumulateGrad         0.01%     934.000us         0.01%       2.262ms       5.517us       0.000us         0.00%       0.000us       0.000us           410  
                                           aten::detach         0.00%     478.000us         0.01%       1.325ms       3.232us       0.000us         0.00%       0.000us       0.000us           410  
                                                 detach         0.00%     862.000us         0.00%     862.000us       2.102us       0.000us         0.00%       0.000us       0.000us           410  
           FullyShardedDataParallel._post_backward_hook         1.60%     279.699ms         2.30%     402.225ms     981.037us       0.000us         0.00%        3.852s       9.396ms           410  
                                            aten::chunk         0.01%       1.268ms         0.07%      12.601ms      30.734us       0.000us         0.00%       0.000us       0.000us           410  
                                            aten::split         0.03%       5.712ms         0.07%      11.624ms      28.351us       0.000us         0.00%       0.000us       0.000us           410  
                            c10d::_reduce_scatter_base_         0.01%       2.141ms         0.20%      34.830ms      84.951us       0.000us         0.00%        3.337s       8.140ms           410  
                              nccl:_reduce_scatter_base         0.00%       0.000us             0      24.888ms      60.702us       0.000us         0.00%       0.000us       0.000us           410  
autograd::engine::evaluate_function: EmbeddingBackwa...         0.00%     675.000us         0.02%       4.212ms     123.882us       0.000us         0.00%      24.508ms     720.824us            34  
                                     EmbeddingBackward0        -0.00%    -114.000us         0.02%       3.278ms      96.412us       0.000us         0.00%      10.133ms     298.029us            34  
                               aten::embedding_backward         0.00%     247.000us         0.02%       3.190ms      93.824us       0.000us         0.00%       9.527ms     280.206us            34  
                                             cudaMalloc         0.16%      27.665ms         0.16%      27.665ms       2.515ms       0.000us         0.00%       0.000us       0.000us            11  
                              Optimizer.step#AdamW.step         0.08%      13.150ms         0.51%      88.954ms       8.895ms       0.000us         0.00%     896.962ms      89.696ms            10  
                    Optimizer.zero_grad#AdamW.zero_grad         0.01%       1.585ms         0.01%       1.585ms     158.500us       0.000us         0.00%       0.000us       0.000us            10  
                                           aten::conv2d         0.00%      30.000us         0.00%     804.000us     201.000us       0.000us         0.00%     425.000us     106.250us             4  
                                      aten::convolution         0.00%      56.000us         0.00%     774.000us     193.500us       0.000us         0.00%     425.000us     106.250us             4  
                                     aten::_convolution         0.00%      57.000us         0.00%     718.000us     179.500us       0.000us         0.00%     425.000us     106.250us             4  
                                  cudaStreamGetPriority         0.00%       0.000us         0.00%       0.000us       0.000us       0.000us         0.00%       0.000us       0.000us             8  
                       cudaDeviceGetStreamPriorityRange         0.00%       0.000us         0.00%       0.000us       0.000us       0.000us         0.00%       0.000us       0.000us             8  
                                          aten::flatten         0.00%       8.000us         0.00%      24.000us       6.000us       0.000us         0.00%       0.000us       0.000us             4  
                                          aten::dropout         0.00%       2.000us         0.00%       2.000us       0.003us       0.000us         0.00%       0.000us       0.000us           668  
                                       aten::layer_norm         0.02%       3.241ms         0.39%      67.850ms      77.102us       0.000us         0.00%      13.997ms      15.906us           880  
                                       aten::zeros_like         0.00%     565.000us         0.02%       3.819ms      24.481us       0.000us         0.00%     455.000us       2.917us           156  
                                          aten::permute         0.02%       3.454ms         0.02%       3.600ms       4.054us       0.000us         0.00%       0.000us       0.000us           888  
    autograd::engine::evaluate_function: AddmmBackward0         0.13%      22.372ms         1.35%     236.981ms     234.171us       0.000us         0.00%      76.389ms      75.483us          1012  
                                         AddmmBackward0         0.06%      10.733ms         0.95%     166.055ms     164.086us       0.000us         0.00%      63.948ms      63.190us          1012  
autograd::engine::evaluate_function: NativeLayerNorm...         0.03%       5.749ms         0.28%      49.483ms     112.461us       0.000us         0.00%       8.133ms      18.484us           440  
                               NativeLayerNormBackward0         0.01%       1.793ms         0.25%      43.568ms      99.018us       0.000us         0.00%       8.079ms      18.361us           440  
     autograd::engine::evaluate_function: GeluBackward0         0.01%       1.534ms         0.10%      17.462ms      85.598us       0.000us         0.00%       3.028ms      14.843us           204  
                                          GeluBackward0         0.00%     577.000us         0.09%      15.875ms      77.819us       0.000us         0.00%       2.975ms      14.583us           204  
autograd::engine::evaluate_function: PermuteBackward...         0.01%       1.792ms         0.02%       3.974ms       8.950us       0.000us         0.00%       0.000us       0.000us           444  
                                       PermuteBackward0         0.00%     707.000us         0.01%       2.106ms       4.743us       0.000us         0.00%       0.000us       0.000us           444  
                                   aten::_reshape_alias         0.00%     145.000us         0.00%     145.000us       1.908us       0.000us         0.00%       0.000us       0.000us            76  
autograd::engine::evaluate_function: SelectBackward0...         0.02%       2.742ms         0.20%      35.181ms      75.173us       0.000us         0.00%       9.484ms      20.265us           468  
                                        SelectBackward0         0.01%       1.110ms         0.15%      25.568ms      54.632us       0.000us         0.00%       7.066ms      15.098us           468  
                                  aten::select_backward         0.01%       2.301ms         0.14%      24.922ms      53.252us       0.000us         0.00%       7.272ms      15.538us           468  
autograd::engine::evaluate_function: ConvolutionBack...         0.00%     103.000us         0.01%     917.000us     229.250us       0.000us         0.00%     763.000us     190.750us             4  
                                   ConvolutionBackward0         0.00%      25.000us         0.00%     814.000us     203.500us       0.000us         0.00%     763.000us     190.750us             4  
                                  cudaDeviceSynchronize         0.00%      61.000us         0.00%      61.000us      61.000us       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------

I don't quite understand all these metrics, but I can see that:

  1. The cuda time consumes are most on all_gather and matmul, this is reasonable
  2. from self CUDA and CUDA total it seems that two-node is 10x slower than single-node for both gather and matmul, this is very weird. Something most be wrong, maybe I set FSDP wrong, but I can't find where.

@awgu
Copy link
Contributor

awgu commented May 29, 2023

Multi-node:

--------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Name        Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
--------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------

aten::mm         0.38%     456.046ms        10.75%       12.890s       1.233ms        3.376s         4.30%        3.376s     322.916us         10454  

Single-node:

aten::mm         2.46%     430.458ms         5.16%     902.353ms      86.317us        3.460s        21.60%        3.460s     330.939us         10454  

I do not see the matmuls taking more time (you can similarly check aten::bmm for batched matmul). I mainly see that the communications take 10x longer on multi-node than single-node. This may just be because of your inter-node network bandwidth being 10x slower than your intra-node network bandwidth. I am not sure how you can check your inter-node connection type, but for example, if it is Ethernet (as opposed to Infiniband), the 10x slowdown is not unreasonable to me.

@JulioZhao97
Copy link
Author

Multi-node:

--------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Name        Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
--------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------

aten::mm         0.38%     456.046ms        10.75%       12.890s       1.233ms        3.376s         4.30%        3.376s     322.916us         10454  

Single-node:

aten::mm         2.46%     430.458ms         5.16%     902.353ms      86.317us        3.460s        21.60%        3.460s     330.939us         10454  

I do not see the matmuls taking more time (you can similarly check aten::bmm for batched matmul). I mainly see that the communications take 10x longer on multi-node than single-node. This may just be because of your inter-node network bandwidth being 10x slower than your intra-node network bandwidth. I am not sure how you can check your inter-node connection type, but for example, if it is Ethernet (as opposed to Infiniband), the 10x slowdown is not unreasonable to me.

Sorry, I see it wrong, the main time consume comes from ncclKernel_ReduceScatter_RING_LL_Sum_half, record_param_comms, ncclKernel_AllGather_RING_LL_Sum_int8_t, c10d::_allgather_base_, are these all communication operations? The inter-node slow-down you mentioned may be the reason, but I train smaller model using DDP, there is no such a big gap. Is it possible that I configure FSDP wrong?

@awgu
Copy link
Contributor

awgu commented May 29, 2023

DDP only uses all-reduce for communication and not all-gather/reduce-scatter. All-reduce is more optimized in practice. At the same time, how much larger is your FSDP model than DDP model? This affects the communication volume.

@JulioZhao97
Copy link
Author

JulioZhao97 commented May 29, 2023

DDP only uses all-reduce for communication and not all-gather/reduce-scatter. All-reduce is more optimized in practice. At the same time, how much larger is your FSDP model than DDP model? This affects the communication volume.

much smaller, the FSDP model is a LLAMA-13B model plus some linear layers, the DDP only tunes linear layers

@awgu
Copy link
Contributor

awgu commented May 30, 2023

I am going to mark this as closed because I no evidence to suggest this is an issue with FSDP. From our discussion, it seems that you have a slow inter-node interconnect.

@awgu awgu closed this as completed May 30, 2023
@awgu awgu added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: fsdp labels May 30, 2023
@JulioZhao97
Copy link
Author

JulioZhao97 commented May 31, 2023

Finally, I set ShardingStrategy=HYBRID_SHARDinstead of FULL_SHARD as https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.ShardingStrategy says, it turns out that multi-node training is faster.
As I observe, I use 4xnode training time of a single iteration is about 3x slower, but overall, training is faster.

@GasolSun36
Copy link

Finally, I set ShardingStrategy=HYBRID_SHARDinstead of FULL_SHARD as https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.ShardingStrategy says, it turns out that multi-node training is faster. As I observe, I use 4xnode training time of a single iteration is about 3x slower, but overall, training is faster.

Hi, JulioZhao97, do you have your full code with training with FSDP in multi-node? Can you share that with me? appreciate

@JulioZhao97
Copy link
Author

Finally, I set ShardingStrategy=HYBRID_SHARDinstead of FULL_SHARD as https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.ShardingStrategy says, it turns out that multi-node training is faster. As I observe, I use 4xnode training time of a single iteration is about 3x slower, but overall, training is faster.

Hi, JulioZhao97, do you have your full code with training with FSDP in multi-node? Can you share that with me? appreciate

def wrap_model_using_fsdp(self):
        params_no_grad = [n for n, p in self._model.named_parameters() if not p.requires_grad]
        
        from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
            
        if len(params_no_grad) > 0:
            print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'. format(len(params_no_grad), params_no_grad))
            print("[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental.")
            print("[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build.  See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining")

            def patch_FSDP_use_orig_params(func):
                def wrap_func(*args, **kwargs):
                    use_orig_params = kwargs.pop('use_orig_params', True)
                    return func(*args, **kwargs, use_orig_params=use_orig_params)
                return wrap_func

            FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__)

        from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision, ShardingStrategy
                
        dtype = torch.float16
        mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)
        
        local_rank = int(os.environ['LOCAL_RANK'])
        torch.cuda.set_device(local_rank)
                
        def get_module_class_from_name(module, name):
            modules_children = list(module.children())
            if module.__class__.__name__ == name:
                return module.__class__
            elif len(modules_children) == 0:
                return
            else:
                for child_module in modules_children:
                    module_class = get_module_class_from_name(child_module, name)
                    if module_class is not None:
                        return module_class
                
        from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
        import functools
        transformer_cls_to_wrap = set()
        for layer_class in ['LlamaDecoderLayer']:
            transformer_cls = get_module_class_from_name(self._model, layer_class)
            if transformer_cls is None:
                raise Exception("Could not find the transformer layer class to wrap in the model.")
            else:
                transformer_cls_to_wrap.add(transformer_cls)
        auto_wrap_policy = functools.partial(
            transformer_auto_wrap_policy,
            # Transformer layer class to wrap
            transformer_layer_cls=transformer_cls_to_wrap,
        )
                
        self._wrapped_model = self._model = FSDP(
            self._model,
            #sharding_strategy=ShardingStrategy.FULL_SHARD,
            sharding_strategy=ShardingStrategy.HYBRID_SHARD,
            #sharding_strategy=ShardingStrategy._HYBRID_SHARD_ZERO2,
            cpu_offload=CPUOffload(offload_params=False),
            mixed_precision=mixed_precision_policy,
            auto_wrap_policy=auto_wrap_policy,
            #device_id=f'cuda:{device_id}'
            device_id=torch.cuda.current_device()
        )

wrap model is something like this, mainly adapted from LLAVA and Huggingface.

@GasolSun36
Copy link

Finally, I set ShardingStrategy=HYBRID_SHARDinstead of FULL_SHARD as https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.ShardingStrategy says, it turns out that multi-node training is faster. As I observe, I use 4xnode training time of a single iteration is about 3x slower, but overall, training is faster.

Hi, JulioZhao97, do you have your full code with training with FSDP in multi-node? Can you share that with me? appreciate

def wrap_model_using_fsdp(self):
        params_no_grad = [n for n, p in self._model.named_parameters() if not p.requires_grad]
        
        from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
            
        if len(params_no_grad) > 0:
            print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'. format(len(params_no_grad), params_no_grad))
            print("[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental.")
            print("[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build.  See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining")

            def patch_FSDP_use_orig_params(func):
                def wrap_func(*args, **kwargs):
                    use_orig_params = kwargs.pop('use_orig_params', True)
                    return func(*args, **kwargs, use_orig_params=use_orig_params)
                return wrap_func

            FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__)

        from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision, ShardingStrategy
                
        dtype = torch.float16
        mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)
        
        local_rank = int(os.environ['LOCAL_RANK'])
        torch.cuda.set_device(local_rank)
                
        def get_module_class_from_name(module, name):
            modules_children = list(module.children())
            if module.__class__.__name__ == name:
                return module.__class__
            elif len(modules_children) == 0:
                return
            else:
                for child_module in modules_children:
                    module_class = get_module_class_from_name(child_module, name)
                    if module_class is not None:
                        return module_class
                
        from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
        import functools
        transformer_cls_to_wrap = set()
        for layer_class in ['LlamaDecoderLayer']:
            transformer_cls = get_module_class_from_name(self._model, layer_class)
            if transformer_cls is None:
                raise Exception("Could not find the transformer layer class to wrap in the model.")
            else:
                transformer_cls_to_wrap.add(transformer_cls)
        auto_wrap_policy = functools.partial(
            transformer_auto_wrap_policy,
            # Transformer layer class to wrap
            transformer_layer_cls=transformer_cls_to_wrap,
        )
                
        self._wrapped_model = self._model = FSDP(
            self._model,
            #sharding_strategy=ShardingStrategy.FULL_SHARD,
            sharding_strategy=ShardingStrategy.HYBRID_SHARD,
            #sharding_strategy=ShardingStrategy._HYBRID_SHARD_ZERO2,
            cpu_offload=CPUOffload(offload_params=False),
            mixed_precision=mixed_precision_policy,
            auto_wrap_policy=auto_wrap_policy,
            #device_id=f'cuda:{device_id}'
            device_id=torch.cuda.current_device()
        )

wrap model is something like this, mainly adapted from LLAVA and Huggingface.

Thanks! btw, are there any changes in the master file of the training? I used fsdp to run successfully on single node, but now I want to run on multi-node, and I don't know how to write the code.

@JulioZhao97
Copy link
Author

As far as I concern the multi-node training and single-node training is basicly the same? If you are running in a slurm cluster, the srun parameters need to be carefully set. Besides, what is your error in multi-node training? @GasolSun36

@GasolSun36
Copy link

As far as I concern the multi-node training and single-node training is basicly the same? If you are running in a slurm cluster, the srun parameters need to be carefully set. Besides, what is your error in multi-node training? @GasolSun36

here are my script:
`#!/bin/bash
#SBATCH --job-name=bloom
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=4
#SBATCH --cpus-per-task=8
#SBATCH --mem=64gb
#SBATCH --gres=gpu:4
export MASTER_PORT=11343
export WORLD_SIZE=8

echo "NODELIST="${SLURM_NODELIST}
master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_ADDR=$master_addr
echo "MASTER_ADDR="$MASTER_ADDR

srun train.py
--model_name_or_path news_bloom_7b1_qa
--data_path ./news_training.json
--bf16 True
--output_dir ./test_multi-node-training/
--num_train_epochs 4
--per_device_train_batch_size 2
--per_device_eval_batch_size 2
--gradient_accumulation_steps 4
--evaluation_strategy "no"
--save_strategy "steps"
--save_steps 1000
--save_total_limit 1
--learning_rate 2e-5
--weight_decay 0.
--warmup_ratio 0.03
--lr_scheduler_type "cosine"
--logging_steps 1
--fsdp "full_shard auto_wrap offload"
--fsdp_transformer_layer_cls_to_wrap 'BloomBlock'
--tf32 True`

and I applied for 2 nodes and 4 cards on the cluster, for a total of 8 cards. the error is slurmstepd: error: execve(): train.py: Permission denied .Can you help me see if there are any obvious errors in my sh script?

@GasolSun36
Copy link

@JulioZhao97 Hi, the above problem is solved, but there is a new problem: ValueError: Using fsdp only works in distributed training. Can I see your setting in the main training file, please? My main reference is this link: https://gist.github.com/TengdaHan/1dd10d335c7ca6f13810fff41e809904

@JulioZhao97
Copy link
Author

@JulioZhao97 Hi, the above problem is solved, but there is a new problem: ValueError: Using fsdp only works in distributed training. Can I see your setting in the main training file, please? My main reference is this link: https://gist.github.com/TengdaHan/1dd10d335c7ca6f13810fff41e809904
How did you launch the job? You should launch job using torchrun xxx.py or torch.distributed.launch xxx.py to run your script.

@GasolSun36
Copy link

GasolSun36 commented Jun 6, 2023

@GasolSun36 see this:https://github.com/huggingface/transformers/blob/7631db0fdcfbd95b1f21d8034a0b8df73b9380ff/src/transformers/trainer.py#L446

@JulioZhao97
my runing script is :
`#!/bin/bash
#SBATCH --job-name=bloom
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=2
#SBATCH --cpus-per-task=8
#SBATCH --mem=64gb
#SBATCH --gres=gpu:2
#SBATCH --mail-type=ALL
#SBATCH --mail-user=790567648@qq.com
export MASTER_PORT=11343
export WORLD_SIZE=4

echo "NODELIST="${SLURM_NODELIST}
master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_ADDR=$master_addr
echo "MASTER_ADDR="$MASTER_ADDR

srun python train.py
--model_name_or_path bigscience/bloom-7b1
--data_path ./news_training.json
--bf16 True
--output_dir ./test_multi-node-training/
--num_train_epochs 4
--per_device_train_batch_size 1
--per_device_eval_batch_size 1
--gradient_accumulation_steps 4
--evaluation_strategy "no"
--save_strategy "steps"
--save_steps 1000
--save_total_limit 1
--learning_rate 2e-5
--weight_decay 0.
--warmup_ratio 0.03
--lr_scheduler_type "cosine"
--logging_steps 1
--fsdp "full_shard auto_wrap offload"
--fsdp_transformer_layer_cls_to_wrap 'BloomBlock'
--tf32 True`

Did you mean srun torchrun train.py?

@g-h-chen
Copy link

Same problem for me: multinode training yields a linearly increased time........................

Can anyone help? @awgu

@dydxdt
Copy link

dydxdt commented Mar 13, 2024

Thanks for your great discussions! I also meet the problem that multi-node training is much slower than single-node. The time for two-node training takes two times longer than single-node's. My training doesn't use FSDP. Do you have some suggestions? Thank you very much! @JulioZhao97

@JulioZhao97
Copy link
Author

Thanks for your great discussions! I also meet the problem that multi-node training is much slower than single-node. The time for two-node training takes two times longer than single-node's. My training doesn't use FSDP. Do you have some suggestions? Thank you very much! @JulioZhao97

I suggest you check timing using torch.profiler as discussed above, it is most likely that the communication between 2 nodes is very slow, you can check with your cluster manager.

@GasolSun36
Copy link

Maybe try Deepspeed, it works perfectly well in my 8 nodes machine. FSDP just didn't work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: fsdp triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants