In [1]:
import transformers
import torch
from matplotlib import pyplot as plt

In [None]:
model = transformers.AutoModelForCausalLM.from_pretrained("./pretrained-llms/LLaDA-8B-Instruct/", trust_remote_code=True, device_map="auto")
tokenizer = transformers.AutoTokenizer.from_pretrained("./pretrained-llms/LLaDA-8B-Instruct/", trust_remote_code=True)
tokenizer.eot_id = 126348

In [None]:
import torch.nn.functional as F
import numpy as np

def get_num_transfer_tokens(mask_index, steps):
    mask_num = mask_index.sum(dim=1, keepdim=True)
    base = mask_num // steps
    remainder = mask_num % steps
    num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
    for i in range(mask_num.size(0)):
        num_transfer_tokens[i, :remainder[i]] += 1
    return num_transfer_tokens

def forward_process(input_ids, eps=1e-3):
    b, l = input_ids.shape
    t = torch.rand(b, device=input_ids.device)
    p_mask = (1 - eps) * t + eps
    p_mask = p_mask[:, None].repeat(1, l)
    masked_indices = torch.rand((b, l), device=input_ids.device) < p_mask
    noisy_batch = torch.where(masked_indices, 126336, input_ids)
    return noisy_batch, masked_indices, p_mask

@torch.no_grad()
def generate(model, input_ids, steps=128, gen_length=128, mask_id=126336):
    B, L = input_ids.shape
    x = torch.full((B, L + gen_length), mask_id, dtype=torch.long, device=model.device)
    x[:, :L] = input_ids.clone()
    
    mask_index = (x == mask_id)
    num_transfer_tokens = get_num_transfer_tokens(mask_index, steps)
    
    for i in range(steps):
        mask_index = (x == mask_id)
        attention_mask = (~mask_index).long()
        
        logits = model(x, attention_mask=attention_mask).logits
        x0 = torch.argmax(logits, dim=-1)
        
        x0_p = torch.rand(x0.shape, device=x0.device)
        x0 = torch.where(mask_index, x0, x)
        confidence = torch.where(mask_index, x0_p, -np.inf)
        
        transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
        for j in range(B):
            _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
            transfer_index[j, select_index] = True
        x[transfer_index] = x0[transfer_index]
    
    return x

sample_data = [
    {"prompt": "What is 2+2?", "target": "4"},
    {"prompt": "What color is the sky?", "target": "Blue"},
    {"prompt": "Capital of France?", "target": "Paris"},
    {"prompt": "How many days in a week?", "target": "Seven"},
    {"prompt": "What is H2O?", "target": "Water"},
]

# tokenize using chat template
def prepare_sample(sample, tokenizer):
    messages = [{"role": "user", "content": sample["prompt"]}]
    prompt_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)
    target_ids = tokenizer.encode(sample["target"], add_special_tokens=False) + [tokenizer.eot_id] + [tokenizer.eos_token_id]
    full_ids = prompt_ids + target_ids
    return torch.tensor(full_ids), len(prompt_ids)

tokenized_samples = [prepare_sample(s, tokenizer) for s in sample_data]

# pad the batch
max_len = max(ids.shape[0] for ids, _ in tokenized_samples)
padded_input_ids = torch.zeros(len(tokenized_samples), max_len, dtype=torch.long)
prompt_lengths = torch.zeros(len(tokenized_samples), dtype=torch.long)
for i, (ids, prompt_len) in enumerate(tokenized_samples):
    padded_input_ids[i, :ids.shape[0]] = ids
    padded_input_ids[i, ids.shape[0]:] = tokenizer.pad_token_id if tokenizer.pad_token_id else 0
    prompt_lengths[i] = prompt_len

device = next(model.parameters()).device
padded_input_ids = padded_input_ids.to(device)
prompt_lengths = prompt_lengths.to(device)

for param in model.parameters():
    param.requires_grad = False

embed_dim = model.model.transformer.wte.weight.shape[1]

linear_adapter = torch.nn.Sequential(
    torch.nn.LayerNorm(embed_dim),
    torch.nn.Linear(embed_dim, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, embed_dim),
    torch.nn.LayerNorm(embed_dim),
).to(device)

