# KUx All-in-One Colab Notebook

Run the entire KUx workflow—from environment setup to the multimodal chatbot—in a single place. Execute the cells sequentially (or rerun specific sections) to rebuild the RAG store, optionally fine-tune Qwen3-Omni, and launch the Gradio demo on Google Colab Pro+.


> **Tip:** Adjust the configuration cell below before running the notebook. Every subsequent step consumes the values you set there, so you can quickly toggle crawling, ingestion, fine-tuning, or change model defaults without editing later cells.


In [None]:
# Master configuration for the KUx workflow. Update values as needed before running the rest of the notebook.
CONFIG = {
    "repo_url": "https://github.com/themistymoon/KUx.git",
    "repo_dir": "/content/KUx",
    # Data collection toggles
    "enable_crawl": False,
    "crawl_seed_urls": [
        "https://cs.sci.ku.ac.th/",
    ],
    "crawl_max_depth": 1,
    "crawl_max_pages": 10,
    # Document ingestion
    "enable_ingest": True,
    "ingest_sources": [
        "data/sample_documents",
        "data/crawled",
    ],
    # Optional LoRA fine-tuning
    "enable_finetune": False,
    "finetune_dataset": "data/train.jsonl",
    "finetune_epochs": 2,
    # Chatbot defaults
    "default_model_key": "qwen3-omni-30b",
    "default_system_prompt": "",
    "launch_share": True,
    "launch_preload": True,
    # Storage locations (relative to the repo root)
    "vector_db_dir": "storage/vectorstore",
    "adapter_dir": "outputs/finetuned-qwen",
}
CONFIG


## 1. Verify the runtime GPU

Confirm that the Colab session is attached to an **A100 80 GB**.


In [None]:
!nvidia-smi


## 2. Clone the KUx repository

Change `repo_url` in the configuration cell if you are working from a fork. The cell is idempotent—it only clones when the target directory is missing.


In [None]:
from pathlib import Path
from IPython import get_ipython

REPO_URL = CONFIG["repo_url"]
TARGET_DIR = Path(CONFIG["repo_dir"])

if not TARGET_DIR.exists():
    TARGET_DIR.parent.mkdir(parents=True, exist_ok=True)
    !git clone {REPO_URL} {TARGET_DIR}

ioshell = get_ipython()
ioshell.run_line_magic('cd', str(TARGET_DIR))
print('Working directory:', Path.cwd())


## 3. Install Python dependencies

Install the runtime packages required for crawling, RAG ingestion, fine-tuning, and the chatbot UI.


In [None]:
!pip install -U pip


In [None]:
!pip install -r requirements.txt


> **Optional:**
>
> * Authenticate with Hugging Face if the Qwen model is gated.
> * Mount Google Drive to persist vector stores (`storage/vectorstore/`) or LoRA adapters (`outputs/`).


In [None]:
# from huggingface_hub import notebook_login
# notebook_login()

# from google.colab import drive
# drive.mount('/content/drive')


## 4. Configure project paths

These directories are created automatically inside the cloned repository. Adjust the configuration cell if you want to point at a different location (e.g., a Drive mount).


In [None]:
from pathlib import Path

PROJECT_ROOT = Path.cwd()
DATA_DIR = PROJECT_ROOT / 'data'
VECTOR_DB_DIR = PROJECT_ROOT / CONFIG['vector_db_dir']
ADAPTER_DIR = PROJECT_ROOT / CONFIG['adapter_dir']

for path in (DATA_DIR, VECTOR_DB_DIR, ADAPTER_DIR):
    path.mkdir(parents=True, exist_ok=True)

CONFIG['project_root'] = str(PROJECT_ROOT)
CONFIG['data_dir'] = str(DATA_DIR)
CONFIG['vector_db_dir'] = str(VECTOR_DB_DIR)
CONFIG['adapter_dir'] = str(ADAPTER_DIR)

print('Project root:', PROJECT_ROOT)
print('Data dir:', DATA_DIR)
print('Vector DB dir:', VECTOR_DB_DIR)
print('Adapter dir:', ADAPTER_DIR)


### Sample documents bundled with KUx

Use these starter files to test the pipeline before adding your own PDFs/CSVs.


In [None]:
for path in sorted((DATA_DIR / 'sample_documents').glob('*')):
    size_kb = path.stat().st_size / 1024
    print(f"{path.relative_to(PROJECT_ROOT)} — {size_kb:.1f} KiB")


## KUx core logic within this notebook

The next cells inline the configuration dataclasses, crawler, RAG ingestion, multimodal pipeline, fine-tuning helper, and chatbot UI so you can run everything directly without importing the `kux` Python package.


In [None]:
"""Central configuration dataclasses for KUx project."""
from __future__ import annotations

from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional


@dataclass(frozen=True)
class ModelOption:
    """Definition for a selectable base model in the chatbot UI."""

    key: str
    label: str
    model_name: str
    multimodal: bool = False


@dataclass
class TrainConfig:
    """Configuration for supervised fine-tuning of Qwen."""

    model_name: str = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
    dataset_path: str = "data/train.jsonl"
    output_dir: str = "outputs/finetuned-qwen"
    learning_rate: float = 2e-4
    num_train_epochs: int = 3
    per_device_train_batch_size: int = 1
    gradient_accumulation_steps: int = 8
    max_seq_length: int = 4096
    warmup_ratio: float = 0.03
    weight_decay: float = 0.0
    logging_steps: int = 10
    save_steps: int = 200
    eval_steps: Optional[int] = None
    seed: int = 42
    lora_r: int = 64
    lora_alpha: int = 16
    lora_dropout: float = 0.05
    load_in_4bit: bool = True
    use_gradient_checkpointing: bool = True
    bf16: bool = True
    dataset_text_field: str = "text"


