# Fast-dLLM LLaDA Model Demo

本Notebook演示如何使用Fast-dLLM的LLaDA模型进行文本生成，包括：
- 基础生成
- 使用Prefix Cache加速
- 使用Dual Cache加速
- 使用并行解码加速

## 1. 环境准备

首先导入必要的库

In [1]:
import sys
import os

# 添加llada目录到Python路径
sys.path.append('/home/jovyan/Fast-dLLM/llada')

import torch
from transformers import AutoTokenizer
from model.modeling_llada import LLaDAModelLM
from generate import generate, generate_with_prefix_cache, generate_with_dual_cache
import time

print(f"PyTorch版本: {torch.__version__}")
print(f"CUDA是否可用: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU设备: {torch.cuda.get_device_name(0)}")

  from .autonotebook import tqdm as notebook_tqdm


PyTorch版本: 2.9.1+cu128
CUDA是否可用: True
GPU设备: NVIDIA GeForce RTX 3090


## 2. 加载模型和Tokenizer

加载LLaDA-8B-Instruct模型

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"使用设备: {device}")

# 加载模型
print("正在加载模型...")
model = LLaDAModelLM.from_pretrained(
    'GSAI-ML/LLaDA-8B-Instruct', 
    trust_remote_code=True, 
    torch_dtype=torch.bfloat16
).to(device).eval()

# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    'GSAI-ML/LLaDA-8B-Instruct', 
    trust_remote_code=True
)

print("模型加载完成!")

The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.


使用设备: cuda
正在加载模型...


Loading checkpoint shards: 100%|██████████| 6/6 [00:01<00:00,  3.93it/s]


模型加载完成!


## 3. 准备输入

准备一个测试问题

In [3]:
# 准备问题
question = "What is the capital of France?"
print(f"问题: {question}")

# 应用chat模板
messages = [{"role": "user", "content": question}]
prompt_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
print(f"\n格式化后的提示:\n{prompt_text}")

# Tokenize
input_ids = tokenizer(prompt_text)['input_ids']
prompt = torch.tensor(input_ids).to(device).unsqueeze(0)
print(f"\n输入token数: {prompt.shape[1]}")

问题: What is the capital of France?

格式化后的提示:
<|startoftext|><|start_header_id|>user<|end_header_id|>

What is the capital of France?<|eot_id|><|start_header_id|>assistant<|end_header_id|>



输入token数: 20


## 4. 方法1: 基础生成（无Cache）

使用标准的LLaDA生成方法

In [4]:
# 生成参数
gen_length = 128  # 生成长度
steps = 128       # 采样步数
block_size = 32   # 块大小

print("方法1: 基础生成（无Cache）")
print(f"生成长度: {gen_length}, 采样步数: {steps}, 块大小: {block_size}")
print("生成中...")

start_time = time.time()
out, nfe = generate(
    model, 
    prompt, 
    steps=steps, 
    gen_length=gen_length, 
    block_length=block_size, 
    temperature=0., 
    remasking='low_confidence'
)
elapsed_time = time.time() - start_time

# 解码结果
answer = tokenizer.batch_decode(out[:, prompt.shape[1]:], skip_special_tokens=True)[0]
print(f"\n生成结果: {answer}")
print(f"前向传播次数: {nfe}")
print(f"生成时间: {elapsed_time:.2f}秒")
print(f"吞吐量: {gen_length / elapsed_time:.2f} tokens/s")

方法1: 基础生成（无Cache）
生成长度: 128, 采样步数: 128, 块大小: 32
生成中...

生成结果: The capital of France is Paris.
前向传播次数: 128
生成时间: 9.78秒
吞吐量: 13.08 tokens/s


## 5. 方法2: 使用Prefix Cache加速

通过缓存前缀的KV来加速生成

In [5]:
print("方法2: 使用Prefix Cache")
print(f"生成长度: {gen_length}, 采样步数: {steps}, 块大小: {block_size}")
print("生成中...")

start_time = time.time()
out, nfe = generate_with_prefix_cache(
    model, 
    prompt, 
    steps=steps, 
    gen_length=gen_length, 
    block_length=block_size, 
    temperature=0., 
    remasking='low_confidence'
)
elapsed_time = time.time() - start_time

