import os

img_dir = "./testure test images/"

test_images = [
    os.path.join(img_dir, 'image1.png'),
    os.path.join(img_dir, 'image2.png'),
]

ground_truths = [
    os.path.join(img_dir, 'image1_groundtruth.png'),
    os.path.join(img_dir, 'image2_groundtruth.png'),
]
### 擴增與訓練說明
- 若要調整擴增筆數，修改第二格中的 `N_SAMPLES`.
- 增強缺陷樣式可在 `add_synthetic_defects` 中加更多型態 (例如格狀、斜切片、局部模糊區)。
- 前處理 `texture_normalize` 可調 `open_size` 與 `fft_remove_ratio` 抑制重複花紋。
- 訓練迭代數由 `EPOCHS` 控制，增大可提升擬合，但要注意過擬合。
- 若有 GPU 會自動使用。若記憶體不足可降低 `batch_size`.
- 目前 `pred_mask` 以白色表示缺陷 (便於人工檢視)。若要與其他方法統一，可改成 0=缺陷,255=非缺陷。

In [5]:
# 擴增原始 2 張灰階影像至 100 筆並訓練輕量 U-Net
# 缺陷顯示規則: 0 = 缺陷(黑), 255 = 非缺陷(白)
# -------------------------------------------------------------
import os, random, cv2, numpy as np, matplotlib.pyplot as plt
random.seed(42); np.random.seed(42)
# Resolve base image directory robustly (typo handling)
candidates = ['./texture test images', './texture test images/', './testure test images', './testure test images/']
BASE_DIR = None
for _c in candidates:
    if os.path.isdir(_c):
        BASE_DIR = _c
        break
if BASE_DIR is None:
    raise FileNotFoundError('No texture image directory found among candidates: ' + ', '.join(candidates))
IMG_FILES = ['image1.png','image2.png']
GT_FILES  = ['image1_groundtruth.png','image2_groundtruth.png']
imgs = []; gts = []
for f, g in zip(IMG_FILES, GT_FILES):
    ip = os.path.join(BASE_DIR,f)
    gp = os.path.join(BASE_DIR,g)
    im = cv2.imread(ip, cv2.IMREAD_UNCHANGED)
    if im is None: raise FileNotFoundError(ip)
    if im.ndim==3: im = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
    gt = cv2.imread(gp, cv2.IMREAD_UNCHANGED) if os.path.isfile(gp) else None
    if gt is None: gt = np.ones_like(im,dtype=np.uint8)*255
    if gt.ndim==3: gt = cv2.cvtColor(gt, cv2.COLOR_BGR2GRAY)
    gt = cv2.resize(gt,(im.shape[1],im.shape[0]),interpolation=cv2.INTER_NEAREST)
    imgs.append(im); gts.append(gt)
print('Loaded base images:', len(imgs))
# -------- 紋理正規化 (改良版: 保留缺陷特徵) --------
def texture_normalize(gray, open_size=51):
    """使用 top-hat 變換保留局部異常，避免過度均質化"""
    k = max(3, open_size|1)
    ker = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(k,k))
    # Top-hat: 突顯比背景亮的區域 (某些缺陷)
    tophat = cv2.morphologyEx(gray, cv2.MORPH_TOPHAT, ker)
    # Black-hat: 突顯比背景暗的區域 (另一些缺陷)
    blackhat = cv2.morphologyEx(gray, cv2.MORPH_BLACKHAT, ker)
    # 組合兩者，保留雙向異常
    norm = cv2.add(gray, tophat)
    norm = cv2.subtract(norm, blackhat)
    # 輕微高斯濾波去除高頻噪音但保留缺陷邊緣
    norm = cv2.GaussianBlur(norm, (3,3), 0.5)
    # CLAHE 增強對比度而不破壞局部特徵
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    norm = clahe.apply(norm)
    return norm