@dataclass
class RAGConfig:
    """Configuration for retrieval augmented generation."""

    vector_db_path: Path = Path("storage/vectorstore")
    chunk_size: int = 1024
    chunk_overlap: int = 80
    allowed_document_types: List[str] = field(default_factory=lambda: [".pdf", ".csv", ".txt"])
    embedding_model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
    collection_name: str = "kasetsart_documents"
    max_retrieval_docs: int = 6


@dataclass
class CrawlerConfig:
    """Configuration for crawling approved Kasetsart CS resources."""

    allowed_domains: List[str] = field(
        default_factory=lambda: [
            "www.ku.ac.th",
            "www.cs.ku.ac.th",
            "cs.ku.ac.th",
            "registrar.ku.ac.th",
            "admission.ku.ac.th",
        ]
    )
    user_agent: str = "KUxBot/1.0 (+https://www.cs.ku.ac.th)"
    request_timeout: int = 20
    max_depth: int = 1
    max_pages: int = 20
    cache_dir: Path = Path("storage/crawler_cache")


MODEL_OPTIONS: List[ModelOption] = [
    ModelOption(
        key="qwen3-omni-30b",
        label="Qwen3-Omni-30B-A3B-Instruct (multimodal)",
        model_name="Qwen/Qwen3-Omni-30B-A3B-Instruct",
        multimodal=True,
    ),
    ModelOption(
        key="gpt-oss-120b",
        label="gpt-oss-120b (text-only)",
        model_name="Qwen/gpt-oss-120b",
        multimodal=False,
    ),
]


__all__ = [
    "TrainConfig",
    "RAGConfig",
    "CrawlerConfig",
    "ModelOption",
    "MODEL_OPTIONS",
]


In [None]:
"""Simple focused crawler for Kasetsart University domains."""
from __future__ import annotations

import hashlib
import logging
from collections import deque
from pathlib import Path
from typing import Dict, Iterable, List, Set
from urllib.parse import urljoin, urlparse

import requests
from bs4 import BeautifulSoup


LOGGER = logging.getLogger(__name__)


class SiteCrawler:
    """Breadth-first crawler constrained to approved Kasetsart domains."""

    def __init__(self, config: CrawlerConfig | None = None) -> None:
        self.config = config or CrawlerConfig()
        self.session = requests.Session()
        self.session.headers.update({"User-Agent": self.config.user_agent})
        self.cache_dir = Path(self.config.cache_dir)
        self.cache_dir.mkdir(parents=True, exist_ok=True)

    # ------------------------------------------------------------------
    # Fetching helpers
    # ------------------------------------------------------------------
    def _is_allowed(self, url: str) -> bool:
        domain = urlparse(url).netloc
        return domain in self.config.allowed_domains

    def _cache_path(self, url: str) -> Path:
        digest = hashlib.sha256(url.encode("utf-8")).hexdigest()
        return self.cache_dir / f"{digest}.html"

    def fetch(self, url: str) -> str:
        if not self._is_allowed(url):
            raise ValueError(f"URL domain not allowed: {url}")
        cache_path = self._cache_path(url)
        if cache_path.exists():
            return cache_path.read_text(encoding="utf-8", errors="ignore")
        response = self.session.get(url, timeout=self.config.request_timeout)
        response.raise_for_status()
        cache_path.write_text(response.text, encoding="utf-8")
        return response.text

    # ------------------------------------------------------------------
    # Parsing helpers
    # ------------------------------------------------------------------
    def extract_text(self, html: str) -> str:
        soup = BeautifulSoup(html, "html.parser")
        for tag in soup(["script", "style", "noscript"]):
            tag.decompose()
        text = "\n".join(chunk.strip() for chunk in soup.stripped_strings)
        return text

    def extract_links(self, base_url: str, html: str) -> List[str]:
        soup = BeautifulSoup(html, "html.parser")
        links: List[str] = []
        for tag in soup.find_all("a", href=True):
            href = tag["href"]
            absolute = urljoin(base_url, href)
            if self._is_allowed(absolute):
                links.append(absolute)
        return links

    # ------------------------------------------------------------------
    # Crawling orchestration
    # ------------------------------------------------------------------
    def crawl(self, seeds: Iterable[str]) -> Dict[str, str]:
        """Crawl starting from the seed URLs and return url->text mapping."""

        visited: Set[str] = set()
        queue: deque[tuple[str, int]] = deque((seed, 0) for seed in seeds)
        results: Dict[str, str] = {}
        while queue and len(results) < self.config.max_pages:
            url, depth = queue.popleft()
            if url in visited or depth > self.config.max_depth:
                continue
            try:
                html = self.fetch(url)
                text = self.extract_text(html)
            except Exception as exc:  # pragma: no cover - network issues
                LOGGER.warning("Failed to crawl %s: %s", url, exc)
                continue
            visited.add(url)
            results[url] = text
            if depth < self.config.max_depth:
                for link in self.extract_links(url, html):
                    if link not in visited:
                        queue.append((link, depth + 1))
        return results


__all__ = ["SiteCrawler"]


In [None]:
"""Document ingestion utilities for KUx retrieval augmented generation."""
from __future__ import annotations

import logging
from pathlib import Path
from typing import Iterable

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import CSVLoader, PyPDFLoader, TextLoader
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings


LOGGER = logging.getLogger(__name__)


