In [1]:
%pip install --quiet tensorflow opencv-python-headless pillow gradio psycopg2-binary python-dotenv numpy



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.3.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [2]:
import os, io, json, uuid, hashlib, socket, traceback
import numpy as np
import cv2
from PIL import Image, ImageOps

import psycopg2
import psycopg2.extras
import gradio as gr

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

from dotenv import load_dotenv
load_dotenv()

LABELS = ["NM", "EX", "GD", "LP", "PL", "PO"]
LABEL2IDX = {l:i for i,l in enumerate(LABELS)}
IDX2LABEL = {i:l for l,i in LABEL2IDX.items()}

IMG_H = 352
IMG_W = 256

# ====== Multi-View ======
CROP_FRAC = 0.62
AUG_FULL = True         # flips/rotations for full
AUG_CORNERS = True      # flips/rotations for corner crops
USE_JITTER = True       # extra pseudo-viewpoints
N_JITTER = 3            # number of jittered full views
JITTER_MAX_PX = 6

# ====== Embedding ======
TOPK = 10
EMB_DIM = 128

# ====== kNN fusion ======
KNN_K = 7
USE_KNN_FUSION = True
FUSION_CONF_TH = 0.55   # if classifier conf below this -> take kNN label

# ====== Training ======
BATCH_SIZE = 32
EPOCHS_HEAD = 6
EPOCHS_FINETUNE = 6
LR_HEAD = 1e-3
LR_FINETUNE = 2e-4

MODEL_PATH = "card_back_condition_model.keras"


In [3]:
# Schritt 3: DB Verbindung (optional, falls nicht vorhanden)
import os

def get_conn():
    return psycopg2.connect(
        host=os.getenv("PGHOST", "localhost"),
        port=int(os.getenv("PGPORT", "5434")),
        dbname=os.getenv("PGDATABASE", "sam1988"),
        user=os.getenv("PGUSER", "sam1988"),
        password=os.getenv("PGPASSWORD", "Ss190488!")
    )


In [4]:
def ensure_schema():
    labels_sql = ",".join([f"'{l}'" for l in LABELS])
    with get_conn() as conn:
        with conn.cursor() as cur:
            # Stelle sicher, dass Tabelle existiert (aus Notebook 1)
            cur.execute(f"""
            CREATE TABLE IF NOT EXISTS pokemon_card_back_samples (
                id UUID PRIMARY KEY,
                created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
                label TEXT NOT NULL CHECK (label IN ({labels_sql})),
                note TEXT,
                raw_sha256 TEXT NOT NULL UNIQUE,
                raw_format TEXT NOT NULL,
                raw_w INT,
                raw_h INT,
                raw_bytes BYTEA NOT NULL,
                proc_format TEXT NOT NULL,
                proc_w INT NOT NULL,
                proc_h INT NOT NULL,
                proc_bytes BYTEA NOT NULL
            );
            """)
            # Migration: mask + debug columns (falls fehlen)
            cur.execute("ALTER TABLE pokemon_card_back_samples ADD COLUMN IF NOT EXISTS proc_mask_format TEXT;")
            cur.execute("ALTER TABLE pokemon_card_back_samples ADD COLUMN IF NOT EXISTS proc_mask_w INT;")
            cur.execute("ALTER TABLE pokemon_card_back_samples ADD COLUMN IF NOT EXISTS proc_mask_h INT;")
            cur.execute("ALTER TABLE pokemon_card_back_samples ADD COLUMN IF NOT EXISTS proc_mask_bytes BYTEA;")
            cur.execute("ALTER TABLE pokemon_card_back_samples ADD COLUMN IF NOT EXISTS proc_method TEXT;")
            cur.execute("ALTER TABLE pokemon_card_back_samples ADD COLUMN IF NOT EXISTS proc_quad_expand REAL;")

def db_counts():
    with get_conn() as conn:
        with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
            cur.execute("""
                SELECT label, COUNT(*)::int AS n
                FROM pokemon_card_back_samples
                GROUP BY label
                ORDER BY label;
            """)
            rows = cur.fetchall()
    counts = {r["label"]: r["n"] for r in rows}
    for l in LABELS:
        counts.setdefault(l, 0)
    counts["TOTAL"] = sum(counts[l] for l in LABELS)
    return counts

