<a href="https://colab.research.google.com/github/redazakan/TextureBot/blob/main/Texture_RAG_Full.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Texture RAG (Embeddings + FAISS + Gemma 3 4B-it)

This notebook is a **fully working end-to-end RAG** demo for a texture recommender:

- **Knowledge Base**: `textures_cleaned.csv`
- **Embeddings**: `google/embeddinggemma-300m` (gated → requires `HF_TOKEN`)
- **Vector Search**: FAISS `IndexFlatIP` (cosine via L2-normalized vectors)
- **LLM**: `google/gemma-3-4b-it` (gated → requires `HF_TOKEN`)
- **Output**: **strict JSON**
- **UI**: Gradio (clean + simple)


## 0) Install dependencies (Colab-safe)

We **do not upgrade** NumPy/Pandas/Torch (Colab pins them).  
We only install what we need for Transformers + FAISS + Gradio + Hugging Face login.


In [None]:
!pip -q install transformers faiss-cpu langid gradio huggingface_hub

## 1) Authenticate to Hugging Face (required for gated models)

Both `embeddinggemma-300m` and `gemma-3-4b-it` are **gated** on Hugging Face.

This cell logs in using that secret.


In [None]:
from google.colab import userdata
from huggingface_hub import login

HF_TOKEN = userdata.get("HF_TOKEN")
assert HF_TOKEN is not None, "HF_TOKEN not found in Colab Secrets"

login(token=HF_TOKEN)
print("✅ Hugging Face login OK")


## 2) Upload / load the CSV knowledge base

Upload `textures_expanded.csv`, then we copy it into `data/`.

Expected columns:
`id, source, asset_name, category, use_case, tags, maps, resolution_max, tileable, free, url, synthetic`


In [None]:
import os, shutil, pandas as pd

os.makedirs("data", exist_ok=True)

src = "textures_expanded.csv"
dst = "data/textures_expanded.csv"
if os.path.exists(src):
    shutil.copy(src, dst)

df = pd.read_csv(dst).fillna("")
print("Rows:", len(df))
print("Columns:", list(df.columns))
df.head(3)


## 3) Core building blocks

### 3.1 Language detection
We auto-detect **FR/EN** using `langid`.

### 3.2 Embeddings (masked mean pooling)
We generate embeddings with `google/embeddinggemma-300m` using:
- attention-mask **masked mean pooling** (ignores padding)
- L2 normalization (critical for cosine similarity with FAISS IP)

### 3.3 LLM grounded ranking
We pass only the **retrieved rows** to Gemma 3 and force:
- output language (FR/EN)
- **STRICT JSON ONLY**
- **NO hallucinated URLs** (must use retrieved rows)


In [None]:
import json
import numpy as np
import torch
import torch.nn.functional as F
import langid
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM

EMBED_MODEL_ID = "google/embeddinggemma-300m"
LLM_MODEL_ID   = "google/gemma-3-4b-it"

def detect_language(text: str) -> str:
    lang, _ = langid.classify(text or "")
    return "FR" if lang == "fr" else "EN"

class TextureEmbedder:
    def __init__(self, token: str, device: str | None = None):
        self.token = token
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        print(f"[Embedder] Loading {EMBED_MODEL_ID} on {self.device}")
        self.tokenizer = AutoTokenizer.from_pretrained(EMBED_MODEL_ID, token=self.token)
        self.model = AutoModel.from_pretrained(EMBED_MODEL_ID, token=self.token).to(self.device)
        self.model.eval()

    @torch.no_grad()
    def embed(self, texts: list[str]) -> np.ndarray:
        inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(self.device)
        outputs = self.model(**inputs)

        # Masked mean pooling (ignore padding tokens)
        mask = inputs["attention_mask"].unsqueeze(-1)          # (B,T,1)
        masked = outputs.last_hidden_state * mask              # (B,T,H)
        pooled = masked.sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)

        # L2 normalize so inner product == cosine similarity
        pooled = F.normalize(pooled, p=2, dim=1)

        # FAISS expects float32, contiguous
        vecs = pooled.detach().cpu().numpy().astype("float32")
        return np.ascontiguousarray(vecs)

