Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
f00737f
Adding Compute-Context-Length(CCL)
vjanfaza Oct 17, 2025
5410733
Adding Compute-Context-Length(CCL)
vjanfaza Oct 17, 2025
13271c6
Adding Compute-Context-Length(CCL)
vjanfaza Oct 17, 2025
3332962
Delete examples/granite_example/ccl_granitemoe_inference.py
vjanfaza Oct 17, 2025
b4bf5f9
Adding Compute-Context-Length(CCL)
vjanfaza Oct 19, 2025
a4f8b4b
Merge remote-tracking branch 'origin/CCL-main' into CCL-main
vjanfaza Oct 19, 2025
9363689
Adding Compute-Context-Length(CCL)
vjanfaza Oct 20, 2025
a4fca59
Adding Compute-Context-Length(CCL)
vjanfaza Oct 20, 2025
acc4e40
Merge branch 'quic:main' into CCL-main
vjanfaza Oct 21, 2025
5f047b4
improving handeling CCL lists
vjanfaza Oct 22, 2025
71c5182
improving handeling CCL lists
vjanfaza Oct 22, 2025
1d74b42
improving handeling CCL lists
vjanfaza Oct 22, 2025
811b1ce
improving handeling CCL lists
vjanfaza Oct 22, 2025
7b57d90
improving handeling CCL lists
vjanfaza Oct 22, 2025
0b88a32
Adding Compute-Context-Length(CCL)
vjanfaza Oct 22, 2025
2ade913
fixing lora testing
vjanfaza Oct 23, 2025
acf3544
Adding Compute-Context-Length(CCL)
vjanfaza Oct 23, 2025
2643e9f
Adding Compute-Context-Length(CCL)
vjanfaza Oct 17, 2025
495b44f
Adding Compute-Context-Length(CCL)
vjanfaza Oct 17, 2025
9d1a63a
Adding Compute-Context-Length(CCL)
vjanfaza Oct 17, 2025
eb3aea5
Adding Compute-Context-Length(CCL)
vjanfaza Oct 19, 2025
736c775
Delete examples/granite_example/ccl_granitemoe_inference.py
vjanfaza Oct 17, 2025
027625c
Adding Compute-Context-Length(CCL)
vjanfaza Oct 20, 2025
42b4b7f
Adding Compute-Context-Length(CCL)
vjanfaza Oct 20, 2025
fa3c2f6
improving handeling CCL lists
vjanfaza Oct 22, 2025
8fb3265
improving handeling CCL lists
vjanfaza Oct 22, 2025
ee2f54e
improving handeling CCL lists
vjanfaza Oct 22, 2025
bb2a207
improving handeling CCL lists
vjanfaza Oct 22, 2025
0e9c851
improving handeling CCL lists
vjanfaza Oct 22, 2025
6cedad2
Adding Compute-Context-Length(CCL)
vjanfaza Oct 22, 2025
528ad38
fixing lora testing
vjanfaza Oct 23, 2025
ba18a3e
Adding Compute-Context-Length(CCL)
vjanfaza Oct 23, 2025
6d056f9
Updated the test
quic-rishinr Oct 23, 2025
c9aaaec
Lint fix
quic-rishinr Oct 23, 2025
8f29a42
Merge remote-tracking branch 'origin/CCL-main' into CCL-main
vjanfaza Oct 23, 2025
5765779
Adding the support of modeling_gpt_bigcode with CCL
vjanfaza Oct 23, 2025
0468a90
Removed redendunt test
quic-rishinr Oct 24, 2025
d8f4eab
Adding Compute-Context-Length(CCL)
vjanfaza Oct 24, 2025
8b6ab58
Merge remote-tracking branch 'origin/CCL-main' into CCL-main
vjanfaza Oct 24, 2025
7e952ad
Adding Compute-Context-Length(CCL)
vjanfaza Oct 24, 2025
2d137f9
Add CCL support to molmo model
vjanfaza Oct 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions QEfficient/cloud/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,18 @@ def main(
"--prompt-len", "--prompt_len", default=32, type=int, help="Sequence length for text generation."
)
parser.add_argument("--ctx-len", "--ctx_len", default=128, type=int, help="Context length for text generation.")
parser.add_argument(
"--comp-ctx-lengths-prefill",
type=lambda comp_ctx_lengths_prefill: [int(x) for x in comp_ctx_lengths_prefill.split(",")],
default=[512],
help="Define ccl list in csv format (e.g.,--comp-ctx-lengths 512,1024,2048).",
)
parser.add_argument(
"--comp-ctx-lengths-decode",
type=lambda comp_ctx_lengths_decode: [int(x) for x in comp_ctx_lengths_decode.split(",")],
default=[2048],
help="Define ccl list in csv format (e.g.,--comp-ctx-lengths 512,1024,2048).",
)
parser.add_argument(
"--mxfp6",
"--mxfp6_matmul",
Expand Down
16 changes: 11 additions & 5 deletions QEfficient/customop/ctx_scatter_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,14 @@ def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> tor


@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
def CtxGather(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT:
ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[3], axes=[0]))
def CtxGather(
data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32
) -> onnxscript.FLOAT:
# Create a shape tensor based on comp_ctx_len
shape_tensor = ops.Concat(ops.Shape(data)[:2], ops.Reshape(comp_ctx_len, [1]), axis=0)

