Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix node provenance tracking #95901

Closed
wants to merge 6 commits into from
Closed

Commits on Mar 2, 2023

  1. Fix node provenance tracking

    [ghstack-poisoned]
    Chillee committed Mar 2, 2023
    Configuration menu
    Copy the full SHA
    0095ad2 View commit details
    Browse the repository at this point in the history
  2. Update on "Fix node provenance tracking"

    cc soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire
    
    [ghstack-poisoned]
    Chillee committed Mar 2, 2023
    Configuration menu
    Copy the full SHA
    24dfafe View commit details
    Browse the repository at this point in the history
  3. Update on "Fix node provenance tracking"

    cc soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire
    
    [ghstack-poisoned]
    Chillee committed Mar 2, 2023
    Configuration menu
    Copy the full SHA
    b9a7407 View commit details
    Browse the repository at this point in the history
  4. Update on "Fix node provenance tracking"

    cc soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire
    
    [ghstack-poisoned]
    Chillee committed Mar 2, 2023
    Configuration menu
    Copy the full SHA
    db335ad View commit details
    Browse the repository at this point in the history

Commits on Mar 3, 2023

  1. Update on "Fix node provenance tracking"

    Before:
    ```
    triton_fused_add_83_add_84_convolution_15_relu_12_relu_13_squeeze_46_var_mean_15_14
    ```
    
    After:
    ```
    triton_fused_add_83_add_84_relu_13_squeeze_46_var_mean_15_14
    ```
    
    For this kernel
    ```
    persistent_reduction(
        size_hints=[512, 64],
        reduction_hint=ReductionHint.INNER,
        filename=__file__,
        meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*fp32', 9: '*fp32', 10: 'i32', 11: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': ['in_out_ptr0'], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), equal_to_1=())]}
    )
    triton.jit
    def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr2, out_ptr3, out_ptr4, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
        xnumel = 512
        rnumel = 49
        xoffset = tl.program_id(0) * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
        xmask = xindex < xnumel
        rindex = tl.arange(0, RBLOCK)[None, :]
        rmask = rindex < rnumel
        r1 = rindex
        x0 = xindex
        tmp0 = tl.load(in_ptr0 + (r1 + (49*x0)), rmask & xmask, other=0)
        tmp8 = tl.load(in_ptr1 + (x0), xmask)
        tmp22 = tl.load(in_ptr2 + (x0), xmask)
        tmp24 = tl.load(in_ptr3 + (x0), xmask)
        tmp30 = tl.load(in_ptr4 + (x0), xmask)
        tmp2 = tl.where(rmask & xmask, tmp0, 0)
        tmp3 = tl.sum(tmp2, 1)[:, None]
        tmp4 = 49.0
        tmp5 = tmp3 / tmp4
        tmp6 = 0.1
        tmp7 = tmp5 * tmp6
        tmp9 = 0.9
        tmp10 = tmp8 * tmp9
        tmp11 = tmp7 + tmp10
        tmp12 = tmp0 - tmp5
        tmp13 = tmp12 * tmp12
        tmp15 = tl.where(rmask & xmask, tmp13, 0)
        tmp16 = tl.sum(tmp15, 1)[:, None]
        tmp17 = tmp16 / tmp4
        tmp18 = 1e-05
        tmp19 = tmp17 + tmp18
        tmp20 = tl.libdevice.rsqrt(tmp19)
        tmp21 = tmp12 * tmp20
        tmp23 = tmp21 * tmp22
        tmp25 = tmp23 + tmp24
        tmp26 = tl.where(0 != 0, 0, tl.where(0 > tmp25, 0, tmp25))
        tmp27 = 1.0208333333333333
        tmp28 = tmp17 * tmp27
        tmp29 = tmp28 * tmp6
        tmp31 = tmp30 * tmp9
        tmp32 = tmp29 + tmp31
        tl.store(in_out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp5, xmask)
        tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp11, xmask)
        tl.store(out_ptr2 + (r1 + (49*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp26, rmask & xmask)
        tl.store(out_ptr3 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp20, xmask)
        tl.store(out_ptr4 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp32, xmask)
    ```
    
    Tbh this still isn't super great provenance tracking, since ops like layernorms are decomposed. I might add some extra provenance tracking during decompositions.
    
    
    
    cc soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire
    
    [ghstack-poisoned]
    Chillee committed Mar 3, 2023
    Configuration menu
    Copy the full SHA
    6f3c63d View commit details
    Browse the repository at this point in the history