def fmt_pg_error(e: Exception) -> str:
    if isinstance(e, psycopg2.Error):
        parts = [f"{type(e).__name__}: {e}"]
        if getattr(e, "pgcode", None):
            parts.append(f"pgcode: {e.pgcode}")
        if getattr(e, "pgerror", None):
            parts.append(f"pgerror: {e.pgerror}")
        diag = getattr(e, "diag", None)
        if diag is not None:
            for k in ["message_detail","message_hint","schema_name","table_name","column_name","constraint_name"]:
                v = getattr(diag, k, None)
                if v:
                    parts.append(f"{k}: {v}")
        return "\n".join(parts)
    return f"{type(e).__name__}: {e}"

ensure_schema()
print("✅ Schema OK. Counts:", db_counts())


✅ Schema OK. Counts: {'EX': 10, 'GD': 10, 'LP': 5, 'NM': 10, 'PL': 5, 'PO': 10, 'TOTAL': 50}


In [5]:
def proc_png_bytes_to_np(proc_png: bytes) -> np.ndarray:
    pil = Image.open(io.BytesIO(proc_png)).convert("RGB")
    arr = np.array(pil, dtype=np.uint8)
    if arr.shape[:2] != (IMG_H, IMG_W):
        arr = cv2.resize(arr, (IMG_W, IMG_H), interpolation=cv2.INTER_AREA)
    return arr

def resize_to_base(img: np.ndarray) -> np.ndarray:
    if img.shape[:2] != (IMG_H, IMG_W):
        img = cv2.resize(img, (IMG_W, IMG_H), interpolation=cv2.INTER_AREA)
    return img

