Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Force synced KJT to trace unbacked SymInt #108960

Closed
wants to merge 1 commit into from

Commits on Sep 11, 2023

  1. Force synced KJT to trace unbacked SymInt (pytorch#108960)

    Summary:
    
    The basic concept behind this diff is to modify Dynamo's tracing behavior when it encounters a KeyedJaggedTensor that is synced (aka has `_length_per_key` and `_offset_per_key` populated). These fields are lists of integers; ordinarily, Dynamo will optimistically try to specialize on integers, however, for KJTs, we know that these integers will definitely vary from run-to-run. Furthermore, ordinarily, we would also specialize these integers if they are 0/1, but we will frequently expect features in KJTs to be 0/1.
    
    The fix is to detect KJTs and treat these integers as *unbacked integers*. This is NOT a universally sound optimization: when treating these integers as unbacked, we never report them as equal to zero or one. In return, we always generate graphs that generalize no matter the length of values on features. This is enough to trace through APS sparse arch, torchrec_dlrm and some small split-cat examples.
    
    The special integer behavior is triggered by a dynamically scoped `force_unspec_int_unbacked_size_like` variable on TracingContext, which we trigger when we wrap a KJT. There probably are other ways to do this, but this was simple and worked.
    
    Test Plan:
    ```
    buck2 test mode/dev-nosan //pytorch/benchmark/fb/test_gpu:run_test_gpu
    ```
    
    from aakhundov
    
    1. first build feed_lower_benchmark:
    ```
    buck2 build --show-output mode/opt -c python.package_style=inplace -c fbcode.enable_gpu_sections=true -c fbcode.platform=platform010 -c fbcode.split-dwarf=true hpc/new/models/feed/benchmark:feed_lower_benchmark
    ```
    2. then run the lowering of the model with it:
    ```
    TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 TORCH_LOGS="output_code,graph_code" TORCH_COMPILE_DEBUG=1 ../buck-out/v2/gen/fbcode/79c6b019ee0f9469/hpc/new/models/feed/benchmark/__feed_lower_benchmark__/feed_lower_benchmark.par --load=manifold://ig_inference_model/tree/user/facebook/fblearner/predictor/960999465/60/gpu_lowering/input.predictor --skip-trt --skip-ait --sync-mode=0 --enable-aot-inductor --lower-presets="ig_stories" --gpu-trace
    ```
    cf https://docs.google.com/document/d/1yD30xYrdmM8r2HTdmXnZTg0-MHVexfVrAa0294m1AUE/edit?pli=1#heading=h.qiv3fp7e6zg0
    
    From torchrec: https://www.internalfb.com/intern/wiki/Torchrec/Development/Testing_production_models/
    
    From ge0405
    baseline (without your diff): f477293168
    your diff: f477292363
    
    Reviewed By: voznesenskym
    
    Differential Revision: D49019987
    ezyang authored and facebook-github-bot committed Sep 11, 2023
    Configuration menu
    Copy the full SHA
    28dee98 View commit details
    Browse the repository at this point in the history