class DocumentIngestor:
    """Ingest PDFs, CSVs and text files into a FAISS vector store."""

    def __init__(self, config: RAGConfig | None = None) -> None:
        self.config = config or RAGConfig()
        self.vector_db_path = Path(self.config.vector_db_path)
        self.vector_db_path.mkdir(parents=True, exist_ok=True)
        self.embeddings = HuggingFaceEmbeddings(model_name=self.config.embedding_model_name)
        self.splitter = RecursiveCharacterTextSplitter(
            chunk_size=self.config.chunk_size,
            chunk_overlap=self.config.chunk_overlap,
        )

    def _resolve_loader(self, file_path: Path):
        suffix = file_path.suffix.lower()
        if suffix == ".pdf":
            return PyPDFLoader(str(file_path))
        if suffix == ".csv":
            return CSVLoader(str(file_path))
        if suffix in {".txt", ".md"}:
            return TextLoader(str(file_path))
        raise ValueError(f"Unsupported file type for ingestion: {suffix}")

    def _load_documents(self, path: Path):
        loader = self._resolve_loader(path)
        documents = loader.load()
        LOGGER.info("Loaded %s documents from %s", len(documents), path)
        return documents

    def ingest(self, sources: Iterable[str]) -> FAISS:
        """Ingest the provided sources into the FAISS vector store."""

        docs = []
        for source in sources:
            path = Path(source)
            if path.is_dir():
                for child in path.rglob("*"):
                    if child.suffix.lower() in self.config.allowed_document_types:
                        docs.extend(self._load_documents(child))
            elif path.suffix.lower() in self.config.allowed_document_types:
                docs.extend(self._load_documents(path))
            else:
                LOGGER.warning("Skipping unsupported file: %s", path)
        if not docs:
            raise RuntimeError("No documents were ingested. Check your source paths.")

        LOGGER.info("Splitting %s documents into chunks", len(docs))
        chunks = self.splitter.split_documents(docs)
        LOGGER.info("Creating vector store with %s chunks", len(chunks))

        vector_store = FAISS.from_documents(chunks, embedding=self.embeddings)
        vector_store.save_local(str(self.vector_db_path))
        LOGGER.info("Vector store saved to %s", self.vector_db_path)
        return vector_store


__all__ = ["DocumentIngestor"]


In [None]:
"""Retrieval augmented generation pipeline for KUx."""
from __future__ import annotations

import logging
from dataclasses import dataclass, field
from pathlib import Path
from textwrap import dedent
from typing import Iterable, List, Optional, Sequence, Tuple

import torch
from langchain.schema import Document
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline


LOGGER = logging.getLogger(__name__)

DEFAULT_SYSTEM_PROMPT = dedent(
    """
    You are KUx, an omniscient assistant for Kasetsart University Computer Science students.
    Answer with verified facts from Kasetsart University sources. If unsure, state that you do not know.
    """
).strip()

MODEL_OPTION_MAP = {option.key: option for option in MODEL_OPTIONS}


@dataclass(slots=True)
class MediaInput:
    """Container for user-provided multimodal attachments."""

    images: List[str] = field(default_factory=list)
    audio: List[str] = field(default_factory=list)
    video: List[str] = field(default_factory=list)

    @classmethod
    def from_payload(
        cls,
        images: Optional[Sequence[str]] = None,
        audio: Optional[Sequence[str]] = None,
        video: Optional[Sequence[str]] = None,
    ) -> "MediaInput":
        def _clean(items: Optional[Sequence[str]]) -> List[str]:
            if not items:
                return []
            return [item for item in items if item]

        return cls(images=_clean(images), audio=_clean(audio), video=_clean(video))

    def is_empty(self) -> bool:
        return not (self.images or self.audio or self.video)


class LocalHFGenerator:
    """Wrapper around a local Hugging Face causal LM for inference."""

    def __init__(
        self,
        model_path: str,
        max_new_tokens: int = 1024,
        temperature: float = 0.2,
        adapter_path: Optional[str] = None,
    ) -> None:
        LOGGER.info("Loading generator from %s", model_path)
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        base_model = AutoModelForCausalLM.from_pretrained(
            model_path,
            device_map="auto",
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
        )
        self.model = self._maybe_apply_adapter(base_model, adapter_path)
        self.pipe = pipeline(
            "text-generation",
            model=self.model,
            tokenizer=self.tokenizer,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            do_sample=temperature > 0,
        )

    @staticmethod
    def _maybe_apply_adapter(model: AutoModelForCausalLM, adapter_path: Optional[str]):
        if not adapter_path:
            return model
        adapter_dir = Path(adapter_path)
        if not adapter_dir.exists():
            LOGGER.warning("Adapter path %s does not exist; continuing with the base model.", adapter_dir)
            return model
        try:  # pragma: no cover - requires peft at runtime
            from peft import PeftModel
        except ImportError as exc:  # pragma: no cover
            raise ImportError(
                "peft is required to load LoRA adapters. Install with `pip install peft`."
            ) from exc

        LOGGER.info("Applying LoRA adapters from %s", adapter_dir)
        adapted_model = PeftModel.from_pretrained(model, adapter_dir)
        return adapted_model

    def generate(self, prompt: str) -> str:
        outputs = self.pipe(prompt)
        text = outputs[0]["generated_text"]
        return text[len(prompt) :].strip()


