In [None]:
import datetime
import random
from typing import Any

import datasets
import peft
import pydantic
import torch
import transformers
import wandb
from transformers import Trainer, TrainerCallback, TrainingArguments


  from pkg_resources import resource_filename


In [2]:
MODEL = "Qwen/Qwen3-4B-Instruct-2507"
MASK_TOKEN = "<|mask|>"
IM_START_TOKEN = "<|im_start|>"

device: torch.device = torch.device(
    "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
)


class LoraConfig(pydantic.BaseModel):
    r: int = 16
    lora_alpha: int = 32
    lora_dropout: float = 0.05
    target_modules: list[str] = [
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ]
    bias: str = "none"


class TrainerConfig(pydantic.BaseModel):
    model: str = MODEL
    lr: float = 2e-4
    epochs: int = 1
    batch_size: int = 16
    gradient_accumulation_steps: int = 4
    log_interval: int = 1000
    lora_config: LoraConfig = LoraConfig()


trainer_config = TrainerConfig()


# _ = wandb.init(
#     project="diffusion-llm",
#     entity="sachinruk",
#     name=datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
#     config=trainer_config.model_dump(),
# )

In [3]:
llm_model: transformers.modeling_utils.PreTrainedModel = (
    transformers.AutoModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path=trainer_config.model,
        torch_dtype=torch.bfloat16 if device.type in {"cuda", "mps"} else torch.float32,
        attn_implementation="flash_attention_2" if device.type == "cuda" else "sdpa",
        device_map="auto",
    )
)
tokenizer: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(
    trainer_config.model
)
tokenizer.add_special_tokens({"additional_special_tokens": [MASK_TOKEN]})
llm_model.resize_token_embeddings(len(tokenizer))


`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Embedding(151670, 2560)

## Converting to Universal Attention

To replace causal attention with universal (bidirectional) attention, we need to modify the attention mechanism. The key difference is that causal attention prevents tokens from attending to future positions, while universal attention allows all positions to attend to all other positions.


In [4]:
if not hasattr(torch.nn.functional.scaled_dot_product_attention, "_is_patched"):
    original_sdpa = torch.nn.functional.scaled_dot_product_attention

    def universal_sdpa(
        query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
    ):
        if attn_mask is not None:
            last_row = attn_mask[..., -1, :]
            universal_mask = last_row.unsqueeze(-2).expand_as(attn_mask)
            attn_mask = universal_mask
        # ... your code ...
        return original_sdpa(
            query,
            key,
            value,
            attn_mask=attn_mask,
            dropout_p=dropout_p,
            is_causal=False,
            scale=scale,
        )

    universal_sdpa._is_patched = True
    torch.nn.functional.scaled_dot_product_attention = universal_sdpa


### Testing Universal Attention

Let's verify that the model now uses bidirectional attention:


In [27]:
# Test with a sample input

test_text = ["The cat sat on the mat.", "The dog"]
inputs = tokenizer(test_text, return_tensors="pt", padding=True).to(device)

# Forward pass with universal attention
with torch.no_grad():
    outputs = llm_model(**inputs)


In [28]:
type(outputs)

transformers.modeling_outputs.CausalLMOutputWithPast

In [5]:
raw_ds = datasets.load_dataset(
    "allenai/tulu-3-sft-mixture-0225",
    split="train",           # [:1%] % for demo.  drop the slice for real training
    cache_dir="./data",
    download_mode=datasets.DownloadMode.REUSE_DATASET_IF_EXISTS
)

## Train for Next Token Prediction

This section implements full next token prediction training (not just on assistant responses) using:
- **LoRA (Low-Rank Adaptation)**: Efficient fine-tuning by only training low-rank adapter matrices
- **Wandb Integration**: Logging losses and sample outputs
- **Batch Size 16**: With gradient accumulation for stability
- **Sample Logging**: Every 1000 iterations, we log input, expected output, and actual output

In [6]:
def mask_input_ids(
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor,
    mask_probability: float,
    mask_token_id: int,
    response_mask: torch.Tensor | None = None,
) -> torch.Tensor:
    mask_indices = torch.bernoulli(
        torch.full(size=input_ids.shape, fill_value=mask_probability, device=input_ids.device)
    ).bool()
    mask_indices = mask_indices & attention_mask.bool()
    if response_mask is not None:
        mask_indices: torch.Tensor = mask_indices & response_mask
    input_ids[mask_indices] = mask_token_id
    return input_ids


def get_prefix_mask(input_ids: torch.Tensor, separator_token_id: int) -> torch.Tensor:
    """Mask labels before the last separator token.

    Args:
        labels: The labels to mask.
        separator_token_id: The ID of the separator token.
        ignore_index: The index to use for the ignored labels.

    Returns:
        The masked labels.
    """
    sep_mask = input_ids == separator_token_id  # (B, L) booleans
    s = sep_mask.cumsum(dim=1)  # running count of seps
    total = s[:, -1:]  # total seps per row (B, 1)
    last_sep_onehot = sep_mask & (s == total)  # 1 only at the last sep (or all 0 if none)

    # Build a mask of positions <= last separator (inclusive):
    # reverse -> cumsum -> reverse gives ones from start up to that last-sep index
    prefix_mask = last_sep_onehot.flip(dims=[1]).cumsum(dim=1).flip(dims=[1]).bool()
    return prefix_mask


def tokenize_text(
    examples: list[dict[str, Any]],
    tokenizer: transformers.PreTrainedTokenizer,
    max_length: int = 1024,
) -> dict[str, torch.Tensor]:
    string_examples: list[str] = tokenizer.apply_chat_template(
        [example["messages"] for example in examples],
        tokenize=False,
        add_generation_prompt=False,
    )

    return tokenizer(
        string_examples,
        padding=True,
        max_length=max_length,
        return_tensors="pt",
    )


class CollateFn:
    def __init__(
        self,
        tokenizer: transformers.PreTrainedTokenizer,
        max_length: int = 1024,
        min_mask_probability: float = 0.1,
        max_mask_probability: float = 0.95,
    ):
        self.tokenizer = tokenizer
        self.mask_token_id: int = tokenizer.convert_tokens_to_ids(MASK_TOKEN)
        self.sep_token_id: int = tokenizer.convert_tokens_to_ids(IM_START_TOKEN)
        self.ignore_index = -100
        self.max_length = max_length
        self.min_mask_probability = min_mask_probability
        self.max_mask_probability = max_mask_probability


class PretrainingCollateFn(CollateFn):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def __call__(self, examples: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
        encoded_batch = tokenize_text(examples, self.tokenizer, self.max_length)
        input_ids = encoded_batch["input_ids"]
        attn = encoded_batch["attention_mask"]

        labels: torch.Tensor = input_ids.clone()
        labels[attn == 0] = self.ignore_index

        input_ids = mask_input_ids(
            input_ids,
            attn,
            mask_probability=random.uniform(self.min_mask_probability, self.max_mask_probability),
            mask_token_id=self.mask_token_id,
        )

        return {
            "input_ids": input_ids,
            "attention_mask": attn,
            "labels": labels,
        }


class SFTCollateFn(CollateFn):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def __call__(self, examples: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
        encoded_batch = tokenize_text(examples, self.tokenizer, self.max_length)

        input_ids = encoded_batch["input_ids"]
        attn = encoded_batch["attention_mask"]
        prefix_mask = get_prefix_mask(input_ids, self.sep_token_id)
        labels: torch.Tensor = input_ids.clone()
        labels[attn == 0] = self.ignore_index
        labels[prefix_mask] = self.ignore_index

        input_ids = mask_input_ids(
            input_ids,
            attn,
            mask_probability=random.uniform(self.min_mask_probability, self.max_mask_probability),
            mask_token_id=self.mask_token_id,
            response_mask=~prefix_mask,
        )

        encoded_batch["labels"] = labels
        return encoded_batch


Test out CollateFns.

In [None]:
sample_messages = [raw_ds[0], raw_ds[1]]

collate_fn = SFTCollateFn(tokenizer)
batch = collate_fn(sample_messages)

for i in range(len(batch["input_ids"])):
    line = batch["input_ids"][i]
    output = tokenizer.decode(line[line != -100], skip_special_tokens=False)
    print(output)
    print("=" * 100)
    expected_output = [
        message["content"] for message in raw_ds[i]["messages"] if message["role"] == "assistant"
    ][0]
    print(expected_output.strip())

<|im_start|>user
Create a snippet of Terraform HCL code that create an AWS autoscaling group, and an ALB in front to expose an application to internet.<|im_end|>
<|im_start|>assistant
Sure, here's an example Terraform H<|mask|> code that creates an AWS Autoscaling Group<|mask|><|mask|><|mask|> Load Balancer to<|mask|> an<|mask|> to the internet:
``` 
# Configure<|mask|><|mask|> provider<|mask|>provider "aws" {
  region = "<|mask|><|mask|>-1"
}

# Create<|mask|> security group to allow traffic to the ALB
<|mask|><|mask|>aws_security_group" "alb<|mask|>" {
<|mask|> name<|mask|> =<|mask|>alb_sg<|mask|> <|mask|> {
<|mask|> from_port =<|mask|>80
    to<|mask|> = <|mask|>0
<|mask|><|mask|> =<|mask|>tcp<|mask|>   <|mask|>r_blocks = ["0.0.0.0<|mask|>0"]
<|mask|><|mask|>}

<|mask|> Create<|mask|><|mask|><|mask|> and<|mask|> group
resource<|mask|>aws<|mask|>"<|mask|>alb"<|mask|><|mask|> name              <|mask|> "example-al<|mask|><|mask|> <|mask|>           = false<|mask|><|mask|> load_balance



### Setup LoRA for Efficient Fine-tuning

LoRA (Low-Rank Adaptation) allows us to fine-tune large models efficiently by only training small adapter matrices. This significantly reduces memory usage and training time while maintaining good performance.


In [95]:
# Configure LoRA
lora_config = peft.LoraConfig(**trainer_config.lora_config.model_dump())

# Apply LoRA to the model
lora_model = peft.get_peft_model(llm_model, lora_config)
lora_model.print_trainable_parameters()


trainable params: 33,030,144 || all params: 4,054,817,280 || trainable%: 0.8146


In [None]:
@torch.inference_mode()
def _calculate_accuracy(
    model: transformers.PreTrainedModel, batch: dict[str, torch.Tensor], mask_token_id: int
) -> tuple[float, float]:
    input_ids = batch["input_ids"]
    labels = batch["labels"]
    attention_mask = batch["attention_mask"]

    outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    logits = outputs.logits
    predicted_ids = torch.argmax(logits, dim=-1)

    mask_token_id_mask = input_ids == mask_token_id
    non_mask_token_id_mask = (~mask_token_id_mask) & (attention_mask == 1)

    non_mask_accuracy = (
        (predicted_ids[non_mask_token_id_mask] == labels[non_mask_token_id_mask]).mean().item()
    )
    mask_accuracy = (predicted_ids[mask_token_id_mask] == labels[mask_token_id_mask]).mean().item()
    return non_mask_accuracy, mask_accuracy


class AccuracyCallback(TrainerCallback):
    def __init__(
        self,
        model: transformers.PreTrainedModel,
        dataset: datasets.Dataset,
        log_interval: int,
        collate_fn: CollateFn,
        batch_size: int,
    ):
        self.log_interval = log_interval
        self.collate_fn = collate_fn
        self.dataset = dataset
        self.last_logged_step = 0
        self.batch_size = batch_size

    def on_step_end(self, args, state, control, **kwargs):
        # Log predictions every log_interval steps
        if (
            state.global_step > 0
            and state.global_step % self.log_interval == 0
            and state.global_step != self.last_logged_step
        ):
            sample_data = [
                self.dataset[i % len(self.dataset)]
                for i in range(state.global_step, state.global_step + self.batch_size)
            ]
            batch = self.collate_fn(sample_data)
            self.model.eval()
            overall_accuracy, mask_accuracy = _calculate_accuracy(
                self.model, batch, self.collate_fn.mask_token_id
            )
            self.model.train()
            wandb.log(
                {
                    "overall_accuracy": overall_accuracy,
                    "mask_accuracy": mask_accuracy,
                    "step": state.global_step,
                }
            )

In [None]:
def first_idx(row_ids: torch.Tensor, end_token_id: int, max_length: int) -> int:
    pos = (row_ids == end_token_id).nonzero(as_tuple=False)
    return pos[0].item() if pos.numel() > 0 else max_length


def get_quotas(init_mask_counts: torch.Tensor, steps: int) -> torch.Tensor:
    base = init_mask_counts // steps
    rem = init_mask_counts % steps
    return torch.stack(
        [(base + (rem > s).long()) for s in range(steps)],
        dim=0,
    )


class DiffusionInference:
    """
    Iteratively infills mask tokens over `steps`. For each batch row:
      - Compute the initial number of masks *before* the first EOS (if any).
      - Divide that count across steps (distributing remainders to earlier steps).
      - At each step, choose the top-k masked positions by confidence (max logit)
        and fill only those with their argmax token.
      - On the final step, unmask all remaining (before EOS).
    Positions after the first EOS (pre-existing or newly predicted) are ignored.
    """

    def __init__(
        self,
        model: transformers.PreTrainedModel,
        mask_token_id: int,
        end_token_id: int,
        steps: int,
    ):
        assert steps >= 1, "steps must be >= 1"
        self.model = model
        self.mask_token_id = mask_token_id
        self.end_token_id = end_token_id
        self.steps = steps

    @torch.inference_mode()
    def __call__(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        was_training = self.model.training
        self.model.eval()

        # Work on copies to avoid mutating the original batch in-place.
        input_ids = batch["input_ids"].clone()
        attention_mask = batch.get("attention_mask", None)

        device = input_ids.device
        batch_size, L = input_ids.shape

        # Initial EOS boundaries and initial mask counts (only up to EOS!)
        eos_idx = torch.tensor(
            [first_idx(input_ids[b], self.end_token_id, L) for b in range(B)]
        ).to(device)
        init_mask_counts = torch.tensor(
            [(input_ids[b, : eos_idx[b]] == self.mask_token_id).sum() for b in range(B)]
        ).to(device)

        quotas = get_quotas(init_mask_counts, self.steps)

        for s in range(self.steps):
            # Forward pass
            outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits  # [B, L, V]

            # For each row, pick top-k masked positions by confidence (max logit).
            for b in range(batch_size):
                # Recompute EOS boundary each step (it may have been newly predicted).
                e = first_idx(input_ids[b], self.end_token_id, L)

                # Masked positions strictly before EOS
                mask_pos = (
                    (input_ids[b, :e] == self.mask_token_id).nonzero(as_tuple=False).squeeze(-1)
                )
                if mask_pos.numel() == 0:
                    continue

                k = int(min(quotas[s, b].item(), mask_pos.numel()))

                # Compute per-position confidence and predicted token.
                # conf: [num_masks], pred_tok: [num_masks]
                per_pos_logits = logits[b, mask_pos, :]  # [num_masks, V]
                confidence, pred_tok = per_pos_logits.max(
                    dim=-1
                )  # argmax token + its logit, size: [num_masks]

                # Select top-k positions by confidence.
                topk_idx = confidence.topk(k=k, dim=0).indices  # [k]
                pos_to_update = mask_pos[topk_idx]  # [k]
                # Overwrite selected masked positions with their predicted tokens.
                input_ids[b, pos_to_update] = pred_tok[topk_idx]  # [k]

            # Optional early exit if nothing masked remains anywhere.
            if not (input_ids == self.mask_token_id).any():
                break

        if was_training:
            self.model.train()

        return {"input_ids": input_ids, "attention_mask": attention_mask}

In [None]:
# Custom callback to log predictions every N steps
class PredictionLoggingCallback(TrainerCallback):
    def __init__(self, log_interval: int, collate_fn, dataset):
        self.log_interval = log_interval
        self.collate_fn = collate_fn
        self.dataset = dataset
        self.last_logged_step = 0
    
    def on_step_end(self, args, state, control, **kwargs):
        # Log predictions every log_interval steps
        if state.global_step > 0 and state.global_step % self.log_interval == 0 and state.global_step != self.last_logged_step:
            model = kwargs.get("model")
            # Create a small sample batch from the dataset
            sample_data = []
            for i, example in enumerate(self.dataset):
                if i >= 16:
                    break
                sample_data.append(example)
            batch = self.collate_fn(sample_data)
            log_predictions(model, batch, state.global_step)
            self.last_logged_step = state.global_step


### Training with HuggingFace Trainer

Using the standard `Trainer` (not `SFTTrainer`) to preserve full sequence training:
- `SFTTrainer` would automatically mask non-assistant tokens for instruction fine-tuning
- We use `Trainer` with our custom `CollateFn` to train on the **entire sequence**
- Uses `trainer_config.gradient_accumulation_steps` from config
- Custom callback logs 8 predictions every 1000 iterations to wandb tables


In [None]:
# Setup training arguments using values from trainer_config
training_args = TrainingArguments(
    output_dir="./lora_adapter",
    num_train_epochs=trainer_config.epochs,
    per_device_train_batch_size=trainer_config.batch_size,
    gradient_accumulation_steps=trainer_config.gradient_accumulation_steps,
    learning_rate=trainer_config.lr,
    logging_steps=10,
    save_strategy="epoch",
    report_to="wandb",
    remove_unused_columns=False,
    bf16=device.type in {"cuda", "mps"},
)

# Create the Trainer (NOT SFTTrainer to preserve full sequence training)
# SFTTrainer would mask non-assistant tokens, but we want to train on the entire sequence
trainer = Trainer(
    model=lora_model,
    args=training_args,
    train_dataset=raw_ds,
    data_collator=collate_fn,
    callbacks=[PredictionLoggingCallback(trainer_config.log_interval, collate_fn, raw_ds)],
)

# Start training
trainer.train()
wandb.finish()