# -------- 合成缺陷 (線/孔/斑點/裂紋) --------
def synth_defects(img, mask, max_ops=3):
    h,w = img.shape
    out = img.copy(); m = mask.copy()
    for _ in range(random.randint(1,max_ops)):
        t = random.choice(['line','hole','dots','crack'])
        if t=='line':
            x1,y1 = random.randint(0,w-1), random.randint(0,h-1)
            x2,y2 = random.randint(0,w-1), random.randint(0,h-1)
            thick = random.randint(2,7)
            cv2.line(out,(x1,y1),(x2,y2),color=random.randint(0,35),thickness=thick)
            cv2.line(m,(x1,y1),(x2,y2),color=0,thickness=thick)
        elif t=='hole':
            cx,cy = random.randint(0,w-1), random.randint(0,h-1); r=random.randint(10,28)
            cv2.circle(out,(cx,cy),r,color=random.randint(0,40),thickness=-1)
            cv2.circle(m,(cx,cy),r,color=0,thickness=-1)
        elif t=='dots':
            for __ in range(random.randint(8,25)):
                dx,dy = random.randint(0,w-1), random.randint(0,h-1)
                out[dy,dx]=random.randint(0,50); m[dy,dx]=0
        elif t=='crack':
            pts=[]; steps=random.randint(5,12); x,y=random.randint(0,w-1),random.randint(0,h-1)
            for __ in range(steps):
                pts.append((x,y)); x+=random.randint(-14,14); y+=random.randint(-14,14)
                x=max(0,min(w-1,x)); y=max(0,min(h-1,y))
            for a,b in zip(pts[:-1],pts[1:]):
                cv2.line(out,a,b,color=random.randint(0,30),thickness=random.randint(1,4))
                cv2.line(m,a,b,color=0,thickness=random.randint(1,4))
    return out,m
# -------- 幾何 + 強度擴增 --------
def aug_geom_intensity(img, mask):
    h,w = img.shape; k=random.randint(0,3)
    img2=np.rot90(img,k); mask2=np.rot90(mask,k)
    if random.random()<0.5: img2=np.flipud(img2); mask2=np.flipud(mask2)
    if random.random()<0.5: img2=np.fliplr(img2); mask2=np.fliplr(mask2)
    scale=random.uniform(0.85,1.25); nh,nw=int(h*scale),int(w*scale)
    img_rs=cv2.resize(img2,(nw,nh),interpolation=cv2.INTER_LINEAR); mask_rs=cv2.resize(mask2,(nw,nh),interpolation=cv2.INTER_NEAREST)
    canvas_i=np.full((h,w),np.mean(img_rs),dtype=img.dtype); canvas_m=np.full((h,w),255,dtype=mask.dtype)
    y0=(h-nh)//2; x0=(w-nw)//2; y1=y0+nh; x1=x0+nw
    y0=max(0,y0); x0=max(0,x0); y1=min(h,y1); x1=min(w,x1)
    img_crop=img_rs[:y1-y0,:x1-x0]; mask_crop=mask_rs[:y1-y0,:x1-x0]
    canvas_i[y0:y1,x0:x1]=img_crop; canvas_m[y0:y1,x0:x1]=mask_crop
    img2,mask2=canvas_i,canvas_m
    if random.random()<0.8: # gamma
        gamma=random.uniform(0.7,1.5); g=(img2/255.0)**gamma*255; img2=np.clip(g,0,255).astype(np.uint8)
    if random.random()<0.6: # noise
        noise=np.random.normal(0,random.uniform(3,12),img2.shape); img2=np.clip(img2+noise,0,255).astype(np.uint8)
    if random.random()<0.4: img2=cv2.GaussianBlur(img2,(3,3),0)
    if random.random()<0.5: # contrast/brightness
        alpha=random.uniform(0.85,1.3); beta=random.uniform(-20,20); img2=np.clip(alpha*img2+beta,0,255).astype(np.uint8)
    return img2,mask2
# -------- 生成擴增樣本 --------
N_SAMPLES=100
augX=[]; augY=[]
base_norm=[texture_normalize(im) for im in imgs]
for i in range(N_SAMPLES):
    idx=random.randint(0,len(base_norm)-1)
    im=base_norm[idx]; gt=gts[idx]
    im_a,gt_a=aug_geom_intensity(im,gt)
    if random.random()<0.9: im_a,gt_a=synth_defects(im_a,gt_a,max_ops=random.randint(1,4))
    augX.append(im_a); augY.append(gt_a)
print('Augmented samples:', len(augX))
# -------- 建立 PyTorch 資料集 --------
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
device='cuda' if torch.cuda.is_available() else 'cpu'; print('Device:',device)
class DefectDataset(Dataset):
    def __init__(self, imgs, masks): self.imgs=imgs; self.masks=masks
    def __len__(self): return len(self.imgs)
    def __getitem__(self,i):
        im=self.imgs[i].astype(np.float32)/255.0
        m=(self.masks[i]==0).astype(np.float32) # defect 1
        return torch.from_numpy(im[None,...]), torch.from_numpy(m[None,...])