class QwenOmniGenerator:
    """Inference helper for Qwen3-Omni multimodal dialogue."""

    def __init__(
        self,
        model_path: str,
        max_new_tokens: int = 2048,
        temperature: float = 0.2,
        use_audio_in_video: bool = True,
        adapter_path: Optional[str] = None,
    ) -> None:
        try:  # pragma: no cover - heavy dependency only available at runtime
            from transformers import (
                Qwen3OmniMoeForConditionalGeneration,
                Qwen3OmniMoeProcessor,
            )
        except ImportError as exc:  # pragma: no cover - import validated in runtime environment
            raise ImportError(
                "Qwen3-Omni dependencies are missing. Install the latest transformers from "
                "source (pip install git+https://github.com/huggingface/transformers)."
            ) from exc

        LOGGER.info("Loading Qwen3-Omni model from %s", model_path)
        base_model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(
            model_path,
            dtype="auto",
            device_map="auto",
            attn_implementation="flash_attention_2",
        )
        if hasattr(base_model, "disable_talker"):
            base_model.disable_talker()
        self.model = self._maybe_apply_adapter(base_model, adapter_path)
        self.processor = Qwen3OmniMoeProcessor.from_pretrained(model_path)
        self.max_new_tokens = max_new_tokens
        self.temperature = temperature
        self.use_audio_in_video = use_audio_in_video

    @staticmethod
    def _maybe_apply_adapter(model, adapter_path: Optional[str]):
        if not adapter_path:
            return model
        adapter_dir = Path(adapter_path)
        if not adapter_dir.exists():
            LOGGER.warning("Adapter path %s does not exist; continuing with the base model.", adapter_dir)
            return model
        try:  # pragma: no cover - requires peft during runtime
            from peft import PeftModel
        except ImportError as exc:  # pragma: no cover
            raise ImportError(
                "peft is required to load Qwen3-Omni LoRA adapters. Install with `pip install peft`."
            ) from exc

        LOGGER.info("Applying Qwen3-Omni LoRA adapters from %s", adapter_dir)
        adapted_model = PeftModel.from_pretrained(model, adapter_dir)
        return adapted_model

    @staticmethod
    def _collect_media(messages: Sequence[dict]) -> tuple[Optional[List[str]], Optional[List[str]], Optional[List[str]]]:
        images: List[str] = []
        audio: List[str] = []
        videos: List[str] = []
        for message in messages:
            content = message.get("content", [])
            if isinstance(content, dict):
                content = [content]
            for item in content:
                if not isinstance(item, dict):
                    continue
                item_type = item.get("type")
                if item_type == "image" and item.get("image") is not None:
                    images.append(item["image"])
                elif item_type == "audio" and item.get("audio") is not None:
                    audio.append(item["audio"])
                elif item_type == "video" and item.get("video") is not None:
                    videos.append(item["video"])
        return (audio or None, images or None, videos or None)

    def generate(self, messages: Sequence[dict]) -> str:
        text = self.processor.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=False,
        )
        audios, images, videos = self._collect_media(messages)
        inputs = self.processor(
            text=text,
            audio=audios,
            images=images,
            videos=videos,
            return_tensors="pt",
            padding=True,
            use_audio_in_video=self.use_audio_in_video,
        )
        inputs = inputs.to(self.model.device)
        outputs = self.model.generate(
            **inputs,
            return_audio=False,
            thinker_return_dict_in_generate=True,
            use_audio_in_video=self.use_audio_in_video,
            max_new_tokens=self.max_new_tokens,
            temperature=self.temperature,
            do_sample=self.temperature > 0,
        )
        sequences = getattr(outputs, "sequences", outputs)
        offset = inputs["input_ids"].shape[1]
        text = self.processor.batch_decode(
            sequences[:, offset:],
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False,
        )[0]
        return text.strip()