# 解码结果
answer = tokenizer.batch_decode(out[:, prompt.shape[1]:], skip_special_tokens=True)[0]
print(f"\n生成结果: {answer}")
print(f"前向传播次数: {nfe}")
print(f"生成时间: {elapsed_time:.2f}秒")
print(f"吞吐量: {gen_length / elapsed_time:.2f} tokens/s")

方法2: 使用Prefix Cache
生成长度: 128, 采样步数: 128, 块大小: 32
生成中...

生成结果: The capital of France is Paris.
前向传播次数: 128
生成时间: 6.87秒
吞吐量: 18.62 tokens/s


## 6. 方法3: 使用Dual Cache加速

同时缓存前缀和后缀的KV来进一步加速

In [6]:
print("方法3: 使用Dual Cache")
print(f"生成长度: {gen_length}, 采样步数: {steps}, 块大小: {block_size}")
print("生成中...")

start_time = time.time()
out, nfe = generate_with_dual_cache(
    model, 
    prompt, 
    steps=steps, 
    gen_length=gen_length, 
    block_length=block_size, 
    temperature=0., 
    remasking='low_confidence'
)
elapsed_time = time.time() - start_time

# 解码结果
answer = tokenizer.batch_decode(out[:, prompt.shape[1]:], skip_special_tokens=True)[0]
print(f"\n生成结果: {answer}")
print(f"前向传播次数: {nfe}")
print(f"生成时间: {elapsed_time:.2f}秒")
print(f"吞吐量: {gen_length / elapsed_time:.2f} tokens/s")

方法3: 使用Dual Cache
生成长度: 128, 采样步数: 128, 块大小: 32
生成中...

生成结果: The capital of France is Paris.
前向传播次数: 128
生成时间: 8.67秒
吞吐量: 14.77 tokens/s


## 7. 方法4: Dual Cache + 置信度并行解码

结合Dual Cache和基于置信度的并行解码，实现最大加速

In [7]:
# 使用置信度阈值进行并行解码
threshold = 0.8  # 置信度阈值

print("方法4: Dual Cache + 置信度并行解码")
print(f"生成长度: {gen_length}, 采样步数: {steps}, 块大小: {block_size}")
print(f"置信度阈值: {threshold}")
print("生成中...")

start_time = time.time()
out, nfe = generate_with_dual_cache(
    model, 
    prompt, 
    steps=steps, 
    gen_length=gen_length, 
    block_length=block_size, 
    temperature=0., 
    remasking='low_confidence',
    threshold=threshold  # 启用并行解码
)
elapsed_time = time.time() - start_time

# 解码结果
answer = tokenizer.batch_decode(out[:, prompt.shape[1]:], skip_special_tokens=True)[0]
print(f"\n生成结果: {answer}")
print(f"前向传播次数: {nfe}")
print(f"生成时间: {elapsed_time:.2f}秒")
print(f"吞吐量: {gen_length / elapsed_time:.2f} tokens/s")

方法4: Dual Cache + 置信度并行解码
生成长度: 128, 采样步数: 128, 块大小: 32
置信度阈值: 0.8
生成中...

生成结果: The capital of France is Paris.
前向传播次数: 6
生成时间: 0.37秒
吞吐量: 346.54 tokens/s


## 8. 尝试其他问题

您可以修改下面的问题来测试不同的输入

In [8]:
# 自定义您的问题
custom_question = "Explain quantum computing in simple terms."

# 准备输入
messages = [{"role": "user", "content": custom_question}]
prompt_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
input_ids = tokenizer(prompt_text)['input_ids']
prompt = torch.tensor(input_ids).to(device).unsqueeze(0)

print(f"问题: {custom_question}\n")

# 使用最快的方法生成
start_time = time.time()
out, nfe = generate_with_dual_cache(
    model, 
    prompt, 
    steps=128, 
    gen_length=256,  # 更长的生成
    block_length=32, 
    temperature=0., 
    remasking='low_confidence',
    threshold=0.8
)
elapsed_time = time.time() - start_time