def jitter_perspective(img: np.ndarray, max_px: int = 6) -> np.ndarray:
    H, W = img.shape[:2]
    src = np.array([[0,0],[W-1,0],[W-1,H-1],[0,H-1]], dtype=np.float32)
    j = np.random.randint(-max_px, max_px+1, size=(4,2)).astype(np.float32)
    dst = np.clip(src + j, [0,0], [W-1,H-1]).astype(np.float32)
    M = cv2.getPerspectiveTransform(src, dst)
    out = cv2.warpPerspective(img, M, (W, H), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
    return out

def make_views(img_uint8: np.ndarray):
    """
    img_uint8: (IMG_H,IMG_W,3) uint8
    returns list of uint8 views, names parallel
    Base: full + 4 corners + 4 edges = 9
    + flips/rotations: full (+3) and corners (+12) if enabled
    + jittered fulls (+N_JITTER) if enabled
    """
    img = resize_to_base(img_uint8)
    H, W = img.shape[:2]
    ch = int(round(H * CROP_FRAC))
    cw = int(round(W * CROP_FRAC))

    def crop(y0, x0, y1, x1):
        c = img[y0:y1, x0:x1]
        return resize_to_base(c)

    # base views
    full = img
    tl = crop(0, 0, ch, cw)
    tr = crop(0, W-cw, ch, W)
    bl = crop(H-ch, 0, H, cw)
    br = crop(H-ch, W-cw, H, W)

    y_mid0 = (H - ch)//2
    x_mid0 = (W - cw)//2
    top = crop(0, x_mid0, ch, x_mid0+cw)
    bottom = crop(H-ch, x_mid0, H, x_mid0+cw)
    left = crop(y_mid0, 0, y_mid0+ch, cw)
    right = crop(y_mid0, W-cw, y_mid0+ch, W)

    base = [
        ("full", full),
        ("corner_tl", tl), ("corner_tr", tr), ("corner_bl", bl), ("corner_br", br),
        ("edge_top", top), ("edge_bottom", bottom), ("edge_left", left), ("edge_right", right),
    ]

    def aug(name, v):
        return [
            (name, v),
            (name + "_hflip", cv2.flip(v, 1)),
            (name + "_vflip", cv2.flip(v, 0)),
            (name + "_rot180", cv2.rotate(v, cv2.ROTATE_180)),
        ]

    views = []
    # base 9
    views.extend(base)

    # full augs (+3)
    if AUG_FULL:
        views.extend(aug("full", full)[1:])

    # corner augs (+12)
    if AUG_CORNERS:
        for nm, v in [("corner_tl", tl), ("corner_tr", tr), ("corner_bl", bl), ("corner_br", br)]:
            views.extend(aug(nm, v)[1:])

    # jitter (+N_JITTER)
    if USE_JITTER:
        for j in range(N_JITTER):
            views.append((f"full_jitter{j+1}", jitter_perspective(full, max_px=JITTER_MAX_PX)))

    names = [n for n,_ in views]
    imgs  = [im for _,im in views]
    return imgs, names


In [6]:
def fetch_training_samples():
    with get_conn() as conn:
        with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
            cur.execute("""
                SELECT id, label, proc_bytes
                FROM pokemon_card_back_samples
                WHERE proc_bytes IS NOT NULL AND label IS NOT NULL
            """)
            rows = cur.fetchall()

    X, y, ids = [], [], []
    for r in rows:
        lbl = r["label"]
        if lbl not in LABELS:
            continue
        arr = proc_png_bytes_to_np(bytes(r["proc_bytes"]))
        X.append(arr)
        y.append(LABEL2IDX[lbl])
        ids.append(str(r["id"]))

    if len(X) == 0:
        raise RuntimeError("Keine Trainingsdaten in DB gefunden (proc_bytes/label).")

    X = np.stack(X, axis=0)
    y = np.array(y, dtype=np.int32)
    return X, y, ids

X, y, ids = fetch_training_samples()
print("Loaded:", X.shape, y.shape)
print("Class counts:", {l:int((y==LABEL2IDX[l]).sum()) for l in LABELS})


Loaded: (50, 352, 256, 3) (50,)
Class counts: {'NM': 10, 'EX': 10, 'GD': 10, 'LP': 5, 'PL': 5, 'PO': 10}


In [7]:
def make_splits(X, y, val_ratio=0.15, seed=42):
    rng = np.random.default_rng(seed)
    idx = np.arange(len(X))
    rng.shuffle(idx)
    n_val = max(1, int(round(len(X)*val_ratio)))
    val_idx = idx[:n_val]
    tr_idx  = idx[n_val:]
    return (X[tr_idx], y[tr_idx]), (X[val_idx], y[val_idx])

(X_tr, y_tr), (X_va, y_va) = make_splits(X, y, val_ratio=0.15)

def augment_tf(img):
    img = tf.image.random_brightness(img, 0.08)
    img = tf.image.random_contrast(img, 0.90, 1.10)
    return img

def make_ds(X, y, training: bool):
    ds = tf.data.Dataset.from_tensor_slices((X, y))
    if training:
        ds = ds.shuffle(min(len(X), 2000), reshuffle_each_iteration=True)
    ds = ds.map(lambda a,b: (tf.cast(a, tf.float32), b), num_parallel_calls=tf.data.AUTOTUNE)
    if training:
        ds = ds.map(lambda a,b: (augment_tf(a), b), num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
    return ds

train_ds = make_ds(X_tr, y_tr, training=True)
val_ds   = make_ds(X_va, y_va, training=False)

print("Train/Val:", len(X_tr), len(X_va))


Train/Val: 42 8


In [8]:
def build_model():
    inp = keras.Input(shape=(IMG_H, IMG_W, 3), name="img")
    x = keras.applications.mobilenet_v2.preprocess_input(inp)

    base = keras.applications.MobileNetV2(
        include_top=False,
        weights="imagenet",
        input_tensor=x
    )
    base.trainable = False

    x = base.output
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.2)(x)

    emb = layers.Dense(EMB_DIM, activation=None, name="embedding")(x)
    emb = layers.LayerNormalization()(emb)

    out = layers.Dense(len(LABELS), activation="softmax", name="class")(emb)
    return keras.Model(inp, out, name="card_back_condition")

if os.path.exists(MODEL_PATH):
    model = keras.models.load_model(MODEL_PATH)
    print("✅ Loaded existing model:", MODEL_PATH)
else:
    model = build_model()
    print("✅ Built new model.")

model.summary()


✅ Loaded existing model: card_back_condition_model.keras


In [9]:
# 1) Head training
model.compile(
    optimizer=keras.optimizers.Adam(LR_HEAD),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)

callbacks = [
    keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True, monitor="val_accuracy"),
]

history1 = model.fit(train_ds, validation_data=val_ds, epochs=EPOCHS_HEAD, callbacks=callbacks)