class RAGPipeline:
    """High level helper for running retrieval augmented generation."""

    def __init__(
        self,
        rag_config: Optional[RAGConfig] = None,
        train_config: Optional[TrainConfig] = None,
        system_prompt: Optional[str] = None,
        model_key: Optional[str] = None,
        use_finetuned: bool = True,
    ) -> None:
        self.rag_config = rag_config or RAGConfig()
        self.train_config = train_config or TrainConfig()
        self.system_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
        self.model_key = model_key or MODEL_OPTIONS[0].key
        self.model_option = self._resolve_model_option(self.model_key)
        self.use_finetuned = use_finetuned
        self.embeddings = HuggingFaceEmbeddings(model_name=self.rag_config.embedding_model_name)
        self.vector_store = self._load_vector_store(self.rag_config.vector_db_path)
        base_model, adapter_path = self._resolve_model_sources()
        if self.model_option.multimodal:
            self.generator = QwenOmniGenerator(base_model, adapter_path=adapter_path)
        else:
            self.generator = LocalHFGenerator(base_model, adapter_path=adapter_path)

    def _resolve_model_option(self, model_key: str) -> ModelOption:
        try:
            return MODEL_OPTION_MAP[model_key]
        except KeyError as exc:  # pragma: no cover - defensive programming
            raise ValueError(f"Unknown model selection: {model_key}") from exc

    def _resolve_model_sources(self) -> Tuple[str, Optional[str]]:
        base_model = self.model_option.model_name
        if not self.use_finetuned:
            return base_model, None

        if self.model_option.key != "qwen3-omni-30b":
            LOGGER.warning(
                "Fine-tuned adapters are currently only supported for Qwen3-Omni. Continuing with base %s.",
                base_model,
            )
            return base_model, None

        adapter_dir = Path(self.train_config.output_dir)
        if not adapter_dir.exists():
            LOGGER.warning(
                "Requested fine-tuned model but no adapters found at %s. Continuing with base model.",
                adapter_dir,
            )
            return base_model, None

        if not self._contains_adapter_weights(adapter_dir):
            LOGGER.warning(
                "Adapter directory %s does not contain LoRA weights. Continuing with base model.",
                adapter_dir,
            )
            return base_model, None

        return base_model, str(adapter_dir)

    @staticmethod
    def _contains_adapter_weights(adapter_dir: Path) -> bool:
        expected_config = adapter_dir / "adapter_config.json"
        if expected_config.exists():
            return True
        # Accept common safetensors/bin outputs
        for pattern in ("adapter_model.bin", "adapter_model.safetensors"):
            if (adapter_dir / pattern).exists():
                return True
        return False

    def _load_vector_store(self, path: Path | str) -> Optional[FAISS]:
        path = Path(path)
        if not path.exists():
            LOGGER.warning(
                "Vector store not found at %s. Responses will be generated without retrieval until you run the ingestion pipeline.",
                path,
            )
            return None
        return FAISS.load_local(
            str(path),
            embeddings=self.embeddings,
            allow_dangerous_deserialization=True,
        )

    def _format_context(self, documents: Iterable[Document]) -> str:
        context_blocks: List[str] = []
        for idx, doc in enumerate(documents, start=1):
            metadata = doc.metadata
            source = metadata.get("source", "unknown")
            block = dedent(
                f"""
                [Document {idx} | Source: {source}]
                {doc.page_content.strip()}
                """
            ).strip()
            context_blocks.append(block)
        return "\n\n".join(context_blocks)

    def _build_text_history(self, history: Optional[Sequence[Tuple[str, str]]]) -> str:
        if not history:
            return ""
        turns: List[str] = []
        for user_text, assistant_text in history:
            if user_text:
                turns.append(
                    dedent(
                        f"""
                        <|im_start|>user
                        {user_text}
                        <|im_end|>
                        """
                    ).strip()
                )
            if assistant_text:
                turns.append(
                    dedent(
                        f"""
                        <|im_start|>assistant
                        {assistant_text}
                        <|im_end|>
                        """
                    ).strip()
                )
        return "\n".join(turns)

    def build_prompt(
        self,
        question: str,
        documents: List[Document],
        history: Optional[Sequence[Tuple[str, str]]] = None,
    ) -> str:
        context = self._format_context(documents)
        history_block = self._build_text_history(history)
        prompt_blocks = [
            dedent(
                f"""
                <|im_start|>system
                {self.system_prompt}
                <|im_end|>
                """
            ).strip()
        ]
        if history_block:
            prompt_blocks.append(history_block)
        prompt_blocks.append(
            dedent(
                f"""
                <|im_start|>user
                Question: {question}

                Use the following context to ground your answer:
                {context}
                <|im_end|>
                <|im_start|>assistant
                """
            ).strip()
        )
        return "\n".join(prompt_blocks)

    def _build_multimodal_messages(
        self,
        question: str,
        documents: List[Document],
        media: MediaInput,
        history: Optional[Sequence[Tuple[str, str]]],
    ) -> List[dict]:
        messages: List[dict] = [
            {
                "role": "system",
                "content": [{"type": "text", "text": self.system_prompt}],
            }
        ]
        if documents:
            context = self._format_context(documents)
            messages.append(
                {
                    "role": "system",
                    "content": [
                        {
                            "type": "text",
                            "text": (
                                "Kasetsart University supporting material:\n"
                                f"{context}"
                            ),
                        }
                    ],
                }
            )
        if history:
            for user_turn, assistant_turn in history:
                if user_turn:
                    messages.append(
                        {
                            "role": "user",
                            "content": [{"type": "text", "text": user_turn}],
                        }
                    )
                if assistant_turn:
                    messages.append(
                        {
                            "role": "assistant",
                            "content": [{"type": "text", "text": assistant_turn}],
                        }
                    )

        user_content: List[dict] = []
        for image_path in media.images:
            user_content.append({"type": "image", "image": image_path})
        for audio_path in media.audio:
            user_content.append({"type": "audio", "audio": audio_path})
        for video_path in media.video:
            user_content.append({"type": "video", "video": video_path})

        request_text = question.strip() or (
            "Please analyse the uploaded media and explain how it relates to Kasetsart University's "
            "Computer Science programme."
        )
        user_content.append({"type": "text", "text": request_text})
        messages.append({"role": "user", "content": user_content})
        return messages

    def answer(
        self,
        question: str,
        top_k: Optional[int] = None,
        media: Optional[MediaInput] = None,
        history: Optional[Sequence[Tuple[str, str]]] = None,
    ) -> str:
        media = media or MediaInput()
        question_text = question.strip()
        if not question_text and media.is_empty():
            return "Please provide a question or upload audio, image, or video content."

        documents: List[Document] = []
        if question_text and self.vector_store is not None:
            top_k = top_k or self.rag_config.max_retrieval_docs
            retriever = self.vector_store.as_retriever(search_kwargs={"k": top_k})
            documents = retriever.get_relevant_documents(question_text)
            if not documents and not self.model_option.multimodal:
                return "I could not find supporting documents for that question."
        elif question_text and self.vector_store is None:
            LOGGER.warning("Vector store unavailable; continuing without retrieved context.")

        if self.model_option.multimodal:
            messages = self._build_multimodal_messages(question_text, documents, media, history)
            return self.generator.generate(messages)

        prompt = self.build_prompt(question_text, documents, history)
        return self.generator.generate(prompt)


__all__ = ["RAGPipeline", "LocalHFGenerator", "QwenOmniGenerator", "MediaInput", "DEFAULT_SYSTEM_PROMPT"]


In [None]:
"""Supervised fine-tuning pipeline for Qwen3-Omni-30B."""
from __future__ import annotations

import json
from dataclasses import asdict
from pathlib import Path
from typing import Any, Dict, List, Optional

import torch
from datasets import Dataset, load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
    set_seed,
)


try:  # Optional imports used only during training
    from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
except ImportError as exc:  # pragma: no cover - only raised when deps missing
    raise ImportError(
        "peft is required for LoRA training. Install with `pip install peft`."
    ) from exc


ChatMessages = List[Dict[str, str]]


def _normalise_sample(example: Dict[str, Any]) -> str:
    """Normalise a dataset row into a chat-style plain text sample."""

    if "messages" in example:
        messages: ChatMessages = example["messages"]
        return messages_to_text(messages)

    if {"instruction", "response"}.issubset(example):
        system_prompt = example.get("system", "You are a helpful assistant.")
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": example["instruction"]},
            {"role": "assistant", "content": example["response"]},
        ]
        if example.get("input"):
            messages.insert(2, {"role": "user", "content": example["input"]})
        return messages_to_text(messages)

    text_field = example.get("text")
    if text_field:
        return str(text_field)

    raise ValueError(
        "Unsupported dataset schema. Expected `messages`, `text` or `instruction`/`response` columns."
    )


def messages_to_text(messages: ChatMessages) -> str:
    """Convert chat messages into model-ready text using the tokenizer template."""

    formatted: List[str] = []
    for message in messages:
        role = message.get("role", "user")
        content = message.get("content", "").strip()
        formatted.append(f"<|im_start|>{role}\n{content}<|im_end|>")
    formatted.append("<|im_start|>assistant\n")
    return "\n".join(formatted)