answer = tokenizer.batch_decode(out[:, prompt.shape[1]:], skip_special_tokens=True)[0]
print(f"回答: {answer}")
print(f"\n生成时间: {elapsed_time:.2f}秒")
print(f"前向传播次数: {nfe}")

问题: Explain quantum computing in simple terms.

回答: Quantum computing is a type of computing that uses the principles of quantum mechanics to perform calculations. In classical computers, the smallest unit of is a bit, which can be either a 0 or a 1. In quantum computing, the smallest unit is called a quantum bit, or qubit. Qubits can be both 0 and 1 at the same time, thanks to a property called superposition. This allows quantum computers to perform many calculations at once, making them much faster than classical computers.

Another important concept in quantum computing is entanglement, where two more qubits in such a way that the state of one qubit can affect the state of another, even if they are far apart. This allows quantum computers to perform computers.

In summary, quantum computing uses the principles of quantum mechanics to perform calculations much faster and classical

生成时间: 7.72秒
前向传播次数: 105


## 9. 多轮对话示例

演示如何进行多轮对话

In [9]:
# 第一轮对话
question1 = "What is machine learning?"
messages = [{"role": "user", "content": question1}]
prompt_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
input_ids = tokenizer(prompt_text)['input_ids']
conversation = torch.tensor(input_ids).to(device).unsqueeze(0)

print(f"用户: {question1}")

# 生成回答
out, nfe = generate_with_dual_cache(
    model, conversation, steps=128, gen_length=128, 
    block_length=32, temperature=0., remasking='low_confidence', threshold=0.8
)
answer1 = tokenizer.batch_decode(out[:, conversation.shape[1]:], skip_special_tokens=True)[0]
print(f"助手: {answer1}\n")

# 移除EOS token并更新对话历史
conversation = out[out != 126081].unsqueeze(0)

# 第二轮对话
question2 = "Can you give me an example?"
messages = [{"role": "user", "content": question2}]
prompt_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
input_ids = tokenizer(prompt_text)['input_ids']
new_input = torch.tensor(input_ids).to(device).unsqueeze(0)
conversation = torch.cat([conversation, new_input[:, 1:]], dim=1)

print(f"用户: {question2}")

# 生成回答
out, nfe = generate_with_dual_cache(
    model, conversation, steps=128, gen_length=128, 
    block_length=32, temperature=0., remasking='low_confidence', threshold=0.8
)
answer2 = tokenizer.batch_decode(out[:, conversation.shape[1]:], skip_special_tokens=True)[0]
print(f"助手: {answer2}")

用户: What is machine learning?
助手: Machine learning is a subset of artificial intelligence that involves the development of algorithms that enable computers to learn from data and make predictions or decisions based on that data. It is a method of training computers to perform tasks without being explicitly programmed. Instead, machine learning algorithms are designed to learn from data and improve their performance over time. This is achieved through the use of statistical techniques and mathematical models that allow computers to identify patterns in data and make predictions or decisions based on those patterns. Machine learning has numerous applications in various fields, including healthcare, finance, marketing, and more, and is an important component of modern data science and artificial intelligence.

用户: Can you give me an example?
助手: Sure! One example of machine learning is the use of image recognition algorithms. These algorithms can be trained to recognize patterns in images

## 10. 参数说明

主要参数说明：

- **gen_length**: 生成的最大token数
- **steps**: 扩散采样步数（通常与gen_length相同或接近）
- **block_size**: 块解码的大小，影响Cache效率
- **threshold**: 置信度阈值（0-1），用于并行解码。设置为None则禁用并行解码
- **temperature**: 温度参数，控制随机性（0表示确定性生成）
- **remasking**: 重mask策略，'low_confidence'表示对低置信度token重新mask

### 性能优化建议：

1. **基础生成**: 适用于需要最高质量的场景
2. **Prefix Cache**: 2-3x加速，适用于大多数场景
3. **Dual Cache**: 进一步加速，轻微质量损失
4. **Dual Cache + 并行解码**: 最大加速(8-11x)，适用于速度优先场景