# split train/val
indices=list(range(N_SAMPLES)); random.shuffle(indices)
val_n=int(N_SAMPLES*0.15); val_set=set(indices[:val_n])
trainX=[augX[i] for i in indices if i not in val_set]; trainY=[augY[i] for i in indices if i not in val_set]
valX=[augX[i] for i in indices if i in val_set]; valY=[augY[i] for i in indices if i in val_set]
train_ds=DefectDataset(trainX,trainY); val_ds=DefectDataset(valX,valY)
train_dl=DataLoader(train_ds,batch_size=8,shuffle=True); val_dl=DataLoader(val_ds,batch_size=8)
# -------- 輕量 U-Net --------
class Block(nn.Module):
    def __init__(self,i,o): super().__init__(); self.c=nn.Sequential(nn.Conv2d(i,o,3,padding=1),nn.BatchNorm2d(o),nn.ReLU(True),nn.Conv2d(o,o,3,padding=1),nn.BatchNorm2d(o),nn.ReLU(True))
    def forward(self,x): return self.c(x)
class TinyUNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.b1 = Block(1,32)
        self.b2 = Block(32,64)
        self.b3 = Block(64,128)
        self.p = nn.MaxPool2d(2)
        self.up2 = nn.ConvTranspose2d(128,64,2,2)
        self.d2 = Block(128,64)
        self.up1 = nn.ConvTranspose2d(64,32,2,2)
        self.d1 = Block(64,32)
        self.out = nn.Conv2d(32,1,1)
    def forward(self,x):
        e1 = self.b1(x)
        e2 = self.b2(self.p(e1))
        e3 = self.b3(self.p(e2))
        u2 = self.up2(e3)
        u2 = torch.cat([u2,e2],1)
        u2 = self.d2(u2)
        u1 = self.up1(u2)
        u1 = torch.cat([u1,e1],1)
        u1 = self.d1(u1)
        return self.out(u1)
def dice_loss(logits, targets, eps=1e-6):
    probs=torch.sigmoid(logits); num=2*(probs*targets).sum(); den=probs.sum()+targets.sum()+eps; return 1-num/den
model=TinyUNet().to(device); opt=torch.optim.Adam(model.parameters(),lr=1e-3); bce=nn.BCEWithLogitsLoss()
EPOCHS=10; train_hist=[]; val_hist=[]
for ep in range(1,EPOCHS+1):
    model.train(); tb=0; td=0
    for x,y in train_dl: x=x.to(device); y=y.to(device); opt.zero_grad(); o=model(x); lb=bce(o,y); ld=dice_loss(o,y); loss=lb+ld; loss.backward(); opt.step(); tb+=lb.item(); td+=ld.item()
    model.eval(); vb=0; vd=0
    with torch.no_grad():
        for x,y in val_dl: x=x.to(device); y=y.to(device); o=model(x); vb+=bce(o,y).item(); vd+=dice_loss(o,y).item()
    train_hist.append((tb/len(train_dl), td/len(train_dl))); val_hist.append((vb/len(val_dl), vd/len(val_dl)))
    print(f'Epoch {ep}/{EPOCHS} Train BCE {train_hist[-1][0]:.4f} Dice {train_hist[-1][1]:.4f} | Val BCE {val_hist[-1][0]:.4f} Dice {val_hist[-1][1]:.4f}')
# -------- 後處理: 去除噪點並膨脹大區域邊緣 --------
def postprocess_mask(mask_binary, min_area=50, dilate_kernel=5):
    """
    mask_binary: (H,W) 0/1 mask, 1=缺陷
    返回: 清理後的 mask (0/1)
    """
    # 連通域分析
    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask_binary, connectivity=8)
    cleaned = np.zeros_like(mask_binary)
    
    for i in range(1, num_labels):  # 跳過背景 (label 0)
        area = stats[i, cv2.CC_STAT_AREA]
        if area < min_area:
            # 小於閾值視為噪點，移除
            continue
        else:
            # 保留大區域並進行邊緣膨脹
            component = (labels == i).astype(np.uint8)
            kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dilate_kernel, dilate_kernel))
            dilated = cv2.dilate(component, kernel, iterations=1)
            cleaned = cv2.bitwise_or(cleaned, dilated)
    
    return cleaned

