# Preparation for Visualization

## Model Loading and Environment Configuration

### Environment Setup (Drive Mount & Project Clone)

In [1]:
# ============================
# 1. Mount Google Drive (optional)
# ============================
from google.colab import drive
drive.mount('/content/drive')

# ============================
# 2. 기본 작업 디렉토리 설정
# ============================
import os
from pathlib import Path

base_dir = Path("/content/drive/MyDrive/Colab Notebooks")
os.makedirs(base_dir, exist_ok=True)
%cd {base_dir}

# ============================
# 3. Repository clone or skip
# ============================
repo_url = "https://github.com/seo-1004/cv-team5-anomaly-detection.git"
repo_dir = base_dir / "cv-team5-anomaly-detection"

if not repo_dir.exists():
    print("Repository not found. Cloning...")
    !git clone {repo_url}
else:
    print("Repository already exists. Skipping clone.")

%cd {repo_dir}

# ============================
# 4. Model checkpoint 확인 후 setup.sh 실행
# ============================
model_path = "checkpoints/autoencoder/best_model_epoch_100.pth"

if not os.path.exists(model_path):
    print("Model not found. Running setup.sh to download/setup model...")
    !bash setup.sh
else:
    print("Model already exists. Skipping setup.")


Mounted at /content/drive
/content/drive/MyDrive/Colab Notebooks
Repository already exists. Skipping clone.
/content/drive/MyDrive/Colab Notebooks/cv-team5-anomaly-detection
Model already exists. Skipping setup.


### Environment & Model Initialization

In [3]:
# Environment & Model Initialization
import os
import sys
import cv2
import glob
import torch
import numpy as np
import gradio as gr
from functools import partial
import matplotlib.pyplot as plt
from src.autoencoder import load_model
from src.utils.image_io import denorm_to_uint8
from src.dataprep.transforms import NormalMapToTensor

print('-' * 126)
print(f"Python Version: {sys.version}")
print(f"PyTorch Version: {torch.__version__}")

# Select GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

print('-' * 126)
model_path = 'checkpoints/autoencoder/best_model_epoch_100.pth'
model, device = load_model(model_path)

------------------------------------------------------------------------------------------------------------------------------
Python Version: 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0]
PyTorch Version: 2.9.0+cpu
Device: cpu
------------------------------------------------------------------------------------------------------------------------------
[load_model] 모델 로드 완료: checkpoints/autoencoder/best_model_epoch_100.pth
[load_model] device: cpu


### Function Definitions

In [4]:
BASE_DIR = Path("/content/drive/MyDrive/Colab Notebooks/cv-team5-anomaly-detection")
%cd {BASE_DIR}

# The input size used by the Dataset
TARGET_SIZE = (256, 256)

GLOBAL_VMIN = 0.0
GLOBAL_VMAX = 0.10

# Transform instance identical to the one used in the Dataset (resize / interpolation)
_normal_transform = NormalMapToTensor(size=TARGET_SIZE)


# ---------------------------------------------------------
# 1. Preprocessing identical to the Dataset pipeline
# ---------------------------------------------------------
def preprocess_np_for_model(image_np):
    """
    image_np: (H, W, 3), uint8 in [0, 255]  (image received from Gradio)

    Dataset pipeline:
        Load .npy → [-1, 1], (H, W, 3) → NormalMapToTensor → (3, H, W), [-1, 1]

    Here, since the input is PNG:
        [0,255] → [0,1] → [-1,1],
    followed by passing through the same NormalMapToTensor transform.
    """
    # [0,255] → [0,1] → [-1,1]
    img = image_np.astype(np.float32) / 255.0
    img = img * 2.0 - 1.0  # Same scale as the normal_np used in the Dataset

    # Apply the same transform as in the Dataset (includes resizing)
    normal_tensor = _normal_transform(img)  # (3,H,W), [-1,1]

    return normal_tensor  # torch.Tensor


# ---------------------------------------------------------
# 2. Error-map computation identical to evaluator.py
# ---------------------------------------------------------
def forward_and_error(model, input_tensor, device):
    """
    input_tensor: (3,H,W), [-1,1]
    return:
      recon_np : (H,W,3) float [-1,1]
      heatmap  : (H,W)   float  (pixel-wise mean absolute error)
    """
    model.eval()

    with torch.no_grad():
        if input_tensor.dim() == 3:
            input_tensor = input_tensor.unsqueeze(0).to(device)  # (1,3,H,W)
        else:
            input_tensor = input_tensor.to(device)

        recon_tensor = model(input_tensor)

        # Convert tensors back to numpy (same as evaluator.py)
        input_np = (
            input_tensor[0]
            .detach()
            .cpu()
            .permute(1, 2, 0)
            .numpy()
        )  # (H,W,3), [-1,1]

        recon_np = (
            recon_tensor[0]
            .detach()
            .cpu()
            .permute(1, 2, 0)
            .numpy()
        )  # (H,W,3), [-1,1]

        # Error map computation (identical to evaluator.py)
        error_map_rgb = np.abs(input_np - recon_np)  # (H,W,3)
        heatmap = error_map_rgb.mean(axis=2)         # (H,W)

    return recon_np, heatmap


