Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 0 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ To switch different backends (default is jax):

```
TPU_BACKEND_TYPE=jax
TPU_BACKEND_TYPE=torchax
TPU_BACKEND_TYPE=pytorch_xla
```

Expand Down Expand Up @@ -209,27 +208,6 @@ docker run \
--max_model_len=1024 \
```

## Torchax Guide

**NOTE**: This is under development so the run may fail.

### Install dependencies

#### Install `vLLM`

Follow the above [step](#install-vllm-tpu) to install vllm for TPU backend.

#### Install `tpu_commons`

Follow the above step to install [tpu_commons](#install-tpu_commons)

### Run example script

```
cd vllm
TPU_BACKEND_TYPE=torchax VLLM_TORCHAX_ENABLED=1 VLLM_USE_V1=1 python examples/offline_inference/tpu.py
```

## How to test kernel?

Install dependencies:
Expand Down
126 changes: 0 additions & 126 deletions tests/kernels/flash_attention_kernel_test.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P

from tpu_commons.kernels.ragged_kv_cache_update import kv_cache_update
from tpu_commons.kernels.ragged_paged_attention.v2.ragged_kv_cache_update import \
kv_cache_update


def kv_cache_update_ref(new_kv, slot_mapping, kv_cache):
Expand Down
Loading
Loading