Skip to content

Inconsistent output when mixing eager with torch.compile #8832

@liangfu

Description

@liangfu

🐛 Bug

When mixing eager tensors with torch.compile, the output tensor result is consistent.

To Reproduce

import torch
import os
import torch_xla.core.xla_model as xm

def write_to_kv_cache(
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        slot_mapping: torch.Tensor,
) -> None:
    torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
    torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)

    key = key.flatten(0, 2)
    value = value.flatten(0, 2)
    key_cache = key_cache.flatten(0, 2)
    value_cache = value_cache.flatten(0, 2)
    key_cache.index_copy_(0, slot_mapping, key)
    value_cache.index_copy_(0, slot_mapping, value)

if __name__ == '__main__':
    device = xm.xla_device()
    num_blocks = 128
    block_size = 128
    num_kv_heads = 4
    head_size = 64
    kv_cache_shape = (2, num_blocks, block_size, num_kv_heads, head_size)
    kv_cache = torch.zeros(kv_cache_shape,
                           dtype=torch.float,
                           device=device)
    key_cache, value_cache = kv_cache

    num_heads = 64
    kv = torch.empty(1, 3, 2, num_kv_heads, head_size, dtype=torch.float, device=device)
    kv.uniform_(-1,1)
    key, value = kv.unbind(dim=2)
    slot_mapping = torch.tensor([0,1,2,3,4,5,6,7,8,9,10,11], dtype=torch.int32,device=device).long()
    compiled_callable = torch.compile(write_to_kv_cache,
                                      backend="openxla",
                                      fullgraph=False,
                                      dynamic=False)
    compiled_callable(key, value, key_cache, value_cache, slot_mapping)
    print(f"k/v cache use torch compile {key_cache[0][:5]}")

    compiled_callable(key, value, key_cache, value_cache, slot_mapping)
    print(f"k/v cache use torch compile again func {key_cache[0][:5]}")

    write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
    print(f"k/v cache use original func {key_cache[0][:5]}")

Expected behavior

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]:
  • torch_xla version:

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingdynamo

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions