In [None]:
import torch as t
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import tqdm
from dataclasses import dataclass

t.manual_seed(0)

DATA_PATH="../../datasets"
DATASET_NAME="dune"

MODEL_NAME = "Qwen/Qwen3-0.6B-base"

DEVICE="cuda"

In [None]:
from huggingface_hub.constants import HF_HUB_CACHE

HF_HUB_CACHE

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, Qwen3ForCausalLM, DataCollatorWithPadding

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

def tokenize(batch, tokenizer: AutoTokenizer):
    return tokenizer(batch["text"], truncation=True, padding=False, max_length=256)

In [None]:
from datasets import Dataset, load_from_disk
from torch.utils.data import DataLoader
import os

dataset_path = os.path.join(DATA_PATH, "processed", DATASET_NAME)

if os.path.exists(dataset_path):
    ds = load_from_disk(dataset_path)
else:
    text = open(os.path.join(DATA_PATH, "dune.txt")).read()
    chunk_size = 1024
    chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size) if text[i:i+chunk_size].strip()]

    raw = Dataset.from_list([{"text": c} for c in chunks]).train_test_split(test_size=0.1)

    ds = raw.map(
        tokenize,
        batched=True,
        remove_columns=raw["train"].column_names,
        fn_kwargs={"tokenizer": tokenizer}
    )

    os.makedirs(dataset_path)
    ds.save_to_disk(dataset_path)

ds

In [None]:
model: Qwen3ForCausalLM = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    dtype=t.bfloat16,  # instead of dtype="auto"
    device_map=DEVICE
)

model.gradient_checkpointing_enable()

assert model.device.type == DEVICE

model

In [None]:
# test input
prompt = "Paul"

def generate(model, prompt):
    model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    # conduct text completion
    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=150
    )
    output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() 

    content = tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")

    return content


def generate_stream(model, prompt, max_new_tokens=150):
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)

    for _ in range(max_new_tokens):
        with t.no_grad():
            logits = model(input_ids).logits

        next_token = logits[:, -1].argmax(dim=-1, keepdim=True)
        input_ids = t.cat([input_ids, next_token], dim=1)

        text = tokenizer.decode(next_token[0], skip_special_tokens=True)
        print(text, end="", flush=True)

generate_stream(model, prompt)

In [None]:
print(f"total num params: {model.num_parameters(True)}")

In [None]:
from lora import apply_lora, enable_lora

@dataclass
class LoraArguments:
    batch_size=2
    rank=8
    alpha=16.0


apply_lora(model, target_modules=("q_proj", "k_proj", "v_proj", "o_proj"), rank=LoraArguments.rank, alpha=LoraArguments.alpha)
num_lora_params = enable_lora(model)

In [None]:
print(f"num params (original): {model.num_parameters(False) - num_lora_params}")
print(f"num params (after lora): {model.num_parameters(False)}")

print(f"num params added by lora: {num_lora_params}")
print(f"lora params %: {num_lora_params / model.num_parameters(False) * 100.}%")

In [None]:
from torchinfo import summary

summary(model, col_names=["num_params", "trainable"])

In [None]:
print(ds["train"].column_names)

In [None]:
def train(model: nn.Module, trainset: Dataset, epochs=1):
    model.train()
    model.config.use_cache = False
    model.enable_input_require_grads()

    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
    trainloader = DataLoader(
        trainset,
        batch_size=LoraArguments.batch_size,
        shuffle=True,
        collate_fn=data_collator,
    )

    trainable = [p for p in model.parameters() if p.requires_grad]
    optimizer = t.optim.AdamW(trainable, lr=2e-4, weight_decay=0.0)
    loss_list = []
    ema = None

    for epoch in range(epochs):
        pbar = tqdm.tqdm(trainloader)
        for batch in pbar:
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)

            logits = model(input_ids=input_ids, attention_mask=attention_mask).logits  # [B,T,V]

            logits = logits[:, :-1, :]
            targets = input_ids[:, 1:]
            mask = attention_mask[:, 1:].bool()
            targets = targets.masked_fill(~mask, -100)

            loss = F.cross_entropy(
                logits.reshape(-1, logits.size(-1)),
                targets.reshape(-1),
                ignore_index=-100,
            )

            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            t.nn.utils.clip_grad_norm_(trainable, 1.0)
            optimizer.step()

            li = loss.item()
            loss_list.append(li)
            ema = li if ema is None else 0.98 * ema + 0.02 * li
            pbar.set_postfix(epoch=f"{epoch+1}/{epochs}", loss=f"{li:.3f}", ema=f"{ema:.3f}")

    return loss_list

            
train(model, ds["train"], epochs=1)

In [None]:
generate_stream(model, "hello")