# -------- 生成預測與視覺化 --------
def predict(im):
    im_n=texture_normalize(im); t=torch.from_numpy(im_n.astype(np.float32)/255.0)[None,None,...].to(device)
    with torch.no_grad(): p=torch.sigmoid(model(t))[0,0].cpu().numpy()
    mask_raw=(p>0.5).astype(np.uint8) # 1=缺陷
    # 應用後處理
    mask_clean = postprocess_mask(mask_raw, min_area=30, dilate_kernel=5)
    mask_out=np.ones_like(im_n,dtype=np.uint8)*255; mask_out[mask_clean==1]=0
    return mask_out, p, mask_raw

fig,axs=plt.subplots(len(imgs),5,figsize=(22,5*len(imgs)))
if len(imgs)==1: axs=axs.reshape(1,5)
for i,(im,gt) in enumerate(zip(imgs,gts)):
    pm,prob,mask_raw=predict(im); im_norm=texture_normalize(im)
    gt_def=(gt==0); pred_def=(pm==0)
    raw_def=(mask_raw==1)
    tp=int(np.logical_and(gt_def,pred_def).sum()); fp=int(np.logical_and(~gt_def,pred_def).sum()); fn=int(np.logical_and(gt_def,~pred_def).sum())
    prec=tp/(tp+fp) if (tp+fp)>0 else 0; rec=tp/(tp+fn) if (tp+fn)>0 else 0; f1=(2*prec*rec)/(prec+rec) if (prec+rec)>0 else 0
    axs[i,0].imshow(im,cmap='gray'); axs[i,0].set_title(f'原始 {i+1}'); axs[i,0].axis('off')
    axs[i,1].imshow(im_norm,cmap='gray'); axs[i,1].set_title('改良紋理正規化'); axs[i,1].axis('off')
    # 顯示原始預測 (後處理前)
    mask_raw_display=np.ones_like(im,dtype=np.uint8)*255; mask_raw_display[raw_def]=0
    axs[i,2].imshow(mask_raw_display,cmap='gray',vmin=0,vmax=255); axs[i,2].set_title('原始預測(有噪點)'); axs[i,2].axis('off')
    # 顯示後處理結果
    axs[i,3].imshow(pm,cmap='gray',vmin=0,vmax=255); axs[i,3].set_title('後處理(黑=缺陷)'); axs[i,3].axis('off')
    overlay=np.stack([im]*3,axis=-1)
    ov=overlay.copy(); ov[gt_def]=[0,255,0]; ov[pred_def]=[255,0,0]
    axs[i,4].imshow(ov); axs[i,4].set_title(f'Overlay F1={f1:.3f}'); axs[i,4].axis('off')
plt.tight_layout(); plt.show()
# Loss 曲線
tr_b=[x[0] for x in train_hist]; tr_d=[x[1] for x in train_hist]; va_b=[x[0] for x in val_hist]; va_d=[x[1] for x in val_hist]
plt.figure(figsize=(10,4)); plt.subplot(1,2,1); plt.plot(tr_b,label='Train BCE'); plt.plot(va_b,label='Val BCE'); plt.legend(); plt.title('BCE Loss');
plt.subplot(1,2,2); plt.plot(tr_d,label='Train Dice'); plt.plot(va_d,label='Val Dice'); plt.legend(); plt.title('Dice Loss'); plt.tight_layout(); plt.show()
print('完成: 產生擴增資料並訓練模型。缺陷以黑色顯示。已加入後處理去除噪點並膨脹大區域邊緣。')

Loaded base images: 2
Augmented samples: 100
Device: cpu
Augmented samples: 100
Device: cpu
Epoch 1/10 Train BCE 0.5634 Dice 0.9035 | Val BCE 0.9049 Dice 0.9363
Epoch 1/10 Train BCE 0.5634 Dice 0.9035 | Val BCE 0.9049 Dice 0.9363
Epoch 2/10 Train BCE 0.4332 Dice 0.8792 | Val BCE 0.4231 Dice 0.9187
Epoch 2/10 Train BCE 0.4332 Dice 0.8792 | Val BCE 0.4231 Dice 0.9187


KeyboardInterrupt: 

In [None]:
# 實際測試資料推論並輸出結果 + 圖像化
import os, glob
from pathlib import Path
import cv2, numpy as np
import matplotlib.pyplot as plt

TEST_DIR = './real_test_images'
OUT_DIR = './results_real'
os.makedirs(OUT_DIR, exist_ok=True)