Commits on Mar 5, 2023

  1. Update on "Fix node provenance tracking"

    Before:
    ```
    triton_fused_add_83_add_84_convolution_15_relu_12_relu_13_squeeze_46_var_mean_15_14
    ```
    
    After:
    ```
    triton_fused_add_83_add_84_relu_13_squeeze_46_var_mean_15_14
    ```
    
    For this kernel
    ```
    persistent_reduction(
        size_hints=[512, 64],
        reduction_hint=ReductionHint.INNER,
        filename=__file__,
        meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*fp32', 9: '*fp32', 10: 'i32', 11: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': ['in_out_ptr0'], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), equal_to_1=())]}
    )
    triton.jit
    def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr2, out_ptr3, out_ptr4, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
        xnumel = 512
        rnumel = 49
        xoffset = tl.program_id(0) * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
        xmask = xindex < xnumel
        rindex = tl.arange(0, RBLOCK)[None, :]
        rmask = rindex < rnumel
        r1 = rindex
        x0 = xindex
        tmp0 = tl.load(in_ptr0 + (r1 + (49*x0)), rmask & xmask, other=0)
        tmp8 = tl.load(in_ptr1 + (x0), xmask)
        tmp22 = tl.load(in_ptr2 + (x0), xmask)
        tmp24 = tl.load(in_ptr3 + (x0), xmask)
        tmp30 = tl.load(in_ptr4 + (x0), xmask)
        tmp2 = tl.where(rmask & xmask, tmp0, 0)
        tmp3 = tl.sum(tmp2, 1)[:, None]
        tmp4 = 49.0
        tmp5 = tmp3 / tmp4
        tmp6 = 0.1
        tmp7 = tmp5 * tmp6
        tmp9 = 0.9
        tmp10 = tmp8 * tmp9
        tmp11 = tmp7 + tmp10
        tmp12 = tmp0 - tmp5
        tmp13 = tmp12 * tmp12
        tmp15 = tl.where(rmask & xmask, tmp13, 0)
        tmp16 = tl.sum(tmp15, 1)[:, None]
        tmp17 = tmp16 / tmp4
        tmp18 = 1e-05
        tmp19 = tmp17 + tmp18
        tmp20 = tl.libdevice.rsqrt(tmp19)
        tmp21 = tmp12 * tmp20
        tmp23 = tmp21 * tmp22
        tmp25 = tmp23 + tmp24
        tmp26 = tl.where(0 != 0, 0, tl.where(0 > tmp25, 0, tmp25))
        tmp27 = 1.0208333333333333
        tmp28 = tmp17 * tmp27
        tmp29 = tmp28 * tmp6
        tmp31 = tmp30 * tmp9
        tmp32 = tmp29 + tmp31
        tl.store(in_out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp5, xmask)
        tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp11, xmask)
        tl.store(out_ptr2 + (r1 + (49*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp26, rmask & xmask)
        tl.store(out_ptr3 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp20, xmask)
        tl.store(out_ptr4 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp32, xmask)
    ```
    
    Tbh this still isn't super great provenance tracking, since ops like layernorms are decomposed. I might add some extra provenance tracking during decompositions.
    
    
    
    cc soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire
    
    [ghstack-poisoned]
    Chillee committed Mar 5, 2023
    Configuration menu
    Copy the full SHA
    d26cdcf View commit details
    Browse the repository at this point in the history