# Directly use the shape tensor without validation
ctx_indices = ops.Expand(ctx_indices, shape_tensor)
ctx_indices = ops.Unsqueeze(ctx_indices, [-1])
return ops.GatherND(data, ctx_indices, batch_dims=2)

Expand All @@ -127,7 +133,7 @@ class CtxGatherFunc(torch.autograd.Function):
"""

@staticmethod
def forward(data: torch.Tensor, ctx_indices: torch.Tensor):
def forward(data: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int):
batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1)
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
return data[batch_indices, head_indices, ctx_indices]
Expand All @@ -137,5 +143,5 @@ def setup_context(ctx, inputs, outputs):
pass

@staticmethod
def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value:
return g.onnxscript_op(CtxGather, data, ctx_indices).setTypeAs(data)
def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int) -> torch.Value:
return g.onnxscript_op(CtxGather, data, ctx_indices, comp_ctx_len).setTypeAs(data)
18 changes: 12 additions & 6 deletions QEfficient/customop/ctx_scatter_gather_cb.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,16 +97,20 @@ def symbolic(

@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
def CtxGatherCB(
data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32
data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32
) -> onnxscript.FLOAT:
batch_size = ops.Gather(ops.Shape(batch_index), [0])
num_heads = ops.Gather(ops.Shape(data), [1])
ctx_len = ops.Gather(ops.Shape(data), [2])
# using compute-context-length (CCL) instead of context-length to do gather process based on CCL and later do attention computations based on CCL as well.
ctx_len = ops.Reshape(comp_ctx_len, [1])

# Expanded shape to create indices
zero = ops.Constant(value_ints=[0])
one = ops.Constant(value_ints=[1])
exp_shape = ops.Concat(batch_size, num_heads, ctx_len, one, axis=0)
# exp_shape = ops.Concat(batch_size, num_heads, ctx_len, one, axis=0)
exp_shape = ops.Concat(
ops.Reshape(batch_size, [1]), ops.Reshape(num_heads, [1]), ops.Reshape(ctx_len, [1]), one, axis=0
)

# Create indices
batch_idx = ops.Expand(ops.Unsqueeze(batch_index, [2, 3]), exp_shape)
Expand All @@ -119,7 +123,7 @@ def CtxGatherCB(

class CtxGatherFuncCB(torch.autograd.Function):
@staticmethod
def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor):
def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int):
batch_indices = batch_index.view(-1, 1, 1)
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
return data[batch_indices, head_indices, ctx_indices]
Expand All @@ -129,8 +133,10 @@ def setup_context(ctx, inputs, outputs):
pass

@staticmethod
def symbolic(g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value) -> torch.Value:
return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices).setTypeAs(data)
def symbolic(
g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int
) -> torch.Value:
return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices, comp_ctx_len).setTypeAs(data)


@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
Expand Down
76 changes: 76 additions & 0 deletions QEfficient/generation/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,8 @@ def cloud_ai_100_exec_kv(
prompts_txt_file_path: Optional[str] = None,
device_id: Optional[List[int]] = None,
generation_len: Optional[int] = None,
comp_ctx_lengths_prefill: Optional[List[int]] = None,
comp_ctx_lengths_decode: Optional[List[int]] = None,
enable_debug_logs: bool = False,
stream: bool = True,
write_io_dir: Optional[str] = None,
Expand Down Expand Up @@ -384,6 +386,8 @@ def cloud_ai_100_exec_kv(
qpc_path=qpc_path,
device_id=device_id,
ctx_len=ctx_len,
comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
comp_ctx_lengths_decode=comp_ctx_lengths_decode,
enable_debug_logs=enable_debug_logs,
write_io_dir=write_io_dir,
full_batch_size=full_batch_size,
Expand Down Expand Up @@ -430,6 +434,8 @@ def __init__(
qpc_path: str,
full_batch_size: Optional[int] = None,
ctx_len: Optional[int] = None,
comp_ctx_lengths_prefill: Optional[List[int]] = None,
comp_ctx_lengths_decode: Optional[List[int]] = None,
device_id: Optional[List[int]] = None,
enable_debug_logs: bool = False,
write_io_dir: Optional[str] = None,
Expand All @@ -439,6 +445,8 @@ def __init__(
sampling_params: Optional[Dict[str, Any]] = None,
) -> None:
self._ctx_len = ctx_len
self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
self.comp_ctx_lengths_decode = comp_ctx_lengths_decode
self._write_io_dir = write_io_dir
self.is_tlm = is_tlm
self.return_pdfs = return_pdfs
Expand Down Expand Up @@ -797,7 +805,17 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
batch_lora_ids = [self._prompt_to_lora_id_mapping_prefill.popleft() for i in range(self.batch_size)]
inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1)

if self.comp_ctx_lengths_prefill is not None:
self.list_of_comp_ctx_lengths_prefill = [np.zeros(length) for length in self.comp_ctx_lengths_prefill]
prefill_ccl_id = 0
inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id]

for i in range(num_chunks):
if self.comp_ctx_lengths_prefill is not None:
if (i + 1) * self._prefill_seq_len > self.comp_ctx_lengths_prefill[prefill_ccl_id]:
prefill_ccl_id = min(prefill_ccl_id + 1, len(self.comp_ctx_lengths_prefill) - 1)
inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id]

chunk_inputs = inputs.copy()
chunk_inputs["input_ids"] = inputs["input_ids"][
:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len
Expand All @@ -816,6 +834,19 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
generation_len,
)

def initialize_ccl(self, decode_inputs):
self.list_of_comp_ctx_lengths_decode = [np.zeros(length) for length in self.comp_ctx_lengths_decode]
max_ccl_id = len(self.comp_ctx_lengths_decode) - 1
max_position_id = np.max(decode_inputs["position_ids"])
ccl_id_initial = 0
ccl_id = ccl_id_initial
for i in range(ccl_id_initial, len(self.comp_ctx_lengths_decode)):
if max_position_id < self.comp_ctx_lengths_decode[i]:
ccl_id = i
break

return ccl_id, max_ccl_id

def run_continuous_batching_decode(self, prompt_queue, generation_len):
"""
Runs continuous batching decode for the given prompt queue and generation length.
Expand Down Expand Up @@ -847,6 +878,10 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
# Prepare decode inputs inputs.
decode_inputs = self.prepare_decode_inputs()