class TextureRankerLLM:
    def __init__(self, token: str):
        self.token = token
        print(f"[LLM] Loading {LLM_MODEL_ID}")
        self.tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID, token=self.token)
        dtype = torch.float16 if torch.cuda.is_available() else torch.float32
        self.model = AutoModelForCausalLM.from_pretrained(
            LLM_MODEL_ID,
            token=self.token,
            device_map="auto",
            torch_dtype=dtype,
        )

    def generate_json(self, query: str, rows: list[dict], lang: str) -> str:
        lang_instr = "French" if lang == "FR" else "English"

        ctx_lines = []
        for i, r in enumerate(rows, start=1):
            ctx_lines.append(
                f"[{i}] id={r.get('id','')} | name={r.get('asset_name','')} | "
                f"category={r.get('category','')} | use_case={r.get('use_case','')} | "
                f"tags={r.get('tags','')} | maps={r.get('maps','')} | "
                f"res={r.get('resolution_max','')} | tileable={r.get('tileable','')} | "
                f"url={r.get('url','')} | synthetic={r.get('synthetic','')}"
            )
        context_str = "\n".join(ctx_lines) if ctx_lines else "(none)"

        prompt = f"""<start_of_turn>user
You are a PBR Texture Expert Assistant.

HARD RULES:
- Output language MUST be {lang_instr} only.
- Return STRICT JSON ONLY (no markdown, no extra text).
- Recommend ONLY from Retrieved Textures.
- Do NOT invent URLs. If url is empty or synthetic=yes, mention it in notes.
- If none fit, output empty recommendations and ask ONE clarifying question.

JSON schema:
{{
  "language": "{lang}",
  "recommendations": [
    {{
      "rank": 1,
      "id": "...",
      "name": "...",
      "url": "...",
      "notes": "1-2 sentences: why it fits; mention maps/res/tileable; mention missing url if any."
    }}
  ],
  "clarifying_question": "..."
}}

User Query: {query}

Retrieved Textures:
{context_str}
<end_of_turn>
<start_of_turn>model
"""

        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        out = self.model.generate(**inputs, max_new_tokens=450, temperature=0.2, do_sample=True)

        # Decode only newly generated tokens (avoid returning the prompt)
        gen = out[0][inputs["input_ids"].shape[1]:]
        return self.tokenizer.decode(gen, skip_special_tokens=True).strip()

def safe_json(text: str) -> dict:
    text = (text or "").strip()
    try:
        return json.loads(text)
    except Exception:
        s, e = text.find("{"), text.rfind("}")
        if s != -1 and e != -1 and e > s:
            try:
                return json.loads(text[s:e+1])
            except Exception:
                pass
    return {"raw": text}


## 4) Offline indexing (CSV → embeddings → FAISS)

We transform each CSV row into a single semantic text string, embed it, then build:

- `IndexFlatIP(d)` where vectors are L2-normalized ⇒ cosine similarity search.

Artifacts saved:
- `data/textures.faiss`
- `data/textures_meta.pkl` (to map FAISS ids back to rows)


In [None]:
import faiss, pickle
import pandas as pd

CSV_PATH   = "data/textures_expanded.csv"
INDEX_PATH = "data/textures.faiss"
META_PATH  = "data/textures_meta.pkl"

REQUIRED_COLS = [
    "id","source","asset_name","category","use_case",
    "tags","maps","resolution_max","tileable","free","url","synthetic"
]

def embedding_text(row: pd.Series) -> str:
    r = row.fillna("")
    return (
        f"Name: {r['asset_name']}. "
        f"Category: {r['category']}. "
        f"Use case: {r['use_case']}. "
        f"Tags: {r['tags']}. "
        f"PBR maps: {r['maps']}. "
        f"Resolution: {r['resolution_max']}. "
        f"Tileable: {r['tileable']}. "
        f"Source: {r['source']}."
    )

# Ensure schema
df2 = pd.read_csv(CSV_PATH).fillna("")
for c in REQUIRED_COLS:
    if c not in df2.columns:
        df2[c] = ""
df2 = df2[REQUIRED_COLS]