# 2) Fine-tune robust
def find_mobilenet_submodel(m: keras.Model):
    for l in m.layers:
        if isinstance(l, keras.Model) and hasattr(l, "layers"):
            name = (l.name or "").lower()
            if "mobilenet" in name and len(l.layers) > 20:
                return l
    return None

def finetune_with_submodel(m: keras.Model, base: keras.Model, unfreeze_from_ratio=0.70):
    base.trainable = True
    n = len(base.layers)
    cut = int(n * unfreeze_from_ratio)

    for i, l in enumerate(base.layers):
        if isinstance(l, keras.layers.BatchNormalization):
            l.trainable = False
        else:
            l.trainable = (i >= cut)

    m.compile(
        optimizer=keras.optimizers.Adam(LR_FINETUNE),
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"]
    )
    return m

base = find_mobilenet_submodel(model)
if base is not None:
    print("✅ MobileNet submodel:", base.name, "| layers:", len(base.layers))
    finetune_with_submodel(model, base, unfreeze_from_ratio=0.70)
    history2 = model.fit(train_ds, validation_data=val_ds, epochs=EPOCHS_FINETUNE, callbacks=callbacks)
else:
    print("⚠️ Kein MobileNet submodel gefunden – Fine-tune übersprungen.")

model.save(MODEL_PATH)
print("✅ Model saved:", MODEL_PATH)


Epoch 1/6
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 1s/step - accuracy: 0.2857 - loss: 1.7475 - val_accuracy: 0.1250 - val_loss: 2.2263
Epoch 2/6
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 303ms/step - accuracy: 0.1429 - loss: 1.7650 - val_accuracy: 0.3750 - val_loss: 2.0783
Epoch 3/6
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 287ms/step - accuracy: 0.1429 - loss: 1.7720 - val_accuracy: 0.0000e+00 - val_loss: 2.2867
Epoch 4/6
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 289ms/step - accuracy: 0.2381 - loss: 1.7761 - val_accuracy: 0.1250 - val_loss: 2.4247
Epoch 5/6
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 284ms/step - accuracy: 0.1905 - loss: 1.7994 - val_accuracy: 0.0000e+00 - val_loss: 2.3405
⚠️ Kein MobileNet submodel gefunden – Fine-tune übersprungen.
✅ Model saved: card_back_condition_model.keras


In [10]:
embedder = keras.Model(
    inputs=model.input,
    outputs=model.get_layer("embedding").output,
    name="embedder"
)

def l2_normalize(v: np.ndarray, eps=1e-12):
    n = np.linalg.norm(v, axis=-1, keepdims=True)
    return v / (n + eps)

def embed_images_np(imgs_uint8: np.ndarray, batch=64) -> np.ndarray:
    embs = []
    for i in range(0, len(imgs_uint8), batch):
        x = imgs_uint8[i:i+batch].astype(np.float32)
        e = embedder.predict(x, verbose=0)
        embs.append(e)
    embs = np.concatenate(embs, axis=0)
    return l2_normalize(embs)


In [11]:
REF_EMB = None
REF_META = None

def fetch_reference_rows():
    with get_conn() as conn:
        with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
            cur.execute("""
                SELECT id, label, proc_bytes
                FROM pokemon_card_back_samples
                WHERE proc_bytes IS NOT NULL AND label IS NOT NULL
                ORDER BY created_at ASC
            """)
            return cur.fetchall()

def rebuild_reference_cache():
    global REF_EMB, REF_META
    rows = fetch_reference_rows()

    meta = []
    emb_list = []

    for r in rows:
        lbl = r["label"]
        if lbl not in LABELS:
            continue

        img = proc_png_bytes_to_np(bytes(r["proc_bytes"]))  # base proc from Notebook1
        views, _names = make_views(img)
        views_np = np.stack(views, axis=0)  # (V,H,W,3)

        e = embed_images_np(views_np)       # (V,D) l2
        e_mean = e.mean(axis=0)
        e_mean = e_mean / (np.linalg.norm(e_mean) + 1e-12)

        emb_list.append(e_mean.astype(np.float32))
        meta.append({"id": str(r["id"]), "label": lbl, "proc_bytes": bytes(r["proc_bytes"])})

    if len(meta) == 0:
        REF_EMB = np.zeros((0, EMB_DIM), dtype=np.float32)
        REF_META = []
        return

    REF_EMB = np.stack(emb_list, axis=0).astype(np.float32)
    REF_META = meta

