diff --git a/Dockerfile_Strix_Halo b/Dockerfile_Strix_Halo
new file mode 100644
index 00000000..edbbbd27
--- /dev/null
+++ b/Dockerfile_Strix_Halo
@@ -0,0 +1,53 @@
+FROM rocm/pytorch:rocm6.4.4_ubuntu24.04_py3.12_pytorch_release_2.7.1
+
+# Create a new user with sudo privileges (passwordless)
+RUN apt-get update && apt-get install -y sudo && \
+ useradd -m -s /bin/bash user && \
+ usermod -aG sudo user && \
+ echo "user ALL=(ALL) NOPASSWD:ALL" >> /etc/sudoers && \
+ apt-get clean && rm -rf /var/lib/apt/lists/*
+
+WORKDIR /opt/src
+
+# bitsandbytes (ROCm)
+RUN git clone -b rocm_enabled_multi_backend https://github.com/ROCm/bitsandbytes.git
+WORKDIR /opt/src/bitsandbytes
+RUN cmake -S . -DGPU_TARGETS="gfx1151" -DBNB_ROCM_ARCH="gfx1151" -DCOMPUTE_BACKEND=hip && \
+ make -j && \
+ python -m pip install --no-cache-dir .
+
+# Python deps
+RUN python -m pip install --no-cache-dir \
+ 'datasets>=3.4.1' \
+ 'sentencepiece>=0.2.0' \
+ tqdm psutil 'wheel>=0.42.0' \
+ 'accelerate>=0.34.1' \
+ 'peft>=0.7.1,!=0.11.0' \
+ einops packaging
+
+# xformers (pinned)
+WORKDIR /opt/src
+RUN git clone https://github.com/ROCm/xformers.git
+WORKDIR /opt/src/xformers
+RUN git submodule update --init --recursive && \
+ git checkout 13c93f3 && \
+ PYTORCH_ROCM_ARCH=gfx1151 python setup.py install
+
+ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
+WORKDIR /root
+RUN git clone https://github.com/ROCm/flash-attention.git
+RUN cd flash-attention && git checkout v2.7.4-cktile && python setup.py install
+
+# Unsloth (install first), then Zoo
+WORKDIR /opt/src
+RUN git clone https://github.com/unslothai/unsloth.git
+
+WORKDIR /opt/src/unsloth
+RUN python -m pip install --no-cache-dir .
+RUN python -m pip install --no-cache-dir jupyterlab ipywidgets ipykernel tqdm
+RUN python -m pip install --no-cache-dir 'unsloth_zoo>=2025.5.7'
+
+# Set default user and working directory
+USER user
+WORKDIR /home/user
+CMD ["/bin/bash"]
diff --git a/gpt-oss-(20B)_StrixHalo-Fine-tuning.ipynb b/gpt-oss-(20B)_StrixHalo-Fine-tuning.ipynb
new file mode 100644
index 00000000..fb1d450d
--- /dev/null
+++ b/gpt-oss-(20B)_StrixHalo-Fine-tuning.ipynb
@@ -0,0 +1,839 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "md-title",
+ "metadata": {},
+ "source": [
+ "# Fine-tuning LLMs with AMD Strix Halo and Unsloth\n",
+ "\n",
+ "Tutorial on fine-tuning gpt-oss-20b (and others) on AMD Strix Halo using Unsloth. This mirrors the DGX Spark example with Strix Halo specifics."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "md-prereqs",
+ "metadata": {},
+ "source": [
+ "## Start with Unsloth Image for Strix Halo\n",
+ "\n",
+ "You can use the prebuilt toolbox image or build locally.\n",
+ "\n",
+ "Option A — Toolbox (recommended):\n",
+ "\n",
+ "```bash\n",
+ "toolbox create strix-halo-llm-finetuning \\\n",
+ " --image docker.io/kyuz0/amd-strix-halo-llm-finetuning:latest \\\n",
+ " -- --device /dev/dri --device /dev/kfd \\\n",
+ " --group-add video --group-add render --security-opt seccomp=unconfined\n",
+ "\n",
+ "toolbox enter strix-halo-llm-finetuning\n",
+ "```\n",
+ "\n",
+ "Option B — Local Docker build from this repo:\n",
+ "\n",
+ "```bash\n",
+ "docker build -f Dockerfile -t unsloth-strix-halo .\n",
+ "docker run -it --device /dev/kfd --device /dev/dri \\\n",
+ " --group-add=render --group-add=video -p 8888:8888 \\\n",
+ " -v $(pwd):/work -w /work unsloth-strix-halo\n",
+ "```\n",
+ "\n",
+ "## Start Jupyter and Run Notebooks\n",
+ "\n",
+ "Inside the container:\n",
+ "\n",
+ "```bash\n",
+ "jupyter lab --notebook-dir /work\n",
+ "```\n",
+ "\n",
+ "If using the toolbox image (see README.md):\n",
+ "\n",
+ "```bash\n",
+ "mkdir -p ~/finetuning-workspace/\n",
+ "cp -r /opt/workspace ~/finetuning-workspace/\n",
+ "jupyter lab --notebook-dir ~/finetuning-workspace/\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "46596bf2",
+ "metadata": {},
+ "source": [
+ "Dockerfile https://github.com/unslothai/notebooks/blob/main/Dockerfile_Strix_Halo"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "code-env-check",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "PyTorch: 2.7.1+git99ccf24\n",
+ "ROCm: 6.4.43484-123eb5128\n",
+ "CUDA (None on ROCm expected): None\n",
+ "torch.cuda.is_available(): True\n",
+ "Device: AMD Radeon Graphics\n",
+ "Total VRAM/Unified (GiB): 128.0\n",
+ "g++ (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0\n",
+ "Copyright (C) 2023 Free Software Foundation, Inc.\n",
+ "This is free software; see the source for copying conditions. There is NO\n",
+ "warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.\n",
+ "\n",
+ "🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.\n",
+ "🦥 Unsloth Zoo will now patch everything to make training faster!\n",
+ "Unsloth: 2025.11.3\n",
+ "Transformers: 4.57.1\n",
+ "TRL: 0.24.0\n",
+ "Env UNSLOTH_FA2_COMPUTE_DTYPE = None\n",
+ "Env UNSLOTH_ROPE_IMPL = None\n",
+ "Env UNSLOTH_DISABLE_TRITON_RMSNORM = None\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Quick environment check (ROCm + Unsloth)\n",
+ "import os, torch\n",
+ "\n",
+ "print('PyTorch:', torch.__version__)\n",
+ "print('ROCm:', getattr(torch.version, 'hip', None))\n",
+ "print('CUDA (None on ROCm expected):', torch.version.cuda)\n",
+ "print('torch.cuda.is_available():', torch.cuda.is_available())\n",
+ "\n",
+ "if torch.cuda.is_available():\n",
+ " try:\n",
+ " print('Device:', torch.cuda.get_device_name(0))\n",
+ " props = torch.cuda.get_device_properties(0)\n",
+ " print('Total VRAM/Unified (GiB):', round(props.total_memory/1024**3, 2))\n",
+ " except Exception as e:\n",
+ " print('Device info error:', e)\n",
+ "\n",
+ "try:\n",
+ " import unsloth, transformers, trl\n",
+ " print('Unsloth:', getattr(unsloth, '__version__', 'unknown'))\n",
+ " print('Transformers:', transformers.__version__)\n",
+ " print('TRL:', trl.__version__)\n",
+ "except Exception as e:\n",
+ " print('Package import error:', e)\n",
+ "\n",
+ "for k in ('UNSLOTH_FA2_COMPUTE_DTYPE','UNSLOTH_ROPE_IMPL','UNSLOTH_DISABLE_TRITON_RMSNORM'):\n",
+ " print(f'Env {k} =', os.environ.get(k))\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "md-config",
+ "metadata": {},
+ "source": [
+ "## Configuration and Hyperparameters\n",
+ "\n",
+ "Sets model name, sequence length, dtypes, 4-bit loading, and Unsloth ROCm tuning env vars."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "code-hparams",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from datasets import load_dataset\n",
+ "from trl import SFTConfig, SFTTrainer\n",
+ "\n",
+ "from unsloth import FastLanguageModel\n",
+ "from unsloth.chat_templates import (\n",
+ " standardize_sharegpt,\n",
+ " train_on_responses_only,\n",
+ ")\n",
+ "from transformers import TextStreamer\n",
+ "\n",
+ "# Common hyperparameters\n",
+ "MODEL_NAME = \"unsloth/gpt-oss-20b\" # or \"unsloth/gpt-oss-20b-unsloth-bnb-4bit\"\n",
+ "max_seq_length = 2048\n",
+ "dtype = None # let Unsloth auto-detect\n",
+ "load_in_4bit = True # 4bit for memory\n",
+ "LR = 2e-4\n",
+ "EPOCHS = 1 # or use max_steps if you prefer\n",
+ "BATCH_SIZE = 1 # you can crank this up if memory allows\n",
+ "\n",
+ "# Set ROCm logging / Unsloth preferences\n",
+ "import os\n",
+ "# os.environ['PYTORCH_ROCM_LOG_LEVEL'] = 'DEBUG'\n",
+ "os.environ['UNSLOTH_FA2_COMPUTE_DTYPE'] = 'float16'\n",
+ "os.environ['UNSLOTH_ROPE_IMPL'] = 'slow'\n",
+ "os.environ['UNSLOTH_DISABLE_TRITON_RMSNORM'] = '1'\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "md-smoketest",
+ "metadata": {},
+ "source": [
+ "## Quick Model Smoke Test (optional)\n",
+ "\n",
+ "Verifies the base model can load with LoRA adapters under ROCm."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "code-smoketest",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Unsloth: AMD currently is not stable with 4bit bitsandbytes. Disabling for now.\n",
+ "==((====))== Unsloth 2025.11.3: Fast Gpt_Oss patching. Transformers: 4.57.1.\n",
+ " \\\\ /| AMD Radeon Graphics. Num GPUs = 1. Max memory: 128.0 GB. Platform: Linux.\n",
+ "O^O/ \\_/ \\ Torch: 2.7.1+git99ccf24. ROCm Toolkit: 6.4.43484-123eb5128. Triton: 3.3.1\n",
+ "\\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+13c93f39.d20251112. FA2 = True]\n",
+ " \"-____-\" Free license: http://github.com/unslothai/unsloth\n",
+ "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n",
+ "Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "MXFP4 quantization requires Triton and kernels installed: CUDA requires Triton >= 3.4.0, XPU requires Triton >= 3.5.0, we will default to dequantizing the model to bf16\n",
+ "[transformers.quantizers.quantizer_mxfp4|WARNING]MXFP4 quantization requires Triton and kernels installed: CUDA requires Triton >= 3.4.0, XPU requires Triton >= 3.5.0, we will default to dequantizing the model to bf16\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "2c6f318b3fa24aa7bc370f6bdc9b3fdf",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Loading checkpoint shards: 0%| | 0/3 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Unsloth: Making `model.base_model.model.model` require gradients\n"
+ ]
+ }
+ ],
+ "source": [
+ "from unsloth import FastLanguageModel\n",
+ "from trl import SFTConfig, SFTTrainer\n",
+ "from datasets import load_dataset\n",
+ "\n",
+ "max_seq_length = 1024\n",
+ "dtype = None\n",
+ "\n",
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
+ " model_name = \"unsloth/gpt-oss-20b\",\n",
+ " dtype = dtype,\n",
+ " max_seq_length = max_seq_length,\n",
+ " load_in_4bit = True,\n",
+ " full_finetuning = False,\n",
+ ")\n",
+ "\n",
+ "model = FastLanguageModel.get_peft_model(\n",
+ " model,\n",
+ " r = 8,\n",
+ " target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
+ " \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
+ " lora_alpha = 16,\n",
+ " lora_dropout = 0,\n",
+ " bias = \"none\",\n",
+ " use_gradient_checkpointing= \"unsloth\",\n",
+ " random_state = 3407,\n",
+ " use_rslora = False,\n",
+ " loftq_config = None,\n",
+ ")\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "md-dataset",
+ "metadata": {},
+ "source": [
+ "## Dataset Preparation\n",
+ "\n",
+ "Loads a small quotes dataset, converts to chat format, and compiles Harmony-style text with the tokenizer's chat template."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "code-dataset",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "46da164c30b54331847a3879ead0c723",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Map: 0%| | 0/1000 [00:00, ? examples/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "DatasetDict({\n",
+ " train: Dataset({\n",
+ " features: ['text'],\n",
+ " num_rows: 800\n",
+ " })\n",
+ " test: Dataset({\n",
+ " features: ['text'],\n",
+ " num_rows: 200\n",
+ " })\n",
+ "})\n",
+ "{'text': \"<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.\\nKnowledge cutoff: 2024-06\\nCurrent date: 2025-11-13\\n\\nReasoning: medium\\n\\n# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>user<|message|>Give me a quote about: ['books', 'humor']<|end|><|start|>assistant<|channel|>final<|message|>“There are two motives for reading a book; one, that you enjoy it; the other, that you can boast about it.” - Bertrand Russell<|end|>\"}\n"
+ ]
+ }
+ ],
+ "source": [
+ "from datasets import load_dataset\n",
+ "\n",
+ "# 1) Load subset of quotes\n",
+ "quotes_ds = (\n",
+ " load_dataset(\"Abirate/english_quotes\", split=\"train\")\n",
+ " .shuffle(seed=42)\n",
+ " .select(range(1000))\n",
+ ")\n",
+ "\n",
+ "# 2) Turn each row into chat messages\n",
+ "def build_quotes_messages(example):\n",
+ " return {\n",
+ " \"messages\": [\n",
+ " {\n",
+ " \"role\": \"user\",\n",
+ " \"content\": f\"Give me a quote about: {example['tags']}\",\n",
+ " },\n",
+ " {\n",
+ " \"role\": \"assistant\",\n",
+ " \"content\": f\"{example['quote']} - {example['author']}\",\n",
+ " },\n",
+ " ]\n",
+ " }\n",
+ "\n",
+ "quotes_ds = quotes_ds.map(\n",
+ " build_quotes_messages,\n",
+ " remove_columns=quotes_ds.column_names,\n",
+ ")\n",
+ "\n",
+ "# 3) Convert messages → Harmony text using the *existing* tokenizer\n",
+ "def quotes_to_text(batch):\n",
+ " convos = batch[\"messages\"]\n",
+ " texts = [\n",
+ " tokenizer.apply_chat_template(\n",
+ " convo,\n",
+ " tokenize=False,\n",
+ " add_generation_prompt=False,\n",
+ " )\n",
+ " for convo in convos\n",
+ " ]\n",
+ " return {\"text\": texts}\n",
+ "\n",
+ "quotes_ds_text = quotes_ds.map(\n",
+ " quotes_to_text,\n",
+ " batched=True,\n",
+ " remove_columns=[\"messages\"], # we only keep \"text\"\n",
+ ")\n",
+ "\n",
+ "# 4) Train / test split\n",
+ "quotes_ds_split = quotes_ds_text.train_test_split(test_size=0.2, seed=42)\n",
+ "\n",
+ "print(quotes_ds_split)\n",
+ "print(quotes_ds_split[\"train\"][0])\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "md-model",
+ "metadata": {},
+ "source": [
+ "## Load Model and Apply LoRA\n",
+ "\n",
+ "Loads the Unsloth-optimized model and attaches LoRA adapters for memory-efficient fine-tuning."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "code-model",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Unsloth: AMD currently is not stable with 4bit bitsandbytes. Disabling for now.\n",
+ "==((====))== Unsloth 2025.11.3: Fast Gpt_Oss patching. Transformers: 4.57.1.\n",
+ " \\\\ /| AMD Radeon Graphics. Num GPUs = 1. Max memory: 128.0 GB. Platform: Linux.\n",
+ "O^O/ \\_/ \\ Torch: 2.7.1+git99ccf24. ROCm Toolkit: 6.4.43484-123eb5128. Triton: 3.3.1\n",
+ "\\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+13c93f39.d20251112. FA2 = True]\n",
+ " \"-____-\" Free license: http://github.com/unslothai/unsloth\n",
+ "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n",
+ "Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "56fd64c528f344a4914840640961060f",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Loading checkpoint shards: 0%| | 0/3 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Unsloth: Making `model.base_model.model.model` require gradients\n",
+ "trainable params: 3,981,312 || all params: 20,918,738,496 || trainable%: 0.0190\n"
+ ]
+ }
+ ],
+ "source": [
+ "# ==== Load model + tokenizer from Unsloth ====\n",
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
+ " model_name = MODEL_NAME,\n",
+ " max_seq_length = max_seq_length,\n",
+ " dtype = dtype,\n",
+ " load_in_4bit = load_in_4bit,\n",
+ " full_finetuning = False, # we want LoRA, not full finetune\n",
+ ")\n",
+ "\n",
+ "# Attach LoRA via Unsloth\n",
+ "model = FastLanguageModel.get_peft_model(\n",
+ " model,\n",
+ " r = 8,\n",
+ " target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
+ " \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
+ " lora_alpha = 16,\n",
+ " lora_dropout = 0.0,\n",
+ " bias = \"none\",\n",
+ " use_gradient_checkpointing = \"unsloth\",\n",
+ " random_state = 3407,\n",
+ " use_rslora = False,\n",
+ " loftq_config = None,\n",
+ ")\n",
+ "\n",
+ "model.print_trainable_parameters()\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "md-train",
+ "metadata": {},
+ "source": [
+ "## Train\n",
+ "\n",
+ "Configures TRL SFTTrainer and fine-tunes for a small number of steps for validation. Increase steps/epochs for real training."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "code-train",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "b855e1ea761846dda976fad453d40272",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Unsloth: Tokenizing [\"text\"] (num_proc=36): 0%| | 0/800 [00:00, ? examples/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "7a9a7944850f4329a0fbc4a43520fe6c",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Unsloth: Tokenizing [\"text\"] (num_proc=36): 0%| | 0/200 [00:00, ? examples/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "The model is already on multiple devices. Skipping the move to device specified in `args`.\n",
+ "[transformers.trainer|WARNING]The model is already on multiple devices. Skipping the move to device specified in `args`.\n",
+ "==((====))== Unsloth - 2x faster free finetuning | Num GPUs used = 1\n",
+ " \\\\ /| Num examples = 800 | Num Epochs = 1 | Total steps = 30\n",
+ "O^O/ \\_/ \\ Batch size per device = 1 | Gradient accumulation steps = 4\n",
+ "\\ / Data Parallel GPUs = 1 | Total batch size (1 x 4 x 1) = 4\n",
+ " \"-____-\" Trainable parameters = 3,981,312 of 20,918,738,496 (0.02% trained)\n",
+ "[transformers.trainer|WARNING]==((====))== Unsloth - 2x faster free finetuning | Num GPUs used = 1\n",
+ " \\\\ /| Num examples = 800 | Num Epochs = 1 | Total steps = 30\n",
+ "O^O/ \\_/ \\ Batch size per device = 1 | Gradient accumulation steps = 4\n",
+ "\\ / Data Parallel GPUs = 1 | Total batch size (1 x 4 x 1) = 4\n",
+ " \"-____-\" Trainable parameters = 3,981,312 of 20,918,738,496 (0.02% trained)\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n",
+ " \n",
+ "
\n",
+ " [30/30 02:37, Epoch 0/1]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | Step | \n",
+ " Training Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 1 | \n",
+ " 4.703300 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 4.710800 | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 4.778100 | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 4.417100 | \n",
+ "
\n",
+ " \n",
+ " | 5 | \n",
+ " 3.999700 | \n",
+ "
\n",
+ " \n",
+ " | 6 | \n",
+ " 3.761800 | \n",
+ "
\n",
+ " \n",
+ " | 7 | \n",
+ " 3.499500 | \n",
+ "
\n",
+ " \n",
+ " | 8 | \n",
+ " 3.145700 | \n",
+ "
\n",
+ " \n",
+ " | 9 | \n",
+ " 2.823400 | \n",
+ "
\n",
+ " \n",
+ " | 10 | \n",
+ " 2.524900 | \n",
+ "
\n",
+ " \n",
+ " | 11 | \n",
+ " 2.378900 | \n",
+ "
\n",
+ " \n",
+ " | 12 | \n",
+ " 2.065000 | \n",
+ "
\n",
+ " \n",
+ " | 13 | \n",
+ " 1.847900 | \n",
+ "
\n",
+ " \n",
+ " | 14 | \n",
+ " 1.865900 | \n",
+ "
\n",
+ " \n",
+ " | 15 | \n",
+ " 1.782500 | \n",
+ "
\n",
+ " \n",
+ " | 16 | \n",
+ " 1.466100 | \n",
+ "
\n",
+ " \n",
+ " | 17 | \n",
+ " 1.639100 | \n",
+ "
\n",
+ " \n",
+ " | 18 | \n",
+ " 1.232200 | \n",
+ "
\n",
+ " \n",
+ " | 19 | \n",
+ " 1.344100 | \n",
+ "
\n",
+ " \n",
+ " | 20 | \n",
+ " 1.403800 | \n",
+ "
\n",
+ " \n",
+ " | 21 | \n",
+ " 1.313000 | \n",
+ "
\n",
+ " \n",
+ " | 22 | \n",
+ " 1.005100 | \n",
+ "
\n",
+ " \n",
+ " | 23 | \n",
+ " 1.243400 | \n",
+ "
\n",
+ " \n",
+ " | 24 | \n",
+ " 1.181800 | \n",
+ "
\n",
+ " \n",
+ " | 25 | \n",
+ " 1.021400 | \n",
+ "
\n",
+ " \n",
+ " | 26 | \n",
+ " 1.226200 | \n",
+ "
\n",
+ " \n",
+ " | 27 | \n",
+ " 1.479400 | \n",
+ "
\n",
+ " \n",
+ " | 28 | \n",
+ " 1.889800 | \n",
+ "
\n",
+ " \n",
+ " | 29 | \n",
+ " 1.196000 | \n",
+ "
\n",
+ " \n",
+ " | 30 | \n",
+ " 0.932400 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Unsloth: Will smartly offload gradients to save VRAM!\n"
+ ]
+ }
+ ],
+ "source": [
+ "from trl import SFTConfig, SFTTrainer\n",
+ "\n",
+ "quotes_args = SFTConfig(\n",
+ " output_dir = \"outputs-quotes\",\n",
+ " dataset_text_field = \"text\",\n",
+ " packing = False,\n",
+ " num_train_epochs = EPOCHS,\n",
+ " per_device_train_batch_size = BATCH_SIZE,\n",
+ " gradient_accumulation_steps = 4,\n",
+ " warmup_steps = 5,\n",
+ " max_steps = 30, # or num_train_epochs=1, max_steps=None\n",
+ " learning_rate = LR,\n",
+ " logging_steps = 1,\n",
+ " optim = \"adamw_8bit\",\n",
+ " weight_decay = 0.001,\n",
+ " lr_scheduler_type = \"linear\",\n",
+ " seed = 3407,\n",
+ " report_to = \"none\",\n",
+ ")\n",
+ "\n",
+ "quotes_trainer = SFTTrainer(\n",
+ " model = model,\n",
+ " args = quotes_args,\n",
+ " train_dataset = quotes_ds_split[\"train\"],\n",
+ " eval_dataset = quotes_ds_split[\"test\"],\n",
+ " processing_class = tokenizer,\n",
+ " dataset_num_proc = 2,\n",
+ ")\n",
+ "\n",
+ "quotes_stats = quotes_trainer.train()\n",
+ "quotes_trainer.save_model(\"finetuned_quotes\")\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "md-infer",
+ "metadata": {},
+ "source": [
+ "## Inference Test\n",
+ "\n",
+ "Run a quick generation to verify the fine-tuned model responds as expected."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "code-infer",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "systemYou are ChatGPT, a large language model trained by OpenAI.\n",
+ "Knowledge cutoff: 2024-06\n",
+ "Current date: 2025-11-13\n",
+ "\n",
+ "Reasoning: medium\n",
+ "\n",
+ "# Valid channels: analysis, commentary, final. Channel must be included for every message.userGive me a short inspiring quote about persistence.assistantfinal\"Success is not final, failure is not fatal: It is the courage to continue that counts.\" — Winston Churchillassistantfinal\"Keep your face always toward the sunshine—and shadows will fall behind you.\" — Walt Whitman\n"
+ ]
+ }
+ ],
+ "source": [
+ "from transformers import TextStreamer\n",
+ "import torch\n",
+ "\n",
+ "# Prepare a simple prompt\n",
+ "messages = [\n",
+ " {\"role\": \"user\", \"content\": \"Give me a short inspiring quote about persistence.\"}\n",
+ "]\n",
+ "\n",
+ "prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
+ "inputs = tokenizer([prompt], return_tensors=\"pt\").to(model.device)\n",
+ "\n",
+ "# Switch to inference optimizations\n",
+ "model = FastLanguageModel.for_inference(model)\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " outputs = model.generate(\n",
+ " input_ids=inputs[\"input_ids\"],\n",
+ " attention_mask=inputs.get(\"attention_mask\"),\n",
+ " max_new_tokens=60,\n",
+ " do_sample=True,\n",
+ " temperature=0.8,\n",
+ " top_p=0.95,\n",
+ " )\n",
+ "\n",
+ "print(tokenizer.decode(outputs[0], skip_special_tokens=True))\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "md-memory",
+ "metadata": {},
+ "source": [
+ "## Unified Memory Usage\n",
+ "\n",
+ "On Strix Halo with unified memory, 4-bit LoRA fine-tuning of 20B should fit comfortably. Use the below to inspect memory."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "code-memory",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Memory allocated: 39.03 GiB\n",
+ "Max memory reserved: 84.04 GiB\n"
+ ]
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "if torch.cuda.is_available():\n",
+ " print(f'Memory allocated: {torch.cuda.memory_allocated()/1024**3:.2f} GiB')\n",
+ " print(f'Max memory reserved: {torch.cuda.max_memory_reserved()/1024**3:.2f} GiB')\n",
+ "else:\n",
+ " print('CUDA (ROCm) device not available in this environment.')\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "md-troubleshoot",
+ "metadata": {},
+ "source": [
+ "## Troubleshooting\n",
+ "\n",
+ "- GPU not visible: pass `--device /dev/kfd --device /dev/dri` and add user to `render`/`video`.\n",
+ "- OOM or slow: reduce `max_seq_length`, keep 4-bit, increase grad accumulation.\n",
+ "- Kernel params for unified memory (see README): `amd_iommu=off amdgpu.gttsize=131072 ttm.pages_limit=33554432`.\n",
+ "- If FA2/RMSNorm issues: set `UNSLOTH_FA2_COMPUTE_DTYPE=float16` and `UNSLOTH_DISABLE_TRITON_RMSNORM=1`."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "md-refs",
+ "metadata": {},
+ "source": [
+ "\n",
+ "## Credits\n",
+ "\n",
+ "Special thanks to kyuz0 for their Transformers fine-tuning notebook and Dockerfile for setting up ROCm and gfx1151 drivers on Strix Halo:\n",
+ "- https://github.com/kyuz0/amd-strix-halo-llm-finetuning"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.11"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/original_template/gpt-oss-(20B)_StrixHalo-Fine-tuning.ipynb b/original_template/gpt-oss-(20B)_StrixHalo-Fine-tuning.ipynb
new file mode 100644
index 00000000..0cb91e1c
--- /dev/null
+++ b/original_template/gpt-oss-(20B)_StrixHalo-Fine-tuning.ipynb
@@ -0,0 +1,843 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "md-title",
+ "metadata": {},
+ "source": [
+ "# Fine-tuning LLMs with AMD Strix Halo and Unsloth\n",
+ "\n",
+ "Tutorial on fine-tuning gpt-oss-20b (and others) on AMD Strix Halo using Unsloth. This mirrors the DGX Spark example with Strix Halo specifics."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "md-prereqs",
+ "metadata": {},
+ "source": [
+ "## Start with Unsloth Image for Strix Halo\n",
+ "\n",
+ "You can use the prebuilt toolbox image or build locally.\n",
+ "\n",
+ "Option A — Toolbox (recommended):\n",
+ "\n",
+ "```bash\n",
+ "toolbox create strix-halo-llm-finetuning \\\n",
+ " --image docker.io/kyuz0/amd-strix-halo-llm-finetuning:latest \\\n",
+ " -- --device /dev/dri --device /dev/kfd \\\n",
+ " --group-add video --group-add render --security-opt seccomp=unconfined\n",
+ "\n",
+ "toolbox enter strix-halo-llm-finetuning\n",
+ "```\n",
+ "\n",
+ "Option B — Local Docker build from this repo:\n",
+ "\n",
+ "```bash\n",
+ "docker build -f Dockerfile -t unsloth-strix-halo .\n",
+ "docker run -it --device /dev/kfd --device /dev/dri \\\n",
+ " --group-add=render --group-add=video -p 8888:8888 \\\n",
+ " -v $(pwd):/work -w /work unsloth-strix-halo\n",
+ "```\n",
+ "\n",
+ "## Start Jupyter and Run Notebooks\n",
+ "\n",
+ "Inside the container:\n",
+ "\n",
+ "```bash\n",
+ "jupyter lab --notebook-dir /work\n",
+ "```\n",
+ "\n",
+ "If using the toolbox image (see README.md):\n",
+ "\n",
+ "```bash\n",
+ "mkdir -p ~/finetuning-workspace/\n",
+ "cp -r /opt/workspace ~/finetuning-workspace/\n",
+ "jupyter lab --notebook-dir ~/finetuning-workspace/\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "46596bf2",
+ "metadata": {},
+ "source": [
+ "Dockerfile https://github.com/unslothai/notebooks/blob/main/Dockerfile_Strix_Halo"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "code-env-check",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "PyTorch: 2.7.1+git99ccf24\n",
+ "ROCm: 6.4.43484-123eb5128\n",
+ "CUDA (None on ROCm expected): None\n",
+ "torch.cuda.is_available(): True\n",
+ "Device: AMD Radeon Graphics\n",
+ "Total VRAM/Unified (GiB): 128.0\n",
+ "g++ (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0\n",
+ "Copyright (C) 2023 Free Software Foundation, Inc.\n",
+ "This is free software; see the source for copying conditions. There is NO\n",
+ "warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.\n",
+ "\n",
+ "🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.\n",
+ "🦥 Unsloth Zoo will now patch everything to make training faster!\n",
+ "Unsloth: 2025.11.3\n",
+ "Transformers: 4.57.1\n",
+ "TRL: 0.24.0\n",
+ "Env UNSLOTH_FA2_COMPUTE_DTYPE = None\n",
+ "Env UNSLOTH_ROPE_IMPL = None\n",
+ "Env UNSLOTH_DISABLE_TRITON_RMSNORM = None\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Quick environment check (ROCm + Unsloth)\n",
+ "import os, torch\n",
+ "\n",
+ "print('PyTorch:', torch.__version__)\n",
+ "print('ROCm:', getattr(torch.version, 'hip', None))\n",
+ "print('CUDA (None on ROCm expected):', torch.version.cuda)\n",
+ "print('torch.cuda.is_available():', torch.cuda.is_available())\n",
+ "\n",
+ "if torch.cuda.is_available():\n",
+ " try:\n",
+ " print('Device:', torch.cuda.get_device_name(0))\n",
+ " props = torch.cuda.get_device_properties(0)\n",
+ " print('Total VRAM/Unified (GiB):', round(props.total_memory/1024**3, 2))\n",
+ " except Exception as e:\n",
+ " print('Device info error:', e)\n",
+ "\n",
+ "try:\n",
+ " import unsloth, transformers, trl\n",
+ " print('Unsloth:', getattr(unsloth, '__version__', 'unknown'))\n",
+ " print('Transformers:', transformers.__version__)\n",
+ " print('TRL:', trl.__version__)\n",
+ "except Exception as e:\n",
+ " print('Package import error:', e)\n",
+ "\n",
+ "for k in ('UNSLOTH_FA2_COMPUTE_DTYPE','UNSLOTH_ROPE_IMPL','UNSLOTH_DISABLE_TRITON_RMSNORM'):\n",
+ " print(f'Env {k} =', os.environ.get(k))\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "md-config",
+ "metadata": {},
+ "source": [
+ "## Configuration and Hyperparameters\n",
+ "\n",
+ "Sets model name, sequence length, dtypes, 4-bit loading, and Unsloth ROCm tuning env vars."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "code-hparams",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from datasets import load_dataset\n",
+ "from trl import SFTConfig, SFTTrainer\n",
+ "\n",
+ "from unsloth import FastLanguageModel\n",
+ "from unsloth.chat_templates import (\n",
+ " standardize_sharegpt,\n",
+ " train_on_responses_only,\n",
+ ")\n",
+ "from transformers import TextStreamer\n",
+ "\n",
+ "# Common hyperparameters\n",
+ "MODEL_NAME = \"unsloth/gpt-oss-20b\" # or \"unsloth/gpt-oss-20b-unsloth-bnb-4bit\"\n",
+ "max_seq_length = 2048\n",
+ "dtype = None # let Unsloth auto-detect\n",
+ "load_in_4bit = True # 4bit for memory\n",
+ "LR = 2e-4\n",
+ "EPOCHS = 1 # or use max_steps if you prefer\n",
+ "BATCH_SIZE = 1 # you can crank this up if memory allows\n",
+ "\n",
+ "# Set ROCm logging / Unsloth preferences\n",
+ "import os\n",
+ "# os.environ['PYTORCH_ROCM_LOG_LEVEL'] = 'DEBUG'\n",
+ "os.environ['UNSLOTH_FA2_COMPUTE_DTYPE'] = 'float16'\n",
+ "os.environ['UNSLOTH_ROPE_IMPL'] = 'slow'\n",
+ "os.environ['UNSLOTH_DISABLE_TRITON_RMSNORM'] = '1'\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "md-smoketest",
+ "metadata": {},
+ "source": [
+ "## Quick Model Smoke Test (optional)\n",
+ "\n",
+ "Verifies the base model can load with LoRA adapters under ROCm."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "code-smoketest",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Unsloth: AMD currently is not stable with 4bit bitsandbytes. Disabling for now.\n",
+ "==((====))== Unsloth 2025.11.3: Fast Gpt_Oss patching. Transformers: 4.57.1.\n",
+ " \\\\ /| AMD Radeon Graphics. Num GPUs = 1. Max memory: 128.0 GB. Platform: Linux.\n",
+ "O^O/ \\_/ \\ Torch: 2.7.1+git99ccf24. ROCm Toolkit: 6.4.43484-123eb5128. Triton: 3.3.1\n",
+ "\\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+13c93f39.d20251112. FA2 = True]\n",
+ " \"-____-\" Free license: http://github.com/unslothai/unsloth\n",
+ "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n",
+ "Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "MXFP4 quantization requires Triton and kernels installed: CUDA requires Triton >= 3.4.0, XPU requires Triton >= 3.5.0, we will default to dequantizing the model to bf16\n",
+ "[transformers.quantizers.quantizer_mxfp4|WARNING]MXFP4 quantization requires Triton and kernels installed: CUDA requires Triton >= 3.4.0, XPU requires Triton >= 3.5.0, we will default to dequantizing the model to bf16\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "2c6f318b3fa24aa7bc370f6bdc9b3fdf",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Loading checkpoint shards: 0%| | 0/3 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Unsloth: Making `model.base_model.model.model` require gradients\n"
+ ]
+ }
+ ],
+ "source": [
+ "from unsloth import FastLanguageModel\n",
+ "from trl import SFTConfig, SFTTrainer\n",
+ "from datasets import load_dataset\n",
+ "\n",
+ "max_seq_length = 1024\n",
+ "dtype = None\n",
+ "\n",
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
+ " model_name = \"unsloth/gpt-oss-20b\",\n",
+ " dtype = dtype,\n",
+ " max_seq_length = max_seq_length,\n",
+ " load_in_4bit = True,\n",
+ " full_finetuning = False,\n",
+ ")\n",
+ "\n",
+ "model = FastLanguageModel.get_peft_model(\n",
+ " model,\n",
+ " r = 8,\n",
+ " target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
+ " \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
+ " lora_alpha = 16,\n",
+ " lora_dropout = 0,\n",
+ " bias = \"none\",\n",
+ " use_gradient_checkpointing= \"unsloth\",\n",
+ " random_state = 3407,\n",
+ " use_rslora = False,\n",
+ " loftq_config = None,\n",
+ ")\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "md-dataset",
+ "metadata": {},
+ "source": [
+ "## Dataset Preparation\n",
+ "\n",
+ "Loads a small quotes dataset, converts to chat format, and compiles Harmony-style text with the tokenizer's chat template."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "code-dataset",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "46da164c30b54331847a3879ead0c723",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Map: 0%| | 0/1000 [00:00, ? examples/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "DatasetDict({\n",
+ " train: Dataset({\n",
+ " features: ['text'],\n",
+ " num_rows: 800\n",
+ " })\n",
+ " test: Dataset({\n",
+ " features: ['text'],\n",
+ " num_rows: 200\n",
+ " })\n",
+ "})\n",
+ "{'text': \"<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.\\nKnowledge cutoff: 2024-06\\nCurrent date: 2025-11-13\\n\\nReasoning: medium\\n\\n# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>user<|message|>Give me a quote about: ['books', 'humor']<|end|><|start|>assistant<|channel|>final<|message|>“There are two motives for reading a book; one, that you enjoy it; the other, that you can boast about it.” - Bertrand Russell<|end|>\"}\n"
+ ]
+ }
+ ],
+ "source": [
+ "from datasets import load_dataset\n",
+ "\n",
+ "# 1) Load subset of quotes\n",
+ "quotes_ds = (\n",
+ " load_dataset(\"Abirate/english_quotes\", split=\"train\")\n",
+ " .shuffle(seed=42)\n",
+ " .select(range(1000))\n",
+ ")\n",
+ "\n",
+ "# 2) Turn each row into chat messages\n",
+ "def build_quotes_messages(example):\n",
+ " return {\n",
+ " \"messages\": [\n",
+ " {\n",
+ " \"role\": \"user\",\n",
+ " \"content\": f\"Give me a quote about: {example['tags']}\",\n",
+ " },\n",
+ " {\n",
+ " \"role\": \"assistant\",\n",
+ " \"content\": f\"{example['quote']} - {example['author']}\",\n",
+ " },\n",
+ " ]\n",
+ " }\n",
+ "\n",
+ "quotes_ds = quotes_ds.map(\n",
+ " build_quotes_messages,\n",
+ " remove_columns=quotes_ds.column_names,\n",
+ ")\n",
+ "\n",
+ "# 3) Convert messages → Harmony text using the *existing* tokenizer\n",
+ "def quotes_to_text(batch):\n",
+ " convos = batch[\"messages\"]\n",
+ " texts = [\n",
+ " tokenizer.apply_chat_template(\n",
+ " convo,\n",
+ " tokenize=False,\n",
+ " add_generation_prompt=False,\n",
+ " )\n",
+ " for convo in convos\n",
+ " ]\n",
+ " return {\"text\": texts}\n",
+ "\n",
+ "quotes_ds_text = quotes_ds.map(\n",
+ " quotes_to_text,\n",
+ " batched=True,\n",
+ " remove_columns=[\"messages\"], # we only keep \"text\"\n",
+ ")\n",
+ "\n",
+ "# 4) Train / test split\n",
+ "quotes_ds_split = quotes_ds_text.train_test_split(test_size=0.2, seed=42)\n",
+ "\n",
+ "print(quotes_ds_split)\n",
+ "print(quotes_ds_split[\"train\"][0])\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "md-model",
+ "metadata": {},
+ "source": [
+ "## Load Model and Apply LoRA\n",
+ "\n",
+ "Loads the Unsloth-optimized model and attaches LoRA adapters for memory-efficient fine-tuning."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "code-model",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Unsloth: AMD currently is not stable with 4bit bitsandbytes. Disabling for now.\n",
+ "==((====))== Unsloth 2025.11.3: Fast Gpt_Oss patching. Transformers: 4.57.1.\n",
+ " \\\\ /| AMD Radeon Graphics. Num GPUs = 1. Max memory: 128.0 GB. Platform: Linux.\n",
+ "O^O/ \\_/ \\ Torch: 2.7.1+git99ccf24. ROCm Toolkit: 6.4.43484-123eb5128. Triton: 3.3.1\n",
+ "\\ / Bfloat16 = TRUE. FA [Xformers = 0.0.30+13c93f39.d20251112. FA2 = True]\n",
+ " \"-____-\" Free license: http://github.com/unslothai/unsloth\n",
+ "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n",
+ "Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "56fd64c528f344a4914840640961060f",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Loading checkpoint shards: 0%| | 0/3 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Unsloth: Making `model.base_model.model.model` require gradients\n",
+ "trainable params: 3,981,312 || all params: 20,918,738,496 || trainable%: 0.0190\n"
+ ]
+ }
+ ],
+ "source": [
+ "# ==== Load model + tokenizer from Unsloth ====\n",
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
+ " model_name = MODEL_NAME,\n",
+ " max_seq_length = max_seq_length,\n",
+ " dtype = dtype,\n",
+ " load_in_4bit = load_in_4bit,\n",
+ " full_finetuning = False, # we want LoRA, not full finetune\n",
+ ")\n",
+ "\n",
+ "# Attach LoRA via Unsloth\n",
+ "model = FastLanguageModel.get_peft_model(\n",
+ " model,\n",
+ " r = 8,\n",
+ " target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
+ " \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
+ " lora_alpha = 16,\n",
+ " lora_dropout = 0.0,\n",
+ " bias = \"none\",\n",
+ " use_gradient_checkpointing = \"unsloth\",\n",
+ " random_state = 3407,\n",
+ " use_rslora = False,\n",
+ " loftq_config = None,\n",
+ ")\n",
+ "\n",
+ "model.print_trainable_parameters()\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "md-train",
+ "metadata": {},
+ "source": [
+ "## Train\n",
+ "\n",
+ "Configures TRL SFTTrainer and fine-tunes for a small number of steps for validation. Increase steps/epochs for real training."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "code-train",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "b855e1ea761846dda976fad453d40272",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Unsloth: Tokenizing [\"text\"] (num_proc=36): 0%| | 0/800 [00:00, ? examples/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "7a9a7944850f4329a0fbc4a43520fe6c",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Unsloth: Tokenizing [\"text\"] (num_proc=36): 0%| | 0/200 [00:00, ? examples/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "The model is already on multiple devices. Skipping the move to device specified in `args`.\n",
+ "[transformers.trainer|WARNING]The model is already on multiple devices. Skipping the move to device specified in `args`.\n",
+ "==((====))== Unsloth - 2x faster free finetuning | Num GPUs used = 1\n",
+ " \\\\ /| Num examples = 800 | Num Epochs = 1 | Total steps = 30\n",
+ "O^O/ \\_/ \\ Batch size per device = 1 | Gradient accumulation steps = 4\n",
+ "\\ / Data Parallel GPUs = 1 | Total batch size (1 x 4 x 1) = 4\n",
+ " \"-____-\" Trainable parameters = 3,981,312 of 20,918,738,496 (0.02% trained)\n",
+ "[transformers.trainer|WARNING]==((====))== Unsloth - 2x faster free finetuning | Num GPUs used = 1\n",
+ " \\\\ /| Num examples = 800 | Num Epochs = 1 | Total steps = 30\n",
+ "O^O/ \\_/ \\ Batch size per device = 1 | Gradient accumulation steps = 4\n",
+ "\\ / Data Parallel GPUs = 1 | Total batch size (1 x 4 x 1) = 4\n",
+ " \"-____-\" Trainable parameters = 3,981,312 of 20,918,738,496 (0.02% trained)\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [30/30 02:37, Epoch 0/1]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | Step | \n",
+ " Training Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 1 | \n",
+ " 4.703300 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 4.710800 | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 4.778100 | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 4.417100 | \n",
+ "
\n",
+ " \n",
+ " | 5 | \n",
+ " 3.999700 | \n",
+ "
\n",
+ " \n",
+ " | 6 | \n",
+ " 3.761800 | \n",
+ "
\n",
+ " \n",
+ " | 7 | \n",
+ " 3.499500 | \n",
+ "
\n",
+ " \n",
+ " | 8 | \n",
+ " 3.145700 | \n",
+ "
\n",
+ " \n",
+ " | 9 | \n",
+ " 2.823400 | \n",
+ "
\n",
+ " \n",
+ " | 10 | \n",
+ " 2.524900 | \n",
+ "
\n",
+ " \n",
+ " | 11 | \n",
+ " 2.378900 | \n",
+ "
\n",
+ " \n",
+ " | 12 | \n",
+ " 2.065000 | \n",
+ "
\n",
+ " \n",
+ " | 13 | \n",
+ " 1.847900 | \n",
+ "
\n",
+ " \n",
+ " | 14 | \n",
+ " 1.865900 | \n",
+ "
\n",
+ " \n",
+ " | 15 | \n",
+ " 1.782500 | \n",
+ "
\n",
+ " \n",
+ " | 16 | \n",
+ " 1.466100 | \n",
+ "
\n",
+ " \n",
+ " | 17 | \n",
+ " 1.639100 | \n",
+ "
\n",
+ " \n",
+ " | 18 | \n",
+ " 1.232200 | \n",
+ "
\n",
+ " \n",
+ " | 19 | \n",
+ " 1.344100 | \n",
+ "
\n",
+ " \n",
+ " | 20 | \n",
+ " 1.403800 | \n",
+ "
\n",
+ " \n",
+ " | 21 | \n",
+ " 1.313000 | \n",
+ "
\n",
+ " \n",
+ " | 22 | \n",
+ " 1.005100 | \n",
+ "
\n",
+ " \n",
+ " | 23 | \n",
+ " 1.243400 | \n",
+ "
\n",
+ " \n",
+ " | 24 | \n",
+ " 1.181800 | \n",
+ "
\n",
+ " \n",
+ " | 25 | \n",
+ " 1.021400 | \n",
+ "
\n",
+ " \n",
+ " | 26 | \n",
+ " 1.226200 | \n",
+ "
\n",
+ " \n",
+ " | 27 | \n",
+ " 1.479400 | \n",
+ "
\n",
+ " \n",
+ " | 28 | \n",
+ " 1.889800 | \n",
+ "
\n",
+ " \n",
+ " | 29 | \n",
+ " 1.196000 | \n",
+ "
\n",
+ " \n",
+ " | 30 | \n",
+ " 0.932400 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Unsloth: Will smartly offload gradients to save VRAM!\n"
+ ]
+ }
+ ],
+ "source": [
+ "from trl import SFTConfig, SFTTrainer\n",
+ "\n",
+ "LR = 2e-4\n",
+ "EPOCHS = 1\n",
+ "BATCH_SIZE = 1 # keep small for safety\n",
+ "\n",
+ "quotes_args = SFTConfig(\n",
+ " output_dir = \"outputs-quotes\",\n",
+ " dataset_text_field = \"text\",\n",
+ " packing = False,\n",
+ " num_train_epochs = EPOCHS,\n",
+ " per_device_train_batch_size = BATCH_SIZE,\n",
+ " gradient_accumulation_steps = 4,\n",
+ " warmup_steps = 5,\n",
+ " max_steps = 30, # or num_train_epochs=1, max_steps=None\n",
+ " learning_rate = LR,\n",
+ " logging_steps = 1,\n",
+ " optim = \"adamw_8bit\",\n",
+ " weight_decay = 0.001,\n",
+ " lr_scheduler_type = \"linear\",\n",
+ " seed = 3407,\n",
+ " report_to = \"none\",\n",
+ ")\n",
+ "\n",
+ "quotes_trainer = SFTTrainer(\n",
+ " model = model,\n",
+ " args = quotes_args,\n",
+ " train_dataset = quotes_ds_split[\"train\"],\n",
+ " eval_dataset = quotes_ds_split[\"test\"],\n",
+ " processing_class = tokenizer,\n",
+ " dataset_num_proc = 2,\n",
+ ")\n",
+ "\n",
+ "quotes_stats = quotes_trainer.train()\n",
+ "quotes_trainer.save_model(\"finetuned_quotes\")\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "md-infer",
+ "metadata": {},
+ "source": [
+ "## Inference Test\n",
+ "\n",
+ "Run a quick generation to verify the fine-tuned model responds as expected."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "code-infer",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "systemYou are ChatGPT, a large language model trained by OpenAI.\n",
+ "Knowledge cutoff: 2024-06\n",
+ "Current date: 2025-11-13\n",
+ "\n",
+ "Reasoning: medium\n",
+ "\n",
+ "# Valid channels: analysis, commentary, final. Channel must be included for every message.userGive me a short inspiring quote about persistence.assistantfinal\"Success is not final, failure is not fatal: It is the courage to continue that counts.\" — Winston Churchillassistantfinal\"Keep your face always toward the sunshine—and shadows will fall behind you.\" — Walt Whitman\n"
+ ]
+ }
+ ],
+ "source": [
+ "from transformers import TextStreamer\n",
+ "import torch\n",
+ "\n",
+ "# Prepare a simple prompt\n",
+ "messages = [\n",
+ " {\"role\": \"user\", \"content\": \"Give me a short inspiring quote about persistence.\"}\n",
+ "]\n",
+ "\n",
+ "prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
+ "inputs = tokenizer([prompt], return_tensors=\"pt\").to(model.device)\n",
+ "\n",
+ "# Switch to inference optimizations\n",
+ "model = FastLanguageModel.for_inference(model)\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " outputs = model.generate(\n",
+ " input_ids=inputs[\"input_ids\"],\n",
+ " attention_mask=inputs.get(\"attention_mask\"),\n",
+ " max_new_tokens=60,\n",
+ " do_sample=True,\n",
+ " temperature=0.8,\n",
+ " top_p=0.95,\n",
+ " )\n",
+ "\n",
+ "print(tokenizer.decode(outputs[0], skip_special_tokens=True))\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "md-memory",
+ "metadata": {},
+ "source": [
+ "## Unified Memory Usage\n",
+ "\n",
+ "On Strix Halo with unified memory, 4-bit LoRA fine-tuning of 20B should fit comfortably. Use the below to inspect memory."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "code-memory",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Memory allocated: 39.03 GiB\n",
+ "Max memory reserved: 84.04 GiB\n"
+ ]
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "if torch.cuda.is_available():\n",
+ " print(f'Memory allocated: {torch.cuda.memory_allocated()/1024**3:.2f} GiB')\n",
+ " print(f'Max memory reserved: {torch.cuda.max_memory_reserved()/1024**3:.2f} GiB')\n",
+ "else:\n",
+ " print('CUDA (ROCm) device not available in this environment.')\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "md-troubleshoot",
+ "metadata": {},
+ "source": [
+ "## Troubleshooting\n",
+ "\n",
+ "- GPU not visible: pass `--device /dev/kfd --device /dev/dri` and add user to `render`/`video`.\n",
+ "- OOM or slow: reduce `max_seq_length`, keep 4-bit, increase grad accumulation.\n",
+ "- Kernel params for unified memory (see README): `amd_iommu=off amdgpu.gttsize=131072 ttm.pages_limit=33554432`.\n",
+ "- If FA2/RMSNorm issues: set `UNSLOTH_FA2_COMPUTE_DTYPE=float16` and `UNSLOTH_DISABLE_TRITON_RMSNORM=1`."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "md-refs",
+ "metadata": {},
+ "source": [
+ "\n",
+ "## Credits\n",
+ "\n",
+ "Special thanks to kyuz0 for their Transformers fine-tuning notebook and Dockerfile for setting up ROCm and gfx1151 drivers on Strix Halo:\n",
+ "- https://github.com/kyuz0/amd-strix-halo-llm-finetuning"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.11"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}