In [1]:
import torch
from transformers import AutoModel, AutoTokenizer

device = 'cuda'
model = AutoModel.from_pretrained('./LLaDA-8B-Instruct', trust_remote_code=True, torch_dtype=torch.bfloat16, use_cache=False).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained('./LLaDA-8B-Instruct', trust_remote_code=True)


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Some weights of LLaDAModelLM were not initialized from the model checkpoint at ./LLaDA-8B-Instruct and are newly initialized: ['model.transformer.mask_head.bias', 'model.transformer.mask_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
mask_head = model.model.transformer.mask_head
torch.nn.init.kaiming_uniform_(mask_head.weight, nonlinearity='linear')
if mask_head.bias is not None:
    torch.nn.init.zeros_(mask_head.bias)
print(mask_head.weight)

Parameter containing:
tensor([[ 0.0039, -0.0099, -0.0114,  ..., -0.0165,  0.0146,  0.0139]],
       device='cuda:0', dtype=torch.bfloat16, requires_grad=True)


In [11]:
from generate import generate

In [12]:
gen_length = 128
steps = 128
block_length = 128 # 32
print('*' * 66)
print(f'**  Answer Length: {gen_length}  |  Sampling Steps: {steps}  **')
print('*' * 66)

conversation_num = 0
user_input = "Write a poem about a Bulgarian hero!"

m = [{"role": "user", "content": user_input}]
user_input = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
input_ids = tokenizer(user_input)['input_ids']
input_ids = torch.tensor(input_ids).to(device).unsqueeze(0)

if conversation_num == 0:
    prompt = input_ids
else:
    prompt = torch.cat([prompt, input_ids[:, 1:]], dim=1)

out = model.generate(prompt, steps=steps, gen_length=gen_length, block_length=block_length, temperature=0., cfg_scale=0., remasking='predicted')

answer = tokenizer.batch_decode(out[:, prompt.shape[1]:], skip_special_tokens=True)[0]
print(f"Bot's reply: {answer}")

******************************************************************
**  Answer Length: 128  |  Sampling Steps: 128  **
******************************************************************
Bot's reply: In the heart of the land,
A Bulgarian hero stands tall,
A symbol of strength and pride,
An inspiration to all.

With a heart of fire and gold,
And a spirit fierce and bright,
He fights with courage and might,
A champion of justice, day and night.

In every battle, challenge, and fight,
His spirit is strong and true,
A symbol of courage and pride,
A beacon of hope and you.


In [14]:
import re
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer

In [15]:
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""


In [16]:
def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip().replace(",", "").replace("$", "")

# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            #{'role': 'user', 'content': 'What is the largest single-digit prime number?'},
            #{'role': 'assistant', 'content': XML_COT_FORMAT.format(
            #    reasoning="9 is divisble by 3 and 8 is divisible by 2, but 7 is prime.",
            #    answer="7"
            #)},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore

dataset = get_gsm8k_questions()

In [17]:
# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses] 
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses] 
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

In [18]:
output_dir = "outputs/LLaDA-GRPO"

In [19]:
import os
os.environ["NCCL_P2P_DISABLE"] = "1"
os.environ["NCCL_IB_DISABLE"] = "1"

In [20]:
training_args = GRPOConfig(
    output_dir=output_dir,
    # run_name=run_name,
    learning_rate=5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type='cosine',
    logging_steps=1,
    bf16=True,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    num_generations=16,
    max_prompt_length=256,
    max_completion_length=786,
    num_train_epochs=1,
    save_steps=100,
    max_grad_norm=0.1,
    report_to="wandb",
    log_on_each_node=False,
)

peft_config = LoraConfig(
    r=16,
    lora_alpha=64,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"],
    task_type="CAUSAL_LM",
    lora_dropout=0.05,
)

In [21]:
trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func],
    args=training_args,
    train_dataset=dataset,
    # peft_config=peft_config
)

OutOfMemoryError: CUDA out of memory. Tried to allocate 96.00 MiB. GPU 0 has a total capacity of 23.65 GiB of which 17.81 MiB is free. Including non-PyTorch memory, this process has 23.62 GiB memory in use. Of the allocated memory 23.04 GiB is allocated by PyTorch, and 141.06 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
trainer.train()

In [19]:
tokenizer.decode([126080, 126346,   3840, 126347,    198,    198,  48851,  26612, 126348, 126346,    598,  10450, 126347,    198,    198,  14455,      0,   2071,560,    331,   6528,    362,   3342,     30])

'<|startoftext|>user\n\nsay hiassistant\n\nHello! How can I assist you today?'

In [None]:
# TODO: UNDERSTAND TOKEN STRUCTURE
# tensor([[126080, 126346,   3840, 126347,    198,    198,  48851,  26612, 126348,
#          126346,    598,  10450, 126347,    198,    198,  14455,      0,   2071,
#             560,    331,   6528,    362,   3342,     30, 126348, 126081, 126081,
#          126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081,
#          126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081,
#          126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081,
#          126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081,
#          126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081,
#          126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081,
#          126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081,
#          126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081,
#          126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081,
#          126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081,
#          126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081,
#          126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081,
#          126081, 126081, 126081, 126081, 126081, 126081, 126081, 126081]],
#        device='cuda:0')
# Bot's reply: Hello! How can I assist you today?