class SupervisedFineTuner:
    """LoRA supervised fine-tuning helper for the Qwen 3 Omni family."""

    def __init__(self, config: Optional[TrainConfig] = None) -> None:
        self.config = config or TrainConfig()
        set_seed(self.config.seed)
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.config.model_name, trust_remote_code=True
        )
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model: Optional[torch.nn.Module] = None
        self.train_dataset: Optional[Dataset] = None
        self.eval_dataset: Optional[Dataset] = None

    # ------------------------------------------------------------------
    # Dataset utilities
    # ------------------------------------------------------------------
    def prepare_datasets(
        self, eval_split: Optional[float] = 0.05, streaming: bool = False
    ) -> None:
        """Load datasets, normalise formatting and tokenize them."""

        dataset_path = self.config.dataset_path
        data_files: Dict[str, str]
        path = Path(dataset_path)
        if path.is_dir():
            data_files = {"train": str(path / "train.jsonl")}
        else:
            data_files = {"train": str(path)}

        if streaming:
            dataset = load_dataset("json", data_files=data_files, streaming=True)[
                "train"
            ]
            raise NotImplementedError("Streaming datasets are not supported in this release.")

        dataset_dict = load_dataset("json", data_files=data_files)
        field_name = self.config.dataset_text_field
        train_ds = dataset_dict["train"].map(
            lambda example: {field_name: _normalise_sample(example)}
        )
        if eval_split:
            split = train_ds.train_test_split(test_size=eval_split, seed=self.config.seed)
            self.train_dataset = self._tokenize(split["train"])
            self.eval_dataset = self._tokenize(split["test"])
        else:
            self.train_dataset = self._tokenize(train_ds)
            self.eval_dataset = None

    def _tokenize(self, dataset: Dataset) -> Dataset:
        """Tokenize dataset into model inputs."""

        def tokenize_function(examples: Dict[str, Any]) -> Dict[str, Any]:
            return self.tokenizer(
                examples[self.config.dataset_text_field],
                truncation=True,
                max_length=self.config.max_seq_length,
            )

        return dataset.map(
            tokenize_function,
            batched=True,
            remove_columns=[self.config.dataset_text_field],
        )

    # ------------------------------------------------------------------
    # Model utilities
    # ------------------------------------------------------------------
    def _load_model(self) -> None:
        torch_dtype = torch.bfloat16 if self.config.bf16 else torch.float16
        quantization_config: Dict[str, Any] = {}
        if self.config.load_in_4bit:
            quantization_config = {
                "load_in_4bit": True,
                "bnb_4bit_compute_dtype": torch_dtype,
                "bnb_4bit_quant_type": "nf4",
                "bnb_4bit_use_double_quant": True,
            }
        self.model = AutoModelForCausalLM.from_pretrained(
            self.config.model_name,
            trust_remote_code=True,
            device_map="auto",
            torch_dtype=torch_dtype,
            **quantization_config,
        )
        self.model = prepare_model_for_kbit_training(self.model)
        lora_config = LoraConfig(
            r=self.config.lora_r,
            lora_alpha=self.config.lora_alpha,
            lora_dropout=self.config.lora_dropout,
            bias="none",
            task_type="CAUSAL_LM",
        )
        self.model = get_peft_model(self.model, lora_config)
        if self.config.use_gradient_checkpointing:
            self.model.gradient_checkpointing_enable()

    # ------------------------------------------------------------------
    # Training orchestration
    # ------------------------------------------------------------------
    def train(self) -> None:
        if self.train_dataset is None:
            self.prepare_datasets()
        if self.model is None:
            self._load_model()

        training_args = TrainingArguments(
            output_dir=self.config.output_dir,
            per_device_train_batch_size=self.config.per_device_train_batch_size,
            gradient_accumulation_steps=self.config.gradient_accumulation_steps,
            learning_rate=self.config.learning_rate,
            num_train_epochs=self.config.num_train_epochs,
            logging_steps=self.config.logging_steps,
            save_steps=self.config.save_steps,
            warmup_ratio=self.config.warmup_ratio,
            weight_decay=self.config.weight_decay,
            bf16=self.config.bf16,
            evaluation_strategy="steps" if self.eval_dataset is not None else "no",
            eval_steps=self.config.eval_steps,
            report_to=["tensorboard"],
        )

        data_collator = DataCollatorForLanguageModeling(
            self.tokenizer, mlm=False, pad_to_multiple_of=8
        )

        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=self.train_dataset,
            eval_dataset=self.eval_dataset,
            tokenizer=self.tokenizer,
            data_collator=data_collator,
        )
        trainer.train()
        trainer.save_model()
        self.tokenizer.save_pretrained(self.config.output_dir)
        with open(Path(self.config.output_dir) / "train_config.json", "w", encoding="utf-8") as fp:
            json.dump(asdict(self.config), fp, indent=2)


__all__ = ["SupervisedFineTuner", "TrainConfig"]


In [None]:
"""Gradio chat interface for the KUx assistant."""
from __future__ import annotations

import logging
from dataclasses import dataclass, replace
from functools import lru_cache
from pathlib import Path
from typing import Any, List, Optional, Sequence, Tuple

import gradio as gr


LOGGER = logging.getLogger(__name__)

MODEL_LABEL_TO_KEY = {option.label: option.key for option in MODEL_OPTIONS}
MODEL_KEY_TO_LABEL = {option.key: option.label for option in MODEL_OPTIONS}


@dataclass(frozen=True)
class LaunchState:
    """Runtime configuration supplied when launching the Gradio app."""

    vector_db_path: Optional[str] = None
    adapter_dir: Optional[str] = None
    default_model_key: str = MODEL_OPTIONS[0].key
    default_system_prompt: str = DEFAULT_SYSTEM_PROMPT


_LAUNCH_STATE = LaunchState()


