In [1]:
!pip install streamlit opencv-python torch torchvision pillow pyngrok
!wget -q https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64
!chmod +x cloudflared-linux-amd64

Collecting streamlit
  Downloading streamlit-1.54.0-py3-none-any.whl.metadata (9.8 kB)
Collecting pyngrok
  Downloading pyngrok-7.5.0-py3-none-any.whl.metadata (8.1 kB)
Collecting cachetools<7,>=5.5 (from streamlit)
  Downloading cachetools-6.2.6-py3-none-any.whl.metadata (5.6 kB)
Collecting pydeck<1,>=0.8.0b4 (from streamlit)
  Downloading pydeck-0.9.1-py2.py3-none-any.whl.metadata (4.1 kB)
Downloading streamlit-1.54.0-py3-none-any.whl (9.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.1/9.1 MB[0m [31m31.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyngrok-7.5.0-py3-none-any.whl (24 kB)
Downloading cachetools-6.2.6-py3-none-any.whl (11 kB)
Downloading pydeck-0.9.1-py2.py3-none-any.whl (6.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.9/6.9 MB[0m [31m55.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pyngrok, cachetools, pydeck, streamlit
  Attempting uninstall: cachetools
    Found existing installation: cachet

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
!cp -r /content/drive/MyDrive/PCB_Dataset/efficientnet_pcb.pth /content/efficientnet_pcb.pth

In [52]:
%%writefile app.py
import streamlit as st
import cv2
import torch
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
from torchvision.models import efficientnet_b0
import torch.nn as nn
import tempfile
import os
import time

# ================= CONFIG =================
CLASSES = ["Missing_hole", "Mouse_bite", "Open_circuit", "Short", "Spur", "Spurious_copper"]
MODEL_PATH = "efficientnet_pcb.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

@st.cache_resource
def load_model():
    model = efficientnet_b0()
    model.classifier[1] = nn.Linear(model.classifier[1].in_features, len(CLASSES))
    if os.path.exists(MODEL_PATH):
        model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
    model.to(device)
    model.eval()
    return model

model = load_model()

transform = transforms.Compose([
    transforms.Grayscale(3),
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

def align_images(template, test):
    orb = cv2.ORB_create(5000)
    kp1, des1 = orb.detectAndCompute(template, None)
    kp2, des2 = orb.detectAndCompute(test, None)
    if des1 is None or des2 is None: return test
    bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
    matches = sorted(bf.match(des1, des2), key=lambda x: x.distance)
    good = matches[:150]
    if len(good) < 15: return test
    src_pts = np.float32([kp2[m.trainIdx].pt for m in good]).reshape(-1,1,2)
    dst_pts = np.float32([kp1[m.queryIdx].pt for m in good]).reshape(-1,1,2)
    H, _ = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
    if H is None: return test
    return cv2.warpPerspective(test, H, (template.shape[1], template.shape[0]))

def detect_defects(template, test):
    template = cv2.resize(template, (test.shape[1], test.shape[0]))
    t_blur = cv2.GaussianBlur(template, (5,5), 0)
    s_blur = cv2.GaussianBlur(test, (5,5), 0)
    diff = cv2.absdiff(s_blur, t_blur)
    _, mask = cv2.threshold(diff, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    kernel = np.ones((3,3), np.uint8)
    mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=1)
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    rois, boxes = [], []
    for c in contours:
        if cv2.contourArea(c) > 100:
            x,y,w,h = cv2.boundingRect(c)
            roi = test[max(0, y-20):min(test.shape[0], y+h+20), max(0, x-20):min(test.shape[1], x+w+20)]
            if roi.size > 0:
                rois.append(roi)
                boxes.append((x,y,w,h))
    return rois, boxes, mask

# ================= PIPELINE WITH LABELED BACKGROUNDS =================
def run_pipeline(template_path, test_path, conf_level):
    start_time = time.time()
    template = cv2.imread(template_path, 0)
    test = cv2.imread(test_path, 0)
    test_aligned = align_images(template, test)
    rois, boxes, mask = detect_defects(template, test_aligned)
    vis = cv2.cvtColor(test_aligned, cv2.COLOR_GRAY2BGR)
    count = 0

    for i, roi in enumerate(rois):
        pil = Image.fromarray(roi)
        img = transform(pil).unsqueeze(0).to(device)
        with torch.no_grad():
            prob = torch.softmax(model(img), dim=1)
            conf, pred = torch.max(prob, 1)

        if conf.item() > conf_level:
            count += 1
            x, y, w, h = boxes[i]
            label = f"{CLASSES[pred.item()]} {conf.item():.2f}"

            # --- Draw Bounding Box ---
            cv2.rectangle(vis, (x, y), (x + w, y + h), (0, 255, 0), 3)

            # --- Draw Label Background ---
            font = cv2.FONT_HERSHEY_SIMPLEX
            font_scale = 0.7
            thickness = 2
            (label_w, label_h), baseline = cv2.getTextSize(label, font, font_scale, thickness)

            # Draw a solid red rectangle for the label background
            # Positions label above the box; if too high, it flips inside the box
            text_y = y - 10 if y - 10 > label_h else y + label_h + 10
            cv2.rectangle(vis, (x, text_y - label_h - baseline), (x + label_w, text_y + baseline), (0, 0, 255), cv2.FILLED)

            # Draw White Text on the Red Background
            cv2.putText(vis, label, (x, text_y), font, font_scale, (255, 255, 255), thickness)

    return vis, count, mask, time.time() - start_time

# ================= UI =================
st.set_page_config(page_title="PCB Defect Analysis", layout="wide")
st.title("PCB Defect Detection System")

conf_level = st.sidebar.slider("Model Confidence Threshold", 0.1, 0.9, 0.3)

col1, col2 = st.columns(2)
with col1:
    template_file = st.file_uploader("Upload TEMPLATE", type=["jpg","png","jpeg"])
with col2:
    test_file = st.file_uploader("Upload TEST", type=["jpg","png","jpeg"])

if test_file and template_file:
    with tempfile.NamedTemporaryFile(delete=False) as t1, tempfile.NamedTemporaryFile(delete=False) as t2:
        t1.write(template_file.read()); t_path = t1.name
        t2.write(test_file.read()); s_path = t2.name

    if st.button("Run Analysis"):
        res, total, mask, duration = run_pipeline(t_path, s_path, conf_level)
        st.success(f"Analysis complete in {duration:.3f}s | Defects: {total}")

        c1, c2 = st.columns(2)
        c1.image(mask, caption="Difference Mask", use_container_width=True)
        c2.image(cv2.cvtColor(res, cv2.COLOR_BGR2RGB), caption="Final Result", use_container_width=True)

        # Convert the OpenCV BGR image to a format Streamlit can download
        _, encoded_img = cv2.imencode('.png', res)

        st.download_button(
            label="Download Result",
            data=encoded_img.tobytes(),
            file_name="pcb_defect_result.png",
            mime="image/png"
        )

Overwriting app.py


In [53]:
import subprocess
import time
import re
import os
import requests

# --- PRE-FLIGHT CHECK ---
MODEL_PATH = "efficientnet_pcb.pth"
if not os.path.exists(MODEL_PATH):
    print(f"ERROR: {MODEL_PATH} not found! The app will crash.")
    print("Ensure you have run the '!cp' command from your Drive first.")

# 1. Kill any existing processes to prevent port conflicts
!pkill streamlit
!pkill cloudflared

# 2. Start Streamlit in the background
print("Starting Streamlit...")
with open("streamlit_log.txt", "w") as f:
    streamlit_proc = subprocess.Popen(
        ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"],
        stdout=f,
        stderr=f
    )

# 3. Wait for Streamlit to actually be "Live" (Prevents 502)
max_retries = 30
success = False
print("Waiting for Streamlit to initialize...")
for i in range(max_retries):
    try:
        response = requests.get("http://localhost:8501")
        if response.status_code == 200:
            print("Streamlit is UP!")
            success = True
            break
    except:
        time.sleep(2)
        if i % 5 == 0: print(f"   ...still waiting ({i}/{max_retries})")

if not success:
    print("Streamlit failed to start. Check streamlit_log.txt")
else:
    # 4. Start Cloudflare Tunnel
    print("Starting Cloudflare Tunnel...")
    cloudflared = subprocess.Popen(
        ["./cloudflared-linux-amd64", "tunnel", "--url", "http://localhost:8501", "--no-autoupdate"],
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True
    )

    # 5. Extract and print the URL
    url_pattern = re.compile(r"https://.*trycloudflare.com")
    for line in cloudflared.stdout:
        match = url_pattern.search(line)
        if match:
            print("\n" + "═"*50)
            print("  YOUR PUBLIC APP LINK:")
            print(f"  {match.group(0)}")
            print("═"*50 + "\n")
            break

Starting Streamlit...
Waiting for Streamlit to initialize...
   ...still waiting (0/30)
Streamlit is UP!
Starting Cloudflare Tunnel...

══════════════════════════════════════════════════
  YOUR PUBLIC APP LINK:
  https://cartoons-printing-governments-clothes.trycloudflare.com
══════════════════════════════════════════════════

