Open
Description
I try to quantizae llava-v1.6-34b
python3 -m mlc_llm.build --model /data/models/mlc/dist/models/llava-v1.6-34b \
--quantization q4f16_ft \
--target cuda \
--use-cuda-graph \
--use-flash-attn-mqa \
--sep-embed \
--max-seq-len 256
--artifact-path /data/models/mlc/dist/llava-v1.6-34b/ctx256 \
--use-safetensors
Expected behavior
llava-v1.6-34b-q4f16_ft-cuda.so
file is created in /data/models/mlc/dist/llava-v1.5-13b/ctx256/llava-v1.6-34b-q4f16_ft/
(/data/models/mlc/dist/llava-v1.5-13b/ctx256/llava-v1.6-34b-q4f16_ft/llava-v1.6-34b-q4f16_ft-cuda.so
)
Actual behavior
Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/usr/local/lib/python3.10/dist-packages/mlc_llm/build.py", line 47, in <module>
main()
File "/usr/local/lib/python3.10/dist-packages/mlc_llm/build.py", line 43, in main
core.build_model_from_args(parsed_args)
File "/usr/local/lib/python3.10/dist-packages/mlc_llm/core.py", line 961, in build_model_from_args
mod = mod_transform_before_build(mod, param_manager, args, model_config)
File "/usr/local/lib/python3.10/dist-packages/mlc_llm/core.py", line 613, in mod_transform_before_build
mod = fuse_split_rotary_embedding(
File "/usr/local/lib/python3.10/dist-packages/tvm/ir/transform.py", line 238, in __call__
return _ffi_transform_api.RunPass(self, mod)
File "tvm/_ffi/_cython/./packed_func.pxi", line 339, in tvm._ffi._cy3.core.PackedFuncBase.__call__
File "tvm/_ffi/_cython/./packed_func.pxi", line 270, in tvm._ffi._cy3.core.FuncCall
File "tvm/_ffi/_cython/./packed_func.pxi", line 259, in tvm._ffi._cy3.core.FuncCall3
File "tvm/_ffi/_cython/./base.pxi", line 185, in tvm._ffi._cy3.core.CHECK_CALL
File "/usr/local/lib/python3.10/dist-packages/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
File "tvm/_ffi/_cython/./packed_func.pxi", line 56, in tvm._ffi._cy3.core.tvm_callback
File "/usr/local/lib/python3.10/dist-packages/mlc_llm/transform/fuse_split_rotary_embedding.py", line 118, in ir_module_pass
split_rotary = get_dynamic_split_rotary()
File "/usr/local/lib/python3.10/dist-packages/mlc_llm/transform/fuse_split_rotary_embedding.py", line 100, in get_dynamic_split_rotary
relax.expr._update_struct_info(
File "/usr/local/lib/python3.10/dist-packages/tvm/relax/expr.py", line 1224, in _update_struct_info
_ffi_api.UpdateStructInfo(expr, struct_info) # type: ignore
File "tvm/_ffi/_cython/./packed_func.pxi", line 339, in tvm._ffi._cy3.core.PackedFuncBase.__call__
File "tvm/_ffi/_cython/./packed_func.pxi", line 270, in tvm._ffi._cy3.core.FuncCall
File "tvm/_ffi/_cython/./packed_func.pxi", line 259, in tvm._ffi._cy3.core.FuncCall3
File "tvm/_ffi/_cython/./base.pxi", line 185, in tvm._ffi._cy3.core.CHECK_CALL
File "/usr/local/lib/python3.10/dist-packages/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
tvm.error.InternalError: Traceback (most recent call last):
[bt] (5) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(TVMFuncCall+0x68) [0xffff73860f98]
[bt] (4) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(+0x1944cf4) [0xffff720a4cf4]
[bt] (3) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::relax::UpdateStructInfo(tvm::RelayExpr, tvm::relax::StructInfo)+0x1bc) [0xffff7209fd10]
[bt] (2) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(+0x193f8ec) [0xffff7209f8ec]
[bt] (1) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::runtime::detail::LogFatal::Entry::Finalize()+0x68) [0xffff7199c7f8]
[bt] (0) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::runtime::Backtrace[abi:cxx11]()+0x30) [0xffff738abfc0]
File "/opt/mlc-llm/3rdparty/tvm/src/relax/ir/struct_info.cc", line 211
InternalError: Check failed: (!expr->struct_info_.defined()) is false: To ensure idempotency, the expression passed to UpdateStructInfo must not have any prior StructInfo. However, expression # from tvm.script import tir as T
@T.prim_func(private=True)
def main(fused_qkv_handle: T.handle, embedded_query_handle: T.handle, embedded_key_handle: T.handle, value_handle: T.handle, rotary_offset: T.int64, batch_size: T.int64, seq_len: T.int64, num_query_heads: T.int64, num_kv_heads: T.int64, head_dim: T.int64, position_embedding_base: T.float32):
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
Fused_QKV = T.match_buffer(fused_qkv_handle, (batch_size, seq_len, num_query_heads + num_kv_heads * T.int64(2), head_dim), "float16")
EmbeddedQuery = T.match_buffer(embedded_query_handle, (batch_size, seq_len, num_query_heads, head_dim), "float16")
EmbeddedKey = T.match_buffer(embedded_key_handle, (batch_size, seq_len, num_kv_heads, head_dim), "float16")
Value = T.match_buffer(value_handle, (batch_size, seq_len, num_kv_heads, head_dim), "float16")
# with T.block("root"):
for iters_0, iters_1, iters_2, iters_3 in T.grid(batch_size, seq_len, num_query_heads + num_kv_heads * T.int64(2), head_dim):
with T.block("FusedRotaryEmbeddingAndSplitQKV"):
batch_i, seq_i, head_num, head_i = T.axis.remap("SSSS", [iters_0, iters_1, iters_2, iters_3])
T.reads(Fused_QKV[batch_i, seq_i, head_num, T.min(T.min(head_i, head_dim // T.int64(2) + head_i), head_i - head_dim // T.int64(2)):T.min(T.min(head_i, head_dim // T.int64(2) + head_i), head_i - head_dim // T.int64(2)) + (T.max(T.max(head_i, head_dim // T.int64(2) + head_i), head_i - head_dim // T.int64(2)) + T.int64(1) - T.min(T.min(head_i, head_dim // T.int64(2) + head_i), head_i - head_dim // T.int64(2)))])
T.writes(EmbeddedQuery[batch_i, seq_i, head_num, head_i], EmbeddedKey[batch_i, seq_i, head_num - num_query_heads, head_i], Value[batch_i, seq_i, head_num - num_query_heads - num_kv_heads, head_i])
pos: T.float32 = T.Cast("float32", rotary_offset + seq_i - seq_len)
inv_freq: T.float32 = T.float32(1.0) / T.pow(position_embedding_base, T.Cast("float32", head_i * T.int64(2) % head_dim) / T.Cast("float32", head_dim))
freq: T.float32 = pos * inv_freq
cos_value: T.float16 = T.Cast("float16", T.cos(freq))
sin_value: T.float16 = T.Cast("float16", T.sin(freq))
input_value: T.float16 = Fused_QKV[batch_i, seq_i, head_num, head_i]
embedded_value: T.float16 = cos_value * input_value + sin_value * T.Select(head_i < head_dim // T.int64(2), Fused_QKV[batch_i, seq_i, head_num, head_i + head_dim // T.int64(2)] * T.float16(-1.0), Fused_QKV[batch_i, seq_i, head_num, head_i - head_dim // T.int64(2)])
if head_num < num_query_heads:
EmbeddedQuery[batch_i, seq_i, head_num, head_i] = embedded_value
else:
if head_num < num_query_heads + num_kv_heads:
EmbeddedKey[batch_i, seq_i, head_num - num_query_heads, head_i] = embedded_value
else:
Value[batch_i, seq_i, head_num - num_query_heads - num_kv_heads, head_i] = input_value has struct info R.Callable((R.Tensor((batch_size, seq_len, num_query_heads + num_kv_heads * 2, head_dim), dtype="float16"), R.Tensor((batch_size, seq_len, num_query_heads, head_dim), dtype="float16"), R.Tensor((batch_size, seq_len, num_kv_heads, head_dim), dtype="float16"), R.Tensor((batch_size, seq_len, num_kv_heads, head_dim), dtype="float16"), R.Prim("int64"), R.Prim("int64"), R.Prim("int64"), R.Prim("int64"), R.Prim("int64"), R.Prim("int64"), R.Prim("float32")), R.Tuple, False), which cannot be overwritten with R.Callable((R.Tensor((batch_size, seq_len, num_query_heads + num_kv_heads * 2, head_dim), dtype="float16"), R.Tensor((batch_size, seq_len, num_query_heads, head_dim), dtype="float16"), R.Tensor((batch_size, seq_len, num_kv_heads, head_dim), dtype="float16"), R.Tensor((batch_size, seq_len, num_kv_heads, head_dim), dtype="float16"), R.Prim("int64"), R.Prim("int64"), R.Prim("int64"), R.Prim("int64"), R.Prim("int64"), R.Prim("int64"), R.Prim("float32")), R.Tuple, False)
Environment
Any environment details, such as: Operating System, TVM version, etc
Platform : Jetson Orin AGX, JetPack 6.2
docker image : dustynv/nano_llm:r36.4.0
package : tvm==0.19.0, mlc-llm==0.1.0
Steps to reproduce
Using
In nano_llm:r36.4.0
container,
python3 -m nano_llm.vision.video \
--api=mlc \
--model liuhaotian/llava-v1.6-34b \
--max-images 8 \
--max-context-len 256 \
--max-new-tokens 48 \
--video-input <video input> \
--video-output <video output> \
--prompt <prompt>