def train_step_with_adapter(model, linear_adapter, input_ids, prompt_lengths):
    noisy_batch, _, p_mask = forward_process(input_ids)
    
    # Shape: (batch_size, seq_len) - each row contains [0, 1, 2, ..., seq_len-1]
    token_positions = torch.arange(noisy_batch.shape[1], device=noisy_batch.device).expand(noisy_batch.size(0), noisy_batch.size(1))
    prompt_mask = (token_positions < prompt_lengths.unsqueeze(1))
    noisy_batch[prompt_mask] = input_ids[prompt_mask]
    
    answer_lengths = torch.sum((1 - prompt_mask.to(torch.int64)), dim=-1, keepdim=True)
    # answer_lengths shape before: (batch_size, 1) - e.g., tensor([[15], [12], [14]])
    # After repeat(1, noisy_batch.shape[1]): (batch_size, seq_len)
    # Each row contains the same value repeated seq_len times
    # e.g., tensor([[15, 15, 15, ..., 15],
    #               [12, 12, 12, ..., 12],
    #               [14, 14, 14, ..., 14]])
    # Values represent the number of answer tokens (non-prompt tokens) for each sample
    # NOTE: This includes padding tokens! If sequences are padded, answer_lengths will
    # count padding as part of the "answer" region (everything after the prompt).
    # This may cause issues if padding tokens are masked and included in the loss.
    answer_lengths = answer_lengths.repeat(1, noisy_batch.shape[1])
    
    # NOTE: This includes padded tokens that were masked! If a padding token was
    # randomly masked during forward_process, it will be included here and contribute to the loss.
    masked_indices = (noisy_batch == 126336)
    
    embeddings = model.model.transformer.wte(noisy_batch)
    adapted_embeddings = embeddings + linear_adapter(embeddings)
    logits = model(inputs_embeds=adapted_embeddings).logits
    
    # NOTE: The loss is calculated on ALL masked tokens, including padded tokens that were masked.
    # This means if padding tokens (pad_token_id) were randomly masked during forward_process,
    # the model will learn to predict them. This implicitly teaches the model about sequence length
    # and when to "end" the answer. If you want to exclude padding from the loss, you would need
    # to create a padding_mask and combine it with masked_indices:
    # e.g., valid_mask = masked_indices & (input_ids != tokenizer.pad_token_id)
    token_loss = F.cross_entropy(logits[masked_indices], input_ids[masked_indices], reduction='none') / p_mask[masked_indices]
    loss = torch.sum(token_loss / answer_lengths[masked_indices]) / input_ids.shape[0]
    
    return loss

num_epochs = 200
optimizer = torch.optim.AdamW(linear_adapter.parameters(), lr=1e-3)
train_losses = []
val_losses = []

linear_adapter.train()
for epoch in range(num_epochs):
    optimizer.zero_grad()
    train_loss = train_step_with_adapter(model, linear_adapter, padded_input_ids, prompt_lengths)
    train_loss.backward()
    optimizer.step()
    train_losses.append(train_loss.item())
    
    linear_adapter.eval()
    with torch.no_grad():
        val_loss = train_step_with_adapter(model, linear_adapter, padded_input_ids, prompt_lengths)
        val_losses.append(val_loss.item())
    linear_adapter.train()
    
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {train_loss.item():.4f}, Val Loss: {val_loss.item():.4f}")

plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)
plt.show()


In [None]:
# === Generation Demo ===
sample_text = "are you a chat LLM?"
sample_text_ids = tokenizer.apply_chat_template([{"role": "user", "content": sample_text}], tokenize=True, add_generation_prompt=True)
sample_text_ids = torch.tensor(sample_text_ids).unsqueeze(0)

output = generate(
    model=model,
    input_ids=sample_text_ids,
)
print(tokenizer.decode(output[0]))

In [None]:

linear_adapter.eval()
with torch.no_grad():
    for sample in sample_data:
        messages = [{"role": "user", "content": sample["prompt"]}]
        prompt_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)
        input_ids = torch.tensor(prompt_ids).unsqueeze(0).to(device)
        
        embeddings = model.model.transformer.wte(input_ids)
        adapted_embeddings = embeddings + linear_adapter(embeddings)
        
        output_without_adapter = generate(model=model, input_ids=input_ids)
        
        gen_length = 8
        mask_id = 126336
        
        x = torch.full((1, input_ids.shape[1] + gen_length), mask_id, dtype=torch.long, device=device)
        x[:, :input_ids.shape[1]] = input_ids
        
        mask_token_embedding = model.model.transformer.wte(torch.tensor([mask_id], device=device))
        sequence_embeds = torch.cat([adapted_embeddings, mask_token_embedding.expand(1, gen_length, -1)], dim=1)
        
        steps = 32
        for step in range(steps):
            mask_index = (x == mask_id)
            logits = model(inputs_embeds=sequence_embeds).logits
            x0 = torch.argmax(logits, dim=-1)
            
            p = F.softmax(logits, dim=-1)
            x0_p = torch.gather(p, dim=-1, index=x0.unsqueeze(-1)).squeeze(-1)
            
            x0 = torch.where(mask_index, x0, x)
            confidence = torch.where(mask_index, x0_p, -float('inf'))
            
            num_transfer = max(1, mask_index.sum().item() // (steps - step))
            _, select_index = torch.topk(confidence.flatten(), k=min(num_transfer, mask_index.sum().item()))
            transfer_mask = torch.zeros_like(x, dtype=torch.bool)
            transfer_mask.view(-1)[select_index] = True
            x[transfer_mask] = x0[transfer_mask]
            
            if transfer_mask.any():
                new_embeds = model.model.transformer.wte(x0[transfer_mask])
                sequence_embeds[transfer_mask] = new_embeds
        
        output_with_adapter = x
        
        print(f"Prompt: {sample['prompt']}")
        print(f"Expected: {sample['target']}")
        print(f"Without adapter: {tokenizer.decode(output_without_adapter[0], skip_special_tokens=False)}")
        print(f"With adapter: {tokenizer.decode(output_with_adapter[0], skip_special_tokens=False)}")
        print("-" * 50)

In [None]:
import torch
from joblib import load
from src.llm import LLaDA_8B_Instruct
from src.dataset import AsrDataset

from matplotlib import pyplot as plt

In [None]:
model = LLaDA_8B_Instruct()
dataset = AsrDataset(
    tokenizer=model.tokenizer,
    prompt="Transcribe speech to text. ",
    mel_size=128,
    fix_length_audio=-1,
    fix_audio_duration=30,
    inference_mode=False,
    normalize=True,
    input_type="mel",
    path_to_jsonl_file="/mnt/disk-n8/data/librispeech/manifests/test-other.jsonl",
    target_column="processed_target",
    llm_name="llada-8b-instruct",
)

In [None]:
model.generate(
    input_ids=batch["input_ids"],
    attention_mask=batch["attention_mask"],
    max_new_tokens=100,
)

In [None]:
thingy = dataset[5]
thingy["input_ids"][thingy["input_ids"] == -1] = 0
print(thingy["input_ids"])
model.tokenizer.decode(thingy["input_ids"])

In [None]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=3, collate_fn=dataset.collator)
for batch in dataloader:
    break