docs = df2.apply(embedding_text, axis=1).tolist()

embedder = TextureEmbedder(token=HF_TOKEN)
vecs = embedder.embed(docs)

index = faiss.IndexFlatIP(vecs.shape[1])
index.add(vecs)

faiss.write_index(index, INDEX_PATH)
with open(META_PATH, "wb") as f:
    pickle.dump(df2.to_dict("records"), f)

print("✅ Index built. Vectors:", index.ntotal)
print("Saved:", INDEX_PATH, META_PATH)


## 5) Online query (retrieve → grounded LLM)

Pipeline:
1) detect language (FR/EN)
2) embed query
3) FAISS top-k retrieval
4) Gemma ranks + explains **only retrieved** rows
5) parse JSON


In [None]:
# Load artifacts
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM # Re-import necessary modules

index = faiss.read_index("data/textures.faiss")
with open("data/textures_meta.pkl", "rb") as f:
    metadata = pickle.load(f)

# Redefine TextureRankerLLM to use bfloat16 or float32 for model loading
class TextureRankerLLM:
    def __init__(self, token: str):
        self.token = token
        print(f"[LLM] Loading {LLM_MODEL_ID}")
        self.tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID, token=self.token)

        # Determine appropriate dtype: bfloat16 for Gemma if supported, else float32
        if torch.cuda.is_available():
            if torch.cuda.is_bf16_supported():
                dtype = torch.bfloat16
                print("Using torch.bfloat16 for model loading.")
            else:
                dtype = torch.float32
                print("bfloat16 not supported, falling back to torch.float32 for model loading.")
        else:
            dtype = torch.float32
            print("CUDA not available, using torch.float32 for model loading.")

        self.model = AutoModelForCausalLM.from_pretrained(
            LLM_MODEL_ID,
            token=self.token,
            device_map="auto",
            torch_dtype=dtype,
        )

    def generate_json(self, query: str, rows: list[dict], lang: str) -> str:
        lang_instr = "French" if lang == "FR" else "English"

        ctx_lines = []
        for i, r in enumerate(rows, start=1):
            ctx_lines.append(
                f"[{i}] id={r.get('id','')}" " | name={r.get('asset_name','')}" " | "
                f"category={r.get('category','')}" " | use_case={r.get('use_case','')}" " | "
                f"tags={r.get('tags','')}" " | maps={r.get('maps','')}" " | "
                f"res={r.get('resolution_max','')}" " | tileable={r.get('tileable','')}" " | "
                f"url={r.get('url','')}" " | synthetic={r.get('synthetic','')}"
            )
        context_str = "\n".join(ctx_lines) if ctx_lines else "(none)"

        prompt = f"""<start_of_turn>user
You are a PBR Texture Expert Assistant.

HARD RULES:
- Output language MUST be {lang_instr} only.
- Return STRICT JSON ONLY (no markdown, no extra text).
- Recommend ONLY from Retrieved Textures.
- Do NOT invent URLs. If url is empty or synthetic=yes, show the Texture but clarify it.
- If none fit, output empty recommendations and ask ONE clarifying question.

JSON schema:
{{
  "language": "{lang}",
  "recommendations": [
    {{
      "rank": 1,
      "id": "...",
      "name": "...",
      "url": "...",
      "notes": "1-2 sentences: why it fits; mention maps/res/tileable; mention missing url if any."
    }}
  ],
  "clarifying_question": "..."
}}

User Query: {query}

Retrieved Textures:
{context_str}
<end_of_turn>
<start_of_turn>model
"""

        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        out = self.model.generate(**inputs, max_new_tokens=450, temperature=0.2, do_sample=True)

        # Decode only newly generated tokens (avoid returning the prompt)
        gen = out[0][inputs["input_ids"].shape[1]:]
        return self.tokenizer.decode(gen, skip_special_tokens=True).strip()

llm = TextureRankerLLM(token=HF_TOKEN) # Instantiate with the redefined class
TOP_K = 5

