<a href="https://colab.research.google.com/github/unique-subedi/posttraining/blob/main/sft.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -U bitsandbytes
!pip install -U trl

Collecting trl
  Downloading trl-0.27.1-py3-none-any.whl.metadata (11 kB)
Downloading trl-0.27.1-py3-none-any.whl (532 kB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m532.9/532.9 kB[0m [31m33.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: trl
Successfully installed trl-0.27.1


In [None]:
!nvidia-smi

Thu Jan 29 16:48:58 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          Off |   00000000:00:05.0 Off |                    0 |
| N/A   33C    P0             59W /  400W |    7279MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# Model

In [None]:
MODEL_NAME = "deepseek-ai/deepseek-llm-7b-chat"

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16,
)


In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)

In [None]:
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map= "auto", quantization_config=bnb_config, dtype=torch.bfloat16)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
print(torch.cuda.memory_allocated() / 1e9, "GB")

9.864994816 GB


# Pre-Finetuning Chat

In [None]:
def chat(prompt, max_new_tokens=200):
    messages = [{"role": "user", "content": prompt}]

    text = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    tokenize=False,
    )

    enc = tokenizer(text, return_tensors="pt")

    input_ids = enc["input_ids"].to(model.device)
    attention_mask = enc["attention_mask"].to(model.device)

    with torch.no_grad():
        out = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=tokenizer.eos_token_id,
        )

    print(tokenizer.decode(out[0][input_ids.shape[1]:], skip_special_tokens=True))

chat("Explain a margin call in 2 sentences.")

 A margin call happens when your broker requires you to add funds because your position lost value. If you don‚Äôt, the broker may sell assets to reduce risk.


# LORA Setup

In [None]:
from peft import LoraConfig, get_peft_model

In [None]:
model.config.use_cache=False

In [None]:
lora_config = lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

trainable params: 37,478,400 || all params: 6,947,844,096 || trainable%: 0.5394


# Data

In [None]:
from datasets import Dataset

data = [
    {
        "messages": [
            {"role": "user", "content": "Explain what a margin call is in 2 sentences."},
            {"role": "assistant", "content": "A margin call happens when your broker requires you to add funds because your position lost value. If you don‚Äôt, the broker may sell assets to reduce risk."},
        ]
    },
    {
        "messages": [
            {"role": "user", "content": "Rewrite concisely: 'Next, we validate the conclusions of Theorem 1 with experiments.'"},
            {"role": "assistant", "content": "Next, we validate Theorem 1 experimentally."},
        ]
    },
]

ds = Dataset.from_list(data)
def render(ex):
    return {"text": tokenizer.apply_chat_template(ex["messages"], tokenize=False)}

train_ds = ds.map(render)


print(train_ds[0]["text"][:400])


Map:   0%|          | 0/2 [00:00<?, ? examples/s]

<ÔΩúbegin‚ñÅof‚ñÅsentenceÔΩú>User: Explain what a margin call is in 2 sentences.

Assistant: A margin call happens when your broker requires you to add funds because your position lost value. If you don‚Äôt, the broker may sell assets to reduce risk.<ÔΩúend‚ñÅof‚ñÅsentenceÔΩú>


# Train

In [None]:
from trl import SFTTrainer
from transformers import TrainingArguments

args = TrainingArguments(
    output_dir="out",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    learning_rate=2e-4,
    max_steps=50,          #
    logging_steps=5,
    report_to="none",
)

trainer = SFTTrainer(
        model=model,
        train_dataset=train_ds,
        args=args
    )

trainer.train()




Tokenizing train dataset:   0%|          | 0/2 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/2 [00:00<?, ? examples/s]

Step,Training Loss
5,2.8537
10,1.2848
15,0.2315
20,0.0559
25,0.0256
30,0.0214
35,0.0174
40,0.0173
45,0.017
50,0.0177


TrainOutput(global_step=50, training_loss=0.4542358377575874, metrics={'train_runtime': 20.0127, 'train_samples_per_second': 2.498, 'train_steps_per_second': 2.498, 'total_flos': 85195798732800.0, 'train_loss': 0.4542358377575874})

# Evaluate

In [None]:
test_prompts = [
    "Explain a margin call in 2 sentences.",
    "Rewrite concisely: Next, we validate the conclusions of Theorem 1 with experiments.",
    "What is overfitting in machine learning?",
]

for p in test_prompts:
    print("PROMPT:", p)
    chat(p)
    print("-" * 40)

PROMPT: Explain a margin call in 2 sentences.
 A margin call happens when your broker requires you to add funds because your position lost value. If you don‚Äôt, the broker may sell assets to reduce risk.
----------------------------------------
PROMPT: Rewrite concisely: Next, we validate the conclusions of Theorem 1 with experiments.
 Next, we validate Theorem 1 experimentally.
----------------------------------------
PROMPT: What is overfitting in machine learning?
 In machine learning, overfitting happens when a model learns the training data too well, to the detriment of its ability to generalize to new, unseen data. In other words, the model learns the training data so well that it starts to include noise and outliers in the training data in its decision boundary. As a result, it performs poorly on new data.
----------------------------------------