@lru_cache(maxsize=8)
def _load_pipeline_cached(
    model_key: str,
    use_finetuned: bool,
    system_prompt: str,
    vector_db_path: Optional[str],
    adapter_dir: Optional[str],
) -> RAGPipeline:
    LOGGER.info(
        "Initialising RAG pipeline (model=%s, finetuned=%s)",
        model_key,
        use_finetuned,
    )
    default_rag = RAGConfig()
    rag_config = RAGConfig(
        vector_db_path=Path(vector_db_path) if vector_db_path else default_rag.vector_db_path
    )
    default_train = TrainConfig()
    train_config = TrainConfig(output_dir=adapter_dir if adapter_dir else default_train.output_dir)
    return RAGPipeline(
        rag_config=rag_config,
        train_config=train_config,
        system_prompt=system_prompt,
        model_key=model_key,
        use_finetuned=use_finetuned,
    )


def load_pipeline(model_key: str, use_finetuned: bool, system_prompt: str) -> RAGPipeline:
    base_prompt = _LAUNCH_STATE.default_system_prompt or DEFAULT_SYSTEM_PROMPT
    prompt = system_prompt.strip() if system_prompt and system_prompt.strip() else base_prompt
    return _load_pipeline_cached(
        model_key,
        use_finetuned,
        prompt,
        _LAUNCH_STATE.vector_db_path,
        _LAUNCH_STATE.adapter_dir,
    )


def _extract_paths(payload: Any) -> List[str]:
    if not payload:
        return []
    if isinstance(payload, (str, Path)):
        return [str(payload)]
    items: Sequence[Any]
    if isinstance(payload, Sequence):
        items = payload
    else:  # single temp file or dict
        items = [payload]
    paths: List[str] = []
    for item in items:
        if not item:
            continue
        if isinstance(item, (str, Path)):
            paths.append(str(item))
        elif isinstance(item, dict):
            for key in ("path", "name"):
                value = item.get(key)
                if value:
                    paths.append(str(value))
                    break
        else:
            path = getattr(item, "name", None)
            if path:
                paths.append(str(path))
    return paths


def _format_user_display(message: str, media: MediaInput) -> str:
    parts: List[str] = []
    text = message.strip()
    if text:
        parts.append(text)
    for label, files in (
        ("Image", media.images),
        ("Audio", media.audio),
        ("Video", media.video),
    ):
        for file_path in files:
            parts.append(f"[{label}] {Path(file_path).name}")
    return "\n".join(parts) if parts else "[No user content]"


def respond(
    message: str,
    history: List[Tuple[str, str]],
    model_label: str,
    use_finetuned: bool,
    system_prompt: str,
    image_payload: Any,
    audio_payload: Any,
    video_payload: Any,
) -> Tuple[List[Tuple[str, str]], gr.Textbox, gr.File, gr.File, gr.File]:
    model_key = MODEL_LABEL_TO_KEY.get(model_label, MODEL_OPTIONS[0].key)
    pipeline = load_pipeline(model_key, use_finetuned, system_prompt)
    media = MediaInput.from_payload(
        images=_extract_paths(image_payload),
        audio=_extract_paths(audio_payload),
        video=_extract_paths(video_payload),
    )
    display_text = _format_user_display(message, media)
    prompt_history = [(user, bot) for user, bot in history if user or bot]
    answer = pipeline.answer(message, media=media, history=prompt_history)
    updated_history = history + [(display_text, answer)]
    return (
        updated_history,
        gr.Textbox.update(value=""),
        gr.File.update(value=None),
        gr.File.update(value=None),
        gr.File.update(value=None),
    )


def launch(
    *,
    vector_db_path: Optional[str] = None,
    adapter_dir: Optional[str] = None,
    default_model_key: Optional[str] = None,
    default_system_prompt: Optional[str] = None,
    share: bool = False,
    server_name: str = "0.0.0.0",
    server_port: int = 7860,
    preload_default: bool = True,
) -> None:
    global _LAUNCH_STATE
    desired_model_key = default_model_key or MODEL_OPTIONS[0].key
    _LAUNCH_STATE = replace(
        _LAUNCH_STATE,
        vector_db_path=str(vector_db_path) if vector_db_path else None,
        adapter_dir=str(adapter_dir) if adapter_dir else None,
        default_model_key=(
            desired_model_key if desired_model_key in MODEL_KEY_TO_LABEL else MODEL_OPTIONS[0].key
        ),
        default_system_prompt=(
            default_system_prompt.strip()
            if default_system_prompt and default_system_prompt.strip()
            else DEFAULT_SYSTEM_PROMPT
        ),
    )
    system_prompt_default = _LAUNCH_STATE.default_system_prompt or DEFAULT_SYSTEM_PROMPT
    _load_pipeline_cached.cache_clear()
    if preload_default:
        LOGGER.info(
            "Preloading default model %s (finetuned=%s)",
            _LAUNCH_STATE.default_model_key,
            True,
        )
        try:
            load_pipeline(_LAUNCH_STATE.default_model_key, True, system_prompt_default)
        except Exception as exc:  # pragma: no cover - defensive for runtime errors
            LOGGER.exception("Failed to preload the default model: %s", exc)
            raise
    description = (
        "KUx is a retrieval-augmented assistant for Kasetsart University Computer Science students."
    )
    theme = gr.themes.Default(primary_hue=gr.themes.colors.green)
    default_model_label = MODEL_KEY_TO_LABEL.get(_LAUNCH_STATE.default_model_key, MODEL_OPTIONS[0].label)
    with gr.Blocks(theme=theme) as demo:
        gr.Markdown(
            "# KUx – Kasetsart CS Assistant\n"
            f"<span style='color: #0f5132'>{description}</span>",
            elem_id="kux-header",
        )
        chatbot = gr.Chatbot(label="Conversation", height=520)
        with gr.Row():
            with gr.Column(scale=3):
                message_box = gr.Textbox(
                    label="Ask KUx",
                    placeholder="Ask a question or leave blank to analyse uploaded media",
                    lines=4,
                )
                send_btn = gr.Button("Send", variant="primary")
                clear_btn = gr.Button("Clear conversation", variant="secondary")
            with gr.Column(scale=2):
                image_files = gr.File(
                    label="Upload images (OCR, object grounding, image math)",
                    file_types=["image"],
                    file_count="multiple",
                )
                audio_files = gr.File(
                    label="Upload audio (speech recognition, translation, captioning)",
                    file_types=["audio"],
                    file_count="multiple",
                )
                video_files = gr.File(
                    label="Upload videos (audio-visual QA/interactions)",
                    file_types=["video"],
                    file_count="multiple",
                )
        with gr.Accordion("Assistant settings", open=False):
            model_dropdown = gr.Dropdown(
                label="Base model",
                choices=list(MODEL_LABEL_TO_KEY.keys()),
                value=default_model_label,
                interactive=True,
            )
            finetune_checkbox = gr.Checkbox(
                label="Use fine-tuned adapters (LoRA)",
                value=True,
                interactive=True,
            )
            system_prompt_box = gr.Textbox(
                label="System prompt",
                value=system_prompt_default,
                lines=4,
                interactive=True,
            )

        send_btn.click(
            fn=respond,
            inputs=[
                message_box,
                chatbot,
                model_dropdown,
                finetune_checkbox,
                system_prompt_box,
                image_files,
                audio_files,
                video_files,
            ],
            outputs=[chatbot, message_box, image_files, audio_files, video_files],
        )

        clear_btn.click(
            fn=lambda: (
                [],
                gr.Textbox.update(value=""),
                gr.File.update(value=None),
                gr.File.update(value=None),
                gr.File.update(value=None),
            ),
            inputs=None,
            outputs=[chatbot, message_box, image_files, audio_files, video_files],
        )

    demo.queue().launch(server_name=server_name, server_port=server_port, share=share)