# 讀取所有影像 (支援常見副檔名)
exts = ('*.png','*.jpg','*.jpeg','*.bmp','*.tif','*.tiff')
files = []
for ext in exts:
    files += glob.glob(os.path.join(TEST_DIR, ext))
files = sorted(files)
print(f'Test images found: {len(files)} in {TEST_DIR}')

# 推論並儲存結果
summary = []
for fp in files:
    im = cv2.imread(fp, cv2.IMREAD_UNCHANGED)
    if im is None:
        print('Skip unreadable:', fp)
        continue
    if im.ndim==3:
        im = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
    mask_out, prob, mask_raw = predict(im)

    name = Path(fp).stem
    # 儲存後處理後的二值遮罩 (黑=缺陷, 白=背景)
    cv2.imwrite(os.path.join(OUT_DIR, f'{name}_mask.png'), mask_out)
    # 儲存原始預測 (0/1) 便於除錯
    raw_vis = np.ones_like(im, dtype=np.uint8)*255
    raw_vis[mask_raw==1] = 0
    cv2.imwrite(os.path.join(OUT_DIR, f'{name}_raw.png'), raw_vis)
    # 儲存機率圖 (0-255 灰階)
    prob_u8 = np.clip(prob*255, 0, 255).astype(np.uint8)
    cv2.imwrite(os.path.join(OUT_DIR, f'{name}_prob.png'), prob_u8)
    # 儲存疊圖
    overlay = np.stack([im]*3, axis=-1)
    ov = overlay.copy()
    ov[mask_out==0] = [255,0,0]
    cv2.imwrite(os.path.join(OUT_DIR, f'{name}_overlay.png'), ov)

    summary.append((name, fp))

# 視覺化：每張圖顯示 原始/原始預測/後處理遮罩/機率圖
n_show = len(summary)
if n_show>0:
    fig, axs = plt.subplots(n_show, 4, figsize=(16, 4*n_show))
    if n_show==1:
        axs = axs.reshape(1,4)
    for i in range(n_show):
        name, fp = summary[i]
        im = cv2.imread(fp, cv2.IMREAD_UNCHANGED)
        if im.ndim==3:
            im = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
        mask_out = cv2.imread(os.path.join(OUT_DIR, f'{name}_mask.png'), cv2.IMREAD_GRAYSCALE)
        raw_vis = cv2.imread(os.path.join(OUT_DIR, f'{name}_raw.png'), cv2.IMREAD_GRAYSCALE)
        prob_u8 = cv2.imread(os.path.join(OUT_DIR, f'{name}_prob.png'), cv2.IMREAD_GRAYSCALE)
        axs[i,0].imshow(im, cmap='gray'); axs[i,0].set_title(f'原始 {name}'); axs[i,0].axis('off')
        axs[i,1].imshow(raw_vis, cmap='gray', vmin=0, vmax=255); axs[i,1].set_title('原始預測'); axs[i,1].axis('off')
        axs[i,2].imshow(mask_out, cmap='gray', vmin=0, vmax=255); axs[i,2].set_title('後處理遮罩'); axs[i,2].axis('off')
        axs[i,3].imshow(prob_u8, cmap='gray'); axs[i,3].set_title('機率圖'); axs[i,3].axis('off')
    plt.tight_layout(); plt.show()

# 額外總覽：疊圖畫廊（最多 12 張），便於快速巡檢
max_gallery = min(12, n_show)
if max_gallery>0:
    cols = 4
    rows = int(np.ceil(max_gallery/cols))
    fig, axs = plt.subplots(rows, cols, figsize=(4*cols, 3*rows))
    axs = np.array(axs)
    axs = axs.reshape(rows, cols)
    for i in range(rows*cols):
        r = i//cols; c = i%cols
        if i < max_gallery:
            name, fp = summary[i]
            overlay = cv2.imread(os.path.join(OUT_DIR, f'{name}_overlay.png'), cv2.IMREAD_UNCHANGED)
            if overlay is None:
                axs[r,c].axis('off'); continue
            overlay_rgb = overlay[:,:,::-1] if overlay.ndim==3 else overlay
            axs[r,c].imshow(overlay_rgb)
            axs[r,c].set_title(name)
            axs[r,c].axis('off')
        else:
            axs[r,c].axis('off')
    plt.tight_layout(); plt.show()

print(f'Inference done. Outputs saved to {OUT_DIR}')

Test images found: 2 in ./real_test_images


NameError: name 'predict' is not defined