rebuild_reference_cache()
print("✅ Reference cache:", len(REF_META), REF_EMB.shape)


✅ Reference cache: 50 (50, 128)


In [12]:
def topk_matches(query_emb: np.ndarray, k=TOPK):
    if REF_EMB is None or len(REF_EMB) == 0:
        return [], []
    sims = REF_EMB @ query_emb  # cosine similarity (l2)
    idx = np.argsort(-sims)[:k]
    return idx.tolist(), sims[idx].tolist()

def knn_vote(idxs, sims):
    scores = {l: 0.0 for l in LABELS}
    # take top KNN_K neighbors from provided idxs/sims (already sorted)
    for i, s in list(zip(idxs, sims))[:KNN_K]:
        scores[REF_META[i]["label"]] += float(max(0.0, s))
    best_lbl, best_score = max(scores.items(), key=lambda x: x[1])
    total = sum(scores.values()) + 1e-12
    conf = best_score / total
    return best_lbl, float(conf), scores

def analyze_proc_np(proc_np_uint8: np.ndarray):
    """
    proc_np_uint8: (IMG_H,IMG_W,3) uint8 from DB-like proc
    returns: status, gallery_items, state_payload, ui_text
    """
    proc_np_uint8 = resize_to_base(proc_np_uint8)

    # ---- Multi-view classifier: avg probs over views ----
    views, names = make_views(proc_np_uint8)
    views_np = np.stack(views, axis=0).astype(np.float32)
    probs_v = model.predict(views_np, verbose=0)   # (V,6)
    probs = probs_v.mean(axis=0)                   # (6,)

    cls_idx = int(np.argmax(probs))
    cls_label = IDX2LABEL[cls_idx]
    cls_conf = float(probs[cls_idx])

    # ---- Multi-view embedding: mean embedding ----
    e = embed_images_np(np.stack(views, axis=0))   # (V,D)
    q_emb = e.mean(axis=0)
    q_emb = q_emb / (np.linalg.norm(q_emb) + 1e-12)

    # ---- kNN ----
    idxs, sims = topk_matches(q_emb, k=TOPK)

    if idxs:
        knn_label, knn_conf, knn_scores = knn_vote(idxs, sims)
    else:
        knn_label, knn_conf, knn_scores = cls_label, 0.0, {}

    # ---- Fusion ----
    if USE_KNN_FUSION and cls_conf < FUSION_CONF_TH:
        final_label = knn_label
        final_source = f"kNN (cls_conf<{FUSION_CONF_TH})"
    else:
        final_label = cls_label
        final_source = "Classifier"

    # ---- gallery of similar refs ----
    gal = []
    for i, s in zip(idxs, sims):
        m = REF_META[i]
        img = Image.open(io.BytesIO(m["proc_bytes"])).convert("RGB")
        gal.append((img, f'{m["label"]} | sim={s:.3f} | {m["id"][:8]}'))

    status = {
        "final_label": final_label,
        "final_source": final_source,
        "classifier": {
            "label": cls_label,
            "conf": round(cls_conf, 4),
            "probs": {IDX2LABEL[i]: float(probs[i]) for i in range(len(LABELS))}
        },
        "knn": {
            "label": knn_label,
            "conf": round(float(knn_conf), 4),
            "scores": {k: round(float(v), 4) for k, v in (knn_scores or {}).items()}
        },
        "views_used": len(views),
        "reference_size": len(REF_META)
    }

    txt = (
        f"Zustand: {final_label} | Quelle: {final_source} | "
        f"Cls: {cls_label} ({cls_conf:.2f}) | kNN: {knn_label} ({knn_conf:.2f}) | Views: {len(views)}"
    )

    # payload for approve (we will store q_emb as ref embedding)
    state_payload = {"status": status, "q_emb": q_emb}

    return status, gal, state_payload, txt


In [13]:
def fallback_prepare_proc_only(pil_img: Image.Image):
    # Fallback: nur resize, wenn du Notebook1-prepare_for_db nicht hier drin hast
    pil_img = ImageOps.exif_transpose(pil_img).convert("RGB")
    arr = np.array(pil_img, dtype=np.uint8)
    arr = cv2.resize(arr, (IMG_W, IMG_H), interpolation=cv2.INTER_AREA)
    return arr