if self.comp_ctx_lengths_decode is not None:
ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs)
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id]

while prompt_queue or current_decode_ongoing.any():
outputs = self._session.run(decode_inputs)

Expand Down Expand Up @@ -884,6 +919,20 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
batch_id_map[decode_batch_id]
]

if self.comp_ctx_lengths_decode is not None:
###Recalculate ccl_id based on position ids###
# Determine the maximum value of position_ids across all batch elements
max_position_id = np.max(decode_inputs["position_ids"])

# Update ccl_id and comp_ctx_lengths_decode based on the maximum position id
ccl_id_initial = 0
ccl_id = ccl_id_initial
for i in range(ccl_id_initial, len(self.comp_ctx_lengths_decode)):
if max_position_id < self.comp_ctx_lengths_decode[i]:
ccl_id = i
break
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id]

else:
current_decode_ongoing[decode_batch_id] = False
else:
Expand All @@ -896,6 +945,15 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
if self.include_sampler:
decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"]

if self.comp_ctx_lengths_decode is not None:
# Update ccl_id and comp_ctx_lengths_decode based on the maximum position id
if (
decode_inputs["position_ids"][decode_batch_id, -1]
>= self.comp_ctx_lengths_decode[ccl_id] - 1
):
ccl_id = min(ccl_id + 1, max_ccl_id)
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id]