def answer(query: str, top_k: int = TOP_K) -> dict:
    query = (query or "").strip()
    if not query:
        return {"error": "empty query"}

    lang = detect_language(query)
    q_vec = embedder.embed([query])
    _, idxs = index.search(q_vec, top_k)
    rows = [metadata[i] for i in idxs[0] if 0 <= i < len(metadata)]

    out = llm.generate_json(query, rows, lang)
    return safe_json(out)

# Try:
answer("concrete floor texture, modern minimalist interior")

## 6)Simple Gradio UI (share link)

This gives us a clean web UI and a shareable link (temporary while Colab runs).


In [None]:
import gradio as gr
import html
import math

CSS = """
:root{
  --bg: #0B0F14;
  --panel: #121A23;
  --panel2:#0F1620;
  --border:#223041;
  --text:#E6EDF3;
  --muted:#9FB0C0;
  --accent:#F28C28;     /* Blender orange vibe */
  --accent2:#4DA3FF;    /* subtle blue */
  --chip:#1A2431;
}

.gradio-container { background: var(--bg) !important; color: var(--text) !important; }
#title h1 { margin-bottom: 6px; }
#subtitle { color: var(--muted); margin-top: 0; }

.bl-panel {
  background: linear-gradient(180deg, var(--panel), var(--panel2));
  border: 1px solid var(--border);
  border-radius: 16px;
  padding: 14px;
}

.bl-input textarea, .bl-input input {
  background: #0D141D !important;
  border: 1px solid var(--border) !important;
  color: var(--text) !important;
  border-radius: 14px !important;
}

.bl-btn button {
  background: var(--accent) !important;
  color: #111 !important;
  border: none !important;
  border-radius: 14px !important;
  font-weight: 800 !important;
}

.bl-btn button:hover { filter: brightness(1.05); }

.bl-header {
  display:flex; align-items:center; gap:10px;
  margin: 10px 0 14px 0;
}

.bl-pill {
  font-weight:800;
  border:1px solid var(--border);
  border-radius:999px;
  padding:2px 10px;
  font-size:12px;
  background: #0D141D;
}

.bl-muted { color: var(--muted); font-size: 13px; }

.card {
  border: 1px solid var(--border);
  border-radius: 16px;
  padding: 16px;
  margin: 12px 0;
  background: rgba(18,26,35,0.75);
  box-shadow: 0 10px 25px rgba(0,0,0,0.25);
}

.card-top { display:flex; justify-content:space-between; align-items:center; gap:10px; }
.card-title { font-size:18px; font-weight:900; line-height:1.2; }
.card-chips { margin-top:10px; display:flex; gap:8px; flex-wrap:wrap; }

.chip {
  border:1px solid var(--border);
  border-radius:999px;
  padding:2px 10px;
  font-size:12px;
  background: var(--chip);
  color: var(--text);
}

.card-notes { margin-top:10px; color: var(--text); opacity: 0.95; line-height: 1.55; font-size: 14px; }

.link {
  text-decoration:none;
  font-size:18px;
  font-weight:900;
  color: var(--accent);
}
.link:hover { text-decoration: underline; }

.hint {
  margin-top:14px;
  padding:12px;
  border-radius:14px;
  background: rgba(242,140,40,0.08);
  border: 1px solid rgba(242,140,40,0.35);
}
.hint-title { font-weight:900; margin-bottom:6px; color: var(--accent); }
"""

def auto_top_k(n_kb: int) -> int:
    """
    Auto Top-K:
    - small KB -> smaller K
    - larger KB -> increase but cap
    """
    if n_kb <= 30: return 5
    if n_kb <= 100: return 7
    if n_kb <= 400: return 10
    return 12  # cap

