In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, PeftModel
from flashml.schedulers import LRConsineAnnealingWithLinearWarmup
import polars as pl
import torch
from bitsandbytes.optim import PagedLion8bit
from flashml import (
    inspect_model,
)
import os


os.environ["CUDA_LAUNCH_BLOCKING"] = "1"


HYPERPARAMS = {
    "model": "Qwen/Qwen3-0.6B",  # "tiiuae/Falcon-H1-0.5B-Base",
    "continue_from_index": -1,
    "seed": 42,
    "batch_size": 2,
    "gradient_accumulation": 8,
    "cross_entropy_weight": torch.tensor(
        [0.0785904383236605, 0.9214095616763395], dtype=torch.float
    ),
    "epochs": 1,
    "lr": 2e-5,
    "betas": (0.9, 0.999),
    "weight_decay": 0.005,
    "quant_config": BitsAndBytesConfig(
        # load_in_8bit=True,
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
    ),
    "lora_config": LoraConfig(
        r=32,  # 8
        lora_alpha=32,  # 16
        target_modules=[
            "q_proj",
            "v_proj",
        ],
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
    ),
}


In [None]:
model = AutoModelForCausalLM.from_pretrained(
    HYPERPARAMS["model"],
    quantization_config=HYPERPARAMS["quant_config"],
    device_map="cuda",
)

inspect_model(model)

In [None]:
import torch
x = torch.triu(torch.randn(100, 333), diagonal=-50).cuda()
from flashml import plot_tensor

plot_tensor(x)

In [5]:
from flashml import plot_tsne
import numpy as np

x1 = np.random.normal(loc=34, scale=3, size=(120, 99))
x2 = np.random.normal(loc=22, scale=5, size=(120, 99))
x3 = np.random.normal(loc=-32, scale=2, size=(200, 99))

x = np.concatenate([x1, x2, x3], axis=0)
plot_tsne(x, labels=["A"] * 120 + ["B"] * 120 + ["C"] * 200)

[t-SNE] Computing 91 nearest neighbors...
[t-SNE] Indexed 440 samples in 0.001s...
[t-SNE] Computed neighbors for 440 samples in 0.085s...
[t-SNE] Computed conditional probabilities for sample 440 / 440
[t-SNE] Mean sigma: 8.189783
[t-SNE] KL divergence after 250 iterations with early exaggeration: 56.438591
[t-SNE] KL divergence after 300 iterations: 1.633399


array([[ 0.87304693,  1.1755126 ,  4.5597153 ],
       [-0.58325976,  3.2334232 ,  5.8059535 ],
       [-0.62448126,  1.1317172 ,  3.1020374 ],
       ...,
       [ 1.016537  ,  2.0812507 , -0.02667014],
       [ 2.9923937 ,  1.9250925 , -1.1484497 ],
       [ 3.6033046 ,  1.806453  ,  2.247579  ]],
      shape=(440, 3), dtype=float32)