input_ids = batch["input_ids"].clone()
input_ids[input_ids == -1] = 126081
for i, input_ids in enumerate(model.tokenizer.batch_decode(input_ids)):
    print("----")
    
    print(input_ids)
    print(batch["input_ids"][i])
    print(batch["audio_mel"][i].shape)
    plt.imshow(batch["audio_mel"][i])
    plt.show()
    plt.plot(batch["attention_mask"][i])
    plt.title("Attention Mask")
    plt.show()
    # print(batch["audio_mask"][i])
    # print(batch["audio_raw"][i].shape)
    # print(batch["audio_raw_mask"][i])
    # print(batch["audio_mel"][i].shape)
    print(batch["modality_mask"][i])
    print(batch["audio_duration"][i])
    print(batch["labels"][i])
    print("-----")

In [None]:
model.tokenizer.decode([      6,     ])

In [None]:
batch = load("batch_0.pkl")

In [None]:
batch["input_ids"][0][batch["input_ids"][0] == -1] = 126081

In [None]:
prompt = "transcribe the following audio"
label = "this is what the audio says"
messages = [{"role": "user", "content": prompt}]
prompt_ids = model.tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)

num_audio_tokens = 10
audio_placeholder_ids = [-1] * num_audio_tokens

# Find where assistant section starts (after <eot_id><start_id>assistant<end_id>\n)
# Insert audio placeholders right before the assistant's response
eot_id = model.tokenizer.encode("<|eot_id|>", add_special_tokens=False)[0]
eot_idx = prompt_ids.index(eot_id)

# Split: user message part + audio placeholders + assistant part
user_part = prompt_ids[:eot_idx]  # up to but not including <eot_id>
assistant_header = prompt_ids[eot_idx:]  # <eot_id><start_id>assistant<end_id>\n (eot comes after audio)

full_ids = user_part + audio_placeholder_ids + assistant_header
print("Token IDs:", full_ids)
print("Audio placeholder starts at index:", len(user_part))

decoded_text = model.tokenizer.decode([t if t >= 0 else 0 for t in full_ids])
print(decoded_text)

In [None]:
decoded_text

In [None]:
model.tokenizer.eos_token

In [None]:
print(model.tokenizer.decode(model.tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)))

In [None]:
model.tokenizer.decode(model.tokenizer(["<BOS><start_id>user<end_id>\n{prompt}<eot_id><start_id>assistant<end_id>\n{target}<EOS>"])["input_ids"][0])

In [None]:
for token in model.tokenizer.batch_decode(batch["input_ids"][0]):
    print(token, end=" ")

In [None]:
from joblib import dump, load
from matplotlib import pyplot as plt


input_ids = load("input_embeds_attention_mask_labels.joblib")[0]
attention_masks = load("input_embeds_attention_mask_labels.joblib")[1]
labels = load("input_embeds_attention_mask_labels.joblib")[2]
for i, thing in enumerate(input_ids):
    print(thing.shape)
    plt.imshow(thing.cpu().detach().numpy())
    plt.title(f"Input ID {i}")
    plt.show()
    plt.plot(attention_masks[i].cpu().detach().numpy())
    plt.title(f"Attention Mask {i}")
    plt.show()
    print(labels[i].cpu().detach().numpy())