def build_cards_ui(result: dict, retrieved_rows: list[dict], lang: str) -> str:
    recs = result.get("recommendations", []) if isinstance(result, dict) else []
    q = (result.get("clarifying_question", "") if isinstance(result, dict) else "") or ""

    header = f"""
    <div class="bl-header">
      <span class="bl-pill">{html.escape(lang)}</span>
      <span class="bl-muted">Found {len(retrieved_rows)} candidates via FAISS</span>
    </div>
    """

    if not recs:
        return header + """
        <div class="card">
          <div class="card-title">No good match found</div>
          <div class="card-notes">Try adding: material (metal/wood/concrete), condition (rusty/wet/dirty), and where it’s used.</div>
        </div>
        """

    cards = []
    for r in recs:
        rid = str(r.get("id",""))
        row = next((x for x in retrieved_rows if str(x.get("id","")) == rid), None)

        # Title: prefer a richer label
        name = str(r.get("name","")) or (row.get("asset_name") if row else "Untitled")
        name = html.escape(name)

        notes = html.escape(str(r.get("notes","")))

        url = str(r.get("url","") or "").strip()
        url_html = f"""<a class="link" href="{html.escape(url)}" target="_blank">↗</a>""" if url else """<span class="bl-muted">↗</span>"""

        chips = []
        if row:
            cat = str(row.get("category","")).strip()
            res = str(row.get("resolution_max","")).strip()
            tile = str(row.get("tileable","")).strip().lower()
            maps_ = str(row.get("maps","")).strip()
            if cat: chips.append(cat)
            if res: chips.append(res)
            if tile in ["1","true","yes","y"]: chips.append("tileable")
            if maps_: chips.append("PBR maps")
        chip_html = "".join([f'<span class="chip">{html.escape(c)}</span>' for c in chips])

        cards.append(f"""
        <div class="card">
          <div class="card-top">
            <div class="card-title">{name}</div>
            <div>{url_html}</div>
          </div>
          <div class="card-chips">{chip_html}</div>
          <div class="card-notes">{notes}</div>
        </div>
        """
        )

    clarif = ""
    if q:
        clarif = f"""
        <div class="hint">
          <div class="hint-title">Clarifying question</div>
          <div class="card-notes">{html.escape(q)}</div>
        </div>
        """

    return header + "\n".join(cards) + clarif


def ui_search(query: str):
    query = (query or "").strip()
    if not query:
        return "", {"error": "empty query"}

    lang = detect_language(query)
    q_vec = embedder.embed([query])

    # We *search more* than we show, then we *display max 5*
    # This helps if top results are noisy.
    SEARCH_K = 12
    SHOW_MAX = 5
    SHOW_MIN = 2

    _, idxs = index.search(q_vec, SEARCH_K)
    retrieved_rows_all = [metadata[i] for i in idxs[0] if 0 <= i < len(metadata)]

    # If FAISS returned nothing, return quickly
    if len(retrieved_rows_all) == 0:
        data = {
            "language": lang,
            "recommendations": [],
            "clarifying_question": "Could you specify the material (metal/wood/concrete) and where it will be used (floor/walls/machinery)?"
        }
        return build_cards_ui(data, [], lang), data

    # We pass up to SHOW_MAX to the LLM (keeps output clean)
    context_rows = retrieved_rows_all[:SHOW_MAX]

    # If we got fewer than min, just pass what we have (no hallucination)
    # (Optional) you can expand context a bit if <2 to give the LLM more chance:
    if len(context_rows) < SHOW_MIN:
        context_rows = retrieved_rows_all[:max(SHOW_MIN, len(retrieved_rows_all))]

    out = llm.generate_json(query, context_rows, lang)
    data = safe_json(out)

    html_cards = build_cards_ui(data, context_rows, lang)
    return html_cards, data



with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
    with gr.Column(elem_id="title"):
        gr.Markdown("# Texture RAG")
        gr.Markdown("<div id='subtitle'>FR/EN • FAISS retrieval + Gemma ranking</div>")

    with gr.Row():
        with gr.Column(scale=10):
            q = gr.Textbox(
                label="Prompt",
                placeholder="hi :)! How can I help you ^^?",
                lines=1,
                elem_classes=["bl-input", "bl-panel"]
            ) # Added missing closing parenthesis here
            run = gr.Button("Search", elem_classes=["bl-btn"])


    cards = gr.HTML(elem_classes=["bl-panel"])

    # JSON is now "more hidden"
    with gr.Accordion("Developer (raw JSON)", open=False):
        raw = gr.JSON()

    run.click(fn=ui_search, inputs=[q], outputs=[cards, raw])
    q.submit(fn=ui_search, inputs=[q], outputs=[cards, raw])

demo.launch(share=True)