# ---------------------------------------------------------
# 3. Wrapper for visualization
# ---------------------------------------------------------
def infer(model, input_tensor, device):
    """
    input_tensor: (3, H, W), [-1,1]
    return:
        - recon_vis: (H,W,3), uint8 [0,255]  (for visualization)
        - heat_vis : (H,W,3), uint8 [0,255]  (colormap visualization)
        - heatmap  : (H,W), float (raw error values)
    """
    # 1) Forward pass + error-map computation (same as evaluate_model)
    recon_np, heatmap = forward_and_error(model, input_tensor, device)

    # 2) Convert reconstructed output into uint8 for visualization
    recon_vis = denorm_to_uint8(recon_np)  # (H,W,3) uint8

    # 3) Convert heatmap into a visual colormap (similar to evaluator’s cv2.normalize routine)
    h_min, h_max = GLOBAL_VMIN, GLOBAL_VMAX
    if h_max > h_min:
        heat_norm = (heatmap - h_min) / (h_max - h_min)
    else:
        heat_norm = np.zeros_like(heatmap)

    heat_uint8 = (heat_norm * 255).astype(np.uint8)
    heat_color = cv2.applyColorMap(heat_uint8, cv2.COLORMAP_JET)
    heat_vis = cv2.cvtColor(heat_color, cv2.COLOR_BGR2RGB)

    # Upsample for Gradio UI display (does not affect numerical values)
    recon_vis = cv2.resize(recon_vis, (512, 512), interpolation=cv2.INTER_LINEAR)
    heat_vis = cv2.resize(heat_vis, (512, 512), interpolation=cv2.INTER_NEAREST)

    return recon_vis, heat_vis, heatmap


# ---------------------------------------------------------
# 4. Function called by Gradio
# ---------------------------------------------------------
def infer_gradio(model, device, image_np):
    """
    Gradio wrapper.
      - image_np: (H,W,3) uint8  [0,255]
      - return: (recon_vis, matplotlib Figure)
    """
    # 1) Preprocess input identically to the Dataset pipeline
    input_tensor = preprocess_np_for_model(image_np)  # (3,H,W), [-1,1]

    # 2) Forward pass + error-map computation
    recon_vis, _, heatmap = infer(model, input_tensor, device)

    # 3) Create matplotlib figure
    fig, ax = plt.subplots()
    im = ax.imshow(
        heatmap,
        cmap="jet",
        vmin=GLOBAL_VMIN,
        vmax=GLOBAL_VMAX,
    )
    plt.colorbar(im, ax=ax)
    ax.set_title("Error Map")
    ax.axis("off")

    return recon_vis, fig


/content/drive/MyDrive/Colab Notebooks/cv-team5-anomaly-detection


# UI Construction and Execution

In [5]:

with gr.Blocks() as demo:
    gr.Markdown("## AutoEncoder-based Anomaly Detection Demo")

    with gr.Row():
        with gr.Column(scale=2):
            input_img = gr.Image(type="numpy", label="Input", width=512, height=512)
        with gr.Column(scale=2):
            gr.HTML("<div style='height:5px'></div>")
            recon_img = gr.Image(label="Reconstructed", width=512, height=412)
        with gr.Column(scale=2):
            gr.HTML("<div style='height:30px'></div>")
            error_plot = gr.Plot(label="Error Map")

    with gr.Row():
        clear_btn = gr.Button("Clear", scale=1)
        submit_btn = gr.Button("Submit", scale=1, variant="primary")

    infer_fn = partial(infer_gradio, model, device)

    submit_btn.click(
        fn=infer_fn,
        inputs=input_img,
        outputs=[recon_img, error_plot],
    )

    clear_btn.click(
        fn=lambda: (None, None, None),
        inputs=None,
        outputs=[input_img, recon_img, error_plot],
    )

demo.launch(
    share=True,
    debug=False,
    quiet=True,
    inline=False,
    inbrowser=True,
)


* Running on public URL: https://02a95f5048087be74e.gradio.live