generated_id_current_index[decode_batch_id] += 1

return decode_pause_time
Expand All @@ -922,7 +980,18 @@ def run_decode(
self._session.set_buffers({"logits": logits_out_placeholder})
finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id
num_token = 0

if self.comp_ctx_lengths_decode is not None:
ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs)
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id]

cache_index = np.max(decode_inputs["position_ids"])
for num_token in range(1, generation_len):
if self.comp_ctx_lengths_decode is not None:
if cache_index >= self.comp_ctx_lengths_decode[ccl_id] - 1:
ccl_id = min(ccl_id + 1, max_ccl_id)
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id]

if streamer:
streamer.put(decode_inputs["input_ids"][0])
outputs = self._session.run(decode_inputs)
Expand All @@ -934,6 +1003,7 @@ def run_decode(
# Prepare inputs for next iteration
decode_inputs["input_ids"] = self._fetch_next_token_id(outputs)
decode_inputs["position_ids"][:, -1] += 1
cache_index += 1
self.generated_ids[:, num_token] = decode_inputs["input_ids"][:, -1]
finished_sequences |= decode_inputs["input_ids"] == self.tokenizer.eos_token_id
if self.include_sampler:
Expand Down Expand Up @@ -983,6 +1053,8 @@ def __init__(
qpc_path: str,
full_batch_size: Optional[int] = None,
ctx_len: Optional[int] = None,
comp_ctx_lengths_prefill: Optional[List[int]] = None,
comp_ctx_lengths_decode: Optional[List[int]] = None,
device_id: Optional[List[int]] = None,
enable_debug_logs: bool = False,
write_io_dir: Optional[str] = None,
Expand All @@ -996,6 +1068,8 @@ def __init__(
qpc_path=qpc_path,
full_batch_size=full_batch_size,
ctx_len=ctx_len,
comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
comp_ctx_lengths_decode=comp_ctx_lengths_decode,
device_id=device_id,
enable_debug_logs=enable_debug_logs,
write_io_dir=write_io_dir,
Expand All @@ -1007,6 +1081,8 @@ def __init__(
self._full_batch_size = self._qaic_model.full_batch_size
self._tokenizer = self._qaic_model.tokenizer
self._ctx_len = ctx_len
self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
self.comp_ctx_lengths_decode = comp_ctx_lengths_decode
self._perf_metrics = None
self._prompt_queue = None
self._text_streamer = None
Expand Down
6 changes: 3 additions & 3 deletions QEfficient/peft/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ def forward(self, x: torch.Tensor, lora_ids: torch.Tensor):
# multilora implementation: lora_ids <batch_size, 1>
other_indices_a = torch.arange(self.lora_a_weights.shape[2]).view(1, 1, -1)
selected_lora_a_weights = CtxGatherFuncCB.apply(
self.lora_a_weights, lora_ids, other_indices_a
self.lora_a_weights, lora_ids, other_indices_a, self.lora_a_weights.shape[2]
) # <num_loras, 1, feature, r>
other_indices_b = torch.arange(self.lora_b_weights.shape[2]).view(1, 1, -1)
selected_lora_b_weights = CtxGatherFuncCB.apply(
self.lora_b_weights, lora_ids, other_indices_b
self.lora_b_weights, lora_ids, other_indices_b, self.lora_b_weights.shape[2]
) # <num_loras, 1, r, feature>
other_indices_s = torch.arange(self.lora_scalings.shape[2]).view(1, 1, -1)
selected_lora_scalings = CtxGatherFuncCB.apply(
self.lora_scalings, lora_ids, other_indices_s
self.lora_scalings, lora_ids, other_indices_s, self.lora_scalings.shape[2]
) # <num_loras, 1, 1, 1>

selected_lora_a_weights = selected_lora_a_weights.squeeze(1)
Expand Down
Loading
Loading