In [None]:
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc

import torch
from models.qwen3.modeling_qwen3_dms import Qwen3ForCausalLMDMS
from transformers import AutoTokenizer, TextStreamer

# Auxiliary functions


def get_model_and_tokenizer(model_name: str, **model_kwargs):
    model = Qwen3ForCausalLMDMS.from_pretrained(model_name, dtype=torch.bfloat16, **model_kwargs)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model.eval()
    model.to(torch.device("cuda:0"))
    return model, tokenizer


def clean_cache():
    gc.collect()
    torch.cuda.empty_cache()


def get_example_input_ids(tokenizer: AutoTokenizer, device: torch.device):
    prompt = [{"role": "user", "content": "Solve x^2 -2x + 1 = 0"}]
    prompt = tokenizer.apply_chat_template(
        prompt, tokenize=False, add_generation_prompt=True, enable_thinking=True
    )
    input_data = tokenizer(prompt, return_tensors="pt").to(device)
    input_ids = input_data["input_ids"]
    return input_ids

In [None]:
# The following function will run chunked prefill to avoid OOM,
# and follow with generation


def example_prefill_generate(model_name: str):
    model, tokenizer = get_model_and_tokenizer(model_name, dms_chunked_prefill=4)

    input_ids = get_example_input_ids(tokenizer, model.device)

    streamer = TextStreamer(tokenizer, skip_prompt=False, skip_special_tokens=False)

    # automatically replaces DynamicCache with DMSCache
    model.generate(input_ids, max_new_tokens=4096, do_sample=False, streamer=streamer)


clean_cache()
example_prefill_generate(model_name="nvidia/Qwen3-8B-DMS-8x")

In [None]:
# In order to run multi-turn conversations using this code, which interleave prefill with generation,
# a fallback of using only prefills can be used.


def example_prefill_only_mode(model_name: str):
    model, tokenizer = get_model_and_tokenizer(
        model_name, dms_manual_inference_mode=True, dms_preallocate_for_tokens=512
    )

    cache = model.get_cache()
    input_ids = get_example_input_ids(tokenizer, model.device)

    with torch.no_grad():
        cache.prefill_mode()

        for _ in range(4096):
            # prefill mode is slower but allows for interleaving of prefill and inference
            assert cache.is_prefill_mode()
            outputs = model(input_ids, past_key_values=cache, use_cache=True)
            input_ids = outputs.logits[:, -1, :].argmax(dim=-1, keepdim=True)
            print(
                tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=False),
                end="",
                flush=True,
            )
            if input_ids[0, 0] == tokenizer.eos_token_id:
                break


clean_cache()
example_prefill_only_mode(model_name="nvidia/Qwen3-8B-DMS-8x")