__all__ = ["launch"]


## 5. Optional: crawl official Kasetsart CS sources

Enable `enable_crawl` in the configuration cell to fetch fresh pages. Outputs are saved under `data/crawled/`.


In [None]:
if CONFIG['enable_crawl']:
    crawl_seed_urls = CONFIG['crawl_seed_urls']
    if not crawl_seed_urls:
        raise ValueError('No seed URLs specified. Update CONFIG["crawl_seed_urls"].')

    crawler_config = CrawlerConfig(
        max_depth=CONFIG.get('crawl_max_depth', 1),
        max_pages=CONFIG.get('crawl_max_pages', 10),
        cache_dir=DATA_DIR / 'crawled_cache',
    )
    crawler = SiteCrawler(crawler_config)
    crawled = crawler.crawl(crawl_seed_urls)
    output_dir = DATA_DIR / 'crawled'
    output_dir.mkdir(parents=True, exist_ok=True)

    for idx, (url, text) in enumerate(crawled.items(), start=1):
        target = output_dir / f'page_{idx:03d}.txt'
        target.write_text(text, encoding='utf-8')
        print(f'Saved {target.relative_to(PROJECT_ROOT)} ← {url}')
else:
    print('Skipping crawl (CONFIG["enable_crawl"] is False).')


## 6. Build or refresh the FAISS vector store

When `enable_ingest` is `True`, the ingestor walks the directories listed in `ingest_sources` (relative to the repo root), rebuilds the embeddings, and saves them to the configured vector store directory.


In [None]:
if CONFIG['enable_ingest']:
    sources = [PROJECT_ROOT / Path(path) for path in CONFIG.get('ingest_sources', [])]
    resolved_sources = [str(path) for path in sources if path.exists()]
    if not resolved_sources:
        raise RuntimeError('No sources found. Add PDFs/CSVs/text files under data/ or update CONFIG["ingest_sources"].')

    rag_config = RAGConfig(vector_db_path=VECTOR_DB_DIR)
    ingestor = DocumentIngestor(rag_config)
    vector_store = ingestor.ingest(resolved_sources)
    print('Vector store saved to', VECTOR_DB_DIR)
else:
    print('Skipping ingestion (CONFIG["enable_ingest"] is False). Ensure a vector store exists before launching the chatbot.')


## 7. Optional: fine-tune Qwen with LoRA adapters

Enable `enable_finetune` and provide a chat-style dataset at the path specified by `finetune_dataset` (relative to the repo root). The adapters are written to the directory configured in `adapter_dir`.


In [None]:
if CONFIG['enable_finetune']:
    dataset_path = PROJECT_ROOT / CONFIG['finetune_dataset']
    if not dataset_path.exists():
        raise FileNotFoundError(f'Dataset not found at {dataset_path}. Upload your training data or update CONFIG["finetune_dataset"].')

    train_config = TrainConfig(
        dataset_path=str(dataset_path),
        output_dir=str(ADAPTER_DIR),
        num_train_epochs=CONFIG.get('finetune_epochs', 2),
    )
    trainer = SupervisedFineTuner(train_config)
    trainer.prepare_datasets()
    trainer.train()
else:
    print('Skipping fine-tuning (CONFIG["enable_finetune"] is False).')


## 8. Launch the multimodal KUx chatbot

This cell blocks while the Gradio app is running. Gradio prints both the local and public share URLs—open the share URL in a new tab to chat with KUx. Stop the cell when you want to shut down the chatbot.


In [None]:
model_keys = {option.key for option in MODEL_OPTIONS}
default_model_key = CONFIG.get('default_model_key') or MODEL_OPTIONS[0].key
if default_model_key not in model_keys:
    raise ValueError(f"Unknown model key '{default_model_key}'. Choose from: {sorted(model_keys)}")

system_prompt = CONFIG.get('default_system_prompt') or DEFAULT_SYSTEM_PROMPT

launch(
    vector_db_path=CONFIG['vector_db_dir'],
    adapter_dir=CONFIG['adapter_dir'],
    default_model_key=default_model_key,
    default_system_prompt=system_prompt,
    share=CONFIG.get('launch_share', False),
    preload_default=CONFIG.get('launch_preload', True),
)