# >>> EMPFEHLUNG:
# Ersetze später fallback_prepare_proc_only durch dein Notebook1 prepare_for_db,
# damit Uploads exakt gleich normalisiert werden wie deine DB-Referenzen.


In [14]:
def insert_sample(label, note,
                  raw_bytes, raw_format, raw_w, raw_h,
                  proc_bytes, proc_format, proc_w, proc_h,
                  mask_bytes=None, mask_format="png", mask_w=IMG_W, mask_h=IMG_H,
                  proc_method=None, proc_quad_expand=None):
    raw_sha = hashlib.sha256(raw_bytes).hexdigest()
    sample_id = uuid.uuid4()

    try:
        with get_conn() as conn:
            with conn.cursor() as cur:
                cur.execute("""
                    INSERT INTO pokemon_card_back_samples (
                        id, label, note,
                        raw_sha256, raw_format, raw_w, raw_h, raw_bytes,
                        proc_format, proc_w, proc_h, proc_bytes,
                        proc_mask_format, proc_mask_w, proc_mask_h, proc_mask_bytes,
                        proc_method, proc_quad_expand
                    )
                    VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)
                """, (
                    str(sample_id), label, note,
                    raw_sha, raw_format, raw_w, raw_h, psycopg2.Binary(raw_bytes),
                    proc_format, proc_w, proc_h, psycopg2.Binary(proc_bytes),
                    mask_format, mask_w, mask_h, psycopg2.Binary(mask_bytes) if mask_bytes else None,
                    proc_method, float(proc_quad_expand) if proc_quad_expand is not None else None
                ))
        return True, str(sample_id), raw_sha, "✅ Gespeichert."
    except Exception as e:
        return False, None, raw_sha, "❌ DB-Fehler:\n" + fmt_pg_error(e) + "\n\n" + traceback.format_exc()

def approve_and_save(state, user_note, label_override):
    global REF_EMB, REF_META

    if state is None:
        return "❌ Kein Analyse-State. Bitte erst analysieren.", db_counts()

    predicted = state["status"]["final_label"]
    if predicted not in LABELS:
        return f"❌ Interner Fehler: predicted label ungültig: {predicted}", db_counts()

    # ✅ Label-Logik:
    # - AUTO -> predicted/final_label
    # - sonst -> override
    if label_override in LABELS:
        label_used = label_override
        was_corrected = (label_override != predicted)
    else:
        label_used = predicted
        was_corrected = False

    # Note mit allen Infos (Prediction + ggf. Korrektur)
    note_obj = {
        "user_note": (user_note or "").strip(),
        "predicted": predicted,
        "label_used": label_used,
        "was_corrected": was_corrected,
        "final_source": state["status"].get("final_source"),
        "classifier": state["status"].get("classifier"),
        "knn": state["status"].get("knn"),
        "views_used": state["status"].get("views_used"),
    }
    note = json.dumps(note_obj, ensure_ascii=False)

    # Minimal speichern (Fallback-Preprocess). Ideal: später Notebook1 prepare_for_db übernehmen.
    pil_img = state.get("_last_pil")
    if pil_img is None:
        return "❌ Interner Fehler: kein Upload-Image im State.", db_counts()

    pil_fixed = ImageOps.exif_transpose(pil_img).convert("RGB")
    raw_w, raw_h = pil_fixed.size

    raw_buf = io.BytesIO()
    pil_fixed.save(raw_buf, format="JPEG", quality=92, optimize=True)
    raw_bytes = raw_buf.getvalue()

    proc_np = state.get("_proc_np")  # RGB uint8
    if proc_np is None:
        return "❌ Interner Fehler: proc fehlt im State.", db_counts()

    ok, enc = cv2.imencode(".png", cv2.cvtColor(proc_np, cv2.COLOR_RGB2BGR))
    if not ok:
        return "❌ Konnte proc PNG nicht encodieren.", db_counts()
    proc_png = enc.tobytes()

    ok, new_id, sha, msg = insert_sample(
        label=label_used,
        note=note,
        raw_bytes=raw_bytes, raw_format="jpeg", raw_w=raw_w, raw_h=raw_h,
        proc_bytes=proc_png, proc_format="png", proc_w=IMG_W, proc_h=IMG_H,
        mask_bytes=None,
        proc_method="notebook2_fallback",
        proc_quad_expand=None
    )

    if ok:
        # Cache inkrementell updaten: Query-Embedding wird Referenz
        if REF_META is None:
            REF_META = []
        if REF_EMB is None:
            REF_EMB = np.zeros((0, EMB_DIM), dtype=np.float32)

        REF_META.append({"id": new_id, "label": label_used, "proc_bytes": proc_png})
        REF_EMB = np.vstack([REF_EMB, state["q_emb"][None].astype(np.float32)])

        extra = " (korrigiert)" if was_corrected else ""
        return f"✅ Gespeichert: {label_used}{extra}. id={new_id[:8]} sha={sha[:10]}...", db_counts()

    return msg, db_counts()



