Skip to content

tracking torch.compile compatibility with lora serving #10617

@youkaichao

Description

@youkaichao

Your current environment

N/A

Model Input Dumps

No response

🐛 Describe the bug

Using torch.compile with lora with fail, because vLLM's support for multi-lora (punica kernel) is very complicated.

The punica wrapper defined in

class PunicaWrapper:

is very similar to attention ops. If we want to support torch.compile for it, we need to do something similar to #10558 , i.e. hiding the whole punica operation from torch.compile .

The difference is, attention ops have quite uniform signature, and we only need to register it once; while punica ops have several signatures, and are applied to various layers, including linear / embedding etc. Even if we wrap all ops into pytorch custom ops for torch.compile, there's not much left for torch.compile to accelerate.

Therefore, I tend to leave lora as-is, and just ignore torch.compile for it.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingstaleOver 90 days of inactivity

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions