In [None]:
import torch as t
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import tqdm
from dataclasses import dataclass

t.manual_seed(0)

DATA_PATH="../../datasets"
DATASET_NAME="dune"

MODEL_NAME = "Qwen/Qwen3-0.6B-base"

DEVICE="cuda"

In [None]:
from huggingface_hub.constants import HF_HUB_CACHE

HF_HUB_CACHE

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, Qwen3ForCausalLM

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

def tokenize(batch, tokenizer: AutoTokenizer):
    return tokenizer(batch["text"], truncation=True, padding=False)

In [None]:
from datasets import Dataset, load_from_disk
from torch.utils.data import DataLoader
import os

dataset_path = os.path.join(DATA_PATH, "processed", DATASET_NAME)

if os.path.exists(dataset_path):
    ds = load_from_disk(dataset_path)
else:
    text = open(os.path.join(DATA_PATH, "dune.txt")).read()
    chunk_size = 1024
    chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size) if text[i:i+chunk_size].strip()]

    raw = Dataset.from_list([{"text": c} for c in chunks]).train_test_split(test_size=0.1)

    ds = raw.map(
        tokenize,
        batched=True,
        remove_columns=raw["train"].column_names,
        fn_kwargs={"tokenizer": tokenizer}
    )

    os.makedirs(dataset_path)
    ds.save_to_disk(dataset_path)

ds

In [None]:
model: Qwen3ForCausalLM = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    dtype="auto",
    device_map=DEVICE
)

assert model.device.type == DEVICE

model

In [None]:
# test input
prompt = "Paul"

def generate(model, prompt):
    model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    # conduct text completion
    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=150
    )
    output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() 

    content = tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")

    return content


def generate_stream(model, prompt, max_new_tokens=150):
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)

    for _ in range(max_new_tokens):
        with t.no_grad():
            logits = model(input_ids).logits

        next_token = logits[:, -1].argmax(dim=-1, keepdim=True)
        input_ids = t.cat([input_ids, next_token], dim=1)

        text = tokenizer.decode(next_token[0], skip_special_tokens=True)
        print(text, end="", flush=True)

generate_stream(model, prompt)

In [None]:
print(f"total num params: {model.num_parameters(True)}")

In [None]:
class LoRAParameterization(nn.Module):
    def __init__(self, in_features, out_features, rank=1, alpha=1., device='cuda'):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features

        self.lora_A = nn.Parameter(t.randn((rank, self.out_features)).to(device))
        self.lora_B = nn.Parameter(t.zeros((self.in_features, rank)).to(device))

        self.scale = alpha/rank
        self.enabled = False
    
    def forward(self, w: t.Tensor):
        if self.enabled:
            assert w.shape == (self.out_features, self.in_features)
            return w + (self.lora_B @ self.lora_A) * self.scale
        return w



In [None]:
from torch.nn.utils import parametrize

def apply_lora(model: nn.Module, target_modules=("q_proj"), rank=8, alpha=16):    
    for name, module in model.named_modules():
        if not isinstance(module, nn.Linear):
            continue
        
        if not any(m in name for m in target_modules):
            continue
        
        parametrize.register_parametrization(
            module,
            "weight",
            LoRAParameterization(
                in_features=module.in_features,
                out_features=module.out_features,
                rank=rank,
                alpha=alpha,
            )
        )

def enable_lora(model: nn.Module):
    for p in model.parameters():
        p.requires_grad = False
    
    for m in model.modules():
        if not parametrize.is_parametrized(m, "weight"):
            continue
                
        m.parametrizations.weight[0].enabled = True
        for p in m.parametrizations.weight[0].parameters():
            p.requires_grad = True

    added_params = model.num_parameters(True)
    
    return added_params


In [None]:
@dataclass
class LoraArguments:
    batch_size=8,
    rank=8,
    alpha=1.0,    

apply_lora(model, target_modules=("q_proj", "k_proj", "v_proj"), rank=LoraArguments.rank, alpha=LoraArguments.alpha)
num_lora_params = enable_lora(model)

In [None]:
print(f"num params (original): {model.num_parameters(False) - num_lora_params}")
print(f"num params (after lora): {model.num_parameters(False)}")

print(f"num params added by lora: {num_lora_params}")
print(f"lora params %: {num_lora_params / model.num_parameters(False) * 100.}%")

In [None]:
from torchinfo import summary

summary(model, col_names=["num_params", "trainable"])

In [None]:
def train(model: nn.Module, trainset: DataLoader, epochs=1):
    trainloader = DataLoader(trainset, batch_size=LoraArguments.batch_size, shuffle=True)

    optimizer = t.optim.Adam(model.parameters(), lr=1e-3)
    loss_list = []

    for epoch in range(epochs):
        pbar = tqdm.tqdm(trainloader)

        for x, y in pbar:
            # Move data to device, perform forward pass
            x, y = x.to(DEVICE), y.to(DEVICE)
            logits = model(x)

            # Calculate loss, perform backward pass
            loss = F.cross_entropy(logits, y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            # Update logs & progress bar
            loss_list.append(loss.item())
            pbar.set_postfix(epoch=f"{epoch + 1}/{epochs}", loss=f"{loss:.3f}")
        
train(model, ds["train"], epochs=1)

In [None]:
generate_stream(model, prompt)