In [15]:
with gr.Blocks(title="Pokemon Condition Compare + Approve") as app2:
    gr.Markdown("# 🧠 Notebook 2: Vergleich (Multi-View + kNN) → Zustand anzeigen → Zustimmen & optional korrigieren")

    state = gr.State(None)

    with gr.Row():
        img_in = gr.Image(label="Upload (Handy): Rückseite", type="pil")
        proc_prev = gr.Image(label="Proc Preview (für Analyse)", type="pil")

    with gr.Row():
        btn_analyze = gr.Button("Analysieren & Vergleichen", variant="primary")
        btn_approve = gr.Button("✅ Zustimmen & Speichern", variant="secondary")

    predicted_txt = gr.Textbox(label="Model-Analyse (Zustand)", interactive=False)

    # ✅ Neu: Override-Label (optional)
    label_override = gr.Dropdown(
        choices=["AUTO (Model)"] + LABELS,
        value="AUTO (Model)",
        label="Falls falsch: richtiges Label auswählen (optional)"
    )

    note = gr.Textbox(label="Notiz (optional)", placeholder="z.B. 'Ecke oben rechts beschädigt'")

    status_json = gr.JSON(label="Details (Classifier + kNN + Views)")
    gallery = gr.Gallery(label=f"Top-{TOPK} ähnlichste Referenzen", columns=4, height="auto")
    save_msg = gr.Textbox(label="Speicher-Status", interactive=False)
    counts = gr.JSON(label="DB Counts")

    def on_analyze(pil_img):
        if pil_img is None:
            return None, None, [], None, "❌ Bitte Bild hochladen.", "", "AUTO (Model)", db_counts()

        # IMPORTANT: aktuell fallback. Idealerweise: Notebook1 prepare_for_db nutzen.
        proc_np = fallback_prepare_proc_only(pil_img)  # RGB uint8 (IMG_H,IMG_W,3)
        proc_prev_pil = Image.fromarray(proc_np)

        status, gal, st_payload, txt = analyze_proc_np(proc_np)

        # carry extra for approve
        st_payload["_last_pil"] = pil_img
        st_payload["_proc_np"] = proc_np

        # Override zurück auf AUTO setzen (User entscheidet aktiv, ob er korrigiert)
        return status, proc_prev_pil, gal, st_payload, "", txt, "AUTO (Model)", db_counts()

    btn_analyze.click(
        fn=on_analyze,
        inputs=[img_in],
        outputs=[status_json, proc_prev, gallery, state, save_msg, predicted_txt, label_override, counts]
    )

    def on_approve(st, user_note, override_choice):
        # Dropdown liefert "AUTO (Model)" oder ein echtes Label
        override = override_choice if override_choice in LABELS else None
        return approve_and_save(st, user_note, override)

    btn_approve.click(
        fn=on_approve,
        inputs=[state, note, label_override],
        outputs=[save_msg, counts]
    )

    counts.value = db_counts()


In [16]:
def guess_local_ip():
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    try:
        s.connect(("10.255.255.255", 1))
        ip = s.getsockname()[0]
    except Exception:
        ip = "127.0.0.1"
    finally:
        s.close()
    return ip

def get_free_port():
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.bind(("", 0))
    port = s.getsockname()[1]
    s.close()
    return port

local_ip = guess_local_ip()
port = get_free_port()

print(f"👉 Öffne am Handy (gleiches WLAN): http://{local_ip}:{port}")
app2.launch(server_name="0.0.0.0", server_port=port, share=False)


👉 Öffne am Handy (gleiches WLAN): http://192.168.8.10:50124
* Running on local URL:  http://0.0.0.0:50124
* To create a public link, set `share=True` in `launch()`.


