In [1]:
import os
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from collections import defaultdict
from PIL import Image
from tqdm import tqdm
import cv2
import pandas as pd

# Th√™m path ƒë·ªÉ import module
sys.path.insert(0, '/thiends/hdd2t/few_shot_model/few-shot-segmentation')

%matplotlib inline
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

  from .autonotebook import tqdm as notebook_tqdm


PyTorch version: 1.13.1+cu116
CUDA available: True


## 1. C·∫•u h√¨nh Dataset OTU_2D

In [2]:
# ================= CONFIG =================
DATA_ROOT = "/thiends/hdd2t/UniverSeg/OTU_2D"  # ƒêi·ªÅu ch·ªânh path ph√π h·ª£p
TRAIN_IMAGES = os.path.join(DATA_ROOT, "train1/Image/")
TRAIN_LABELS = os.path.join(DATA_ROOT, "train1/Label/")
VAL_IMAGES = os.path.join(DATA_ROOT, "validation1/Image/")
VAL_LABELS = os.path.join(DATA_ROOT, "validation1/Label/")
TRAIN_TXT = os.path.join(DATA_ROOT, "train.txt")
VAL_TXT = os.path.join(DATA_ROOT, "val.txt")
TRAIN_CLS = os.path.join(DATA_ROOT, "train_cls.txt")
VAL_CLS = os.path.join(DATA_ROOT, "val_cls.txt")

# SEnet y√™u c·∫ßu input size nh·∫•t ƒë·ªãnh (c√≥ th·ªÉ thay ƒë·ªïi)
RESIZE_TO = (256, 256)  # SEnet th∆∞·ªùng d√πng size l·ªõn h∆°n
NUM_CLASSES = 8
LABEL_NAMES = [f"Class {i}" for i in range(NUM_CLASSES)]

# Device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")
# ===========================================

Using device: cuda


## 2. Dataset v√† Utils

In [3]:
# Load class labels
def load_cls_labels(filepath):
    labels = {}
    with open(filepath) as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) == 2:
                filename = parts[0].replace('.JPG', '')
                cls = int(parts[1])
                labels[filename] = cls
    return labels

train_cls_labels = load_cls_labels(TRAIN_CLS)
val_cls_labels = load_cls_labels(VAL_CLS)

# ---------- Utils ----------
def process_image(image_path, resize_to):
    """Load v√† preprocess ·∫£nh cho SEnet (grayscale, 1 channel)"""
    try:
        img = Image.open(image_path).convert("L")  # Grayscale cho SEnet
        img = img.resize(resize_to, Image.BILINEAR)
        img = np.array(img, dtype=np.float32) / 255.0
        return img[np.newaxis, :, :]  # [1, H, W]
    except Exception as e:
        print(f"Error loading {image_path}: {e}")
        return None

def process_image_rgb(image_path, resize_to):
    """Load ·∫£nh RGB ƒë·ªÉ hi·ªÉn th·ªã"""
    try:
        img = Image.open(image_path).convert("RGB")
        img = img.resize(resize_to, Image.BILINEAR)
        img = np.array(img, dtype=np.float32) / 255.0
        return np.transpose(img, (2, 0, 1))  # [3, H, W]
    except Exception:
        return None

def process_mask(mask_path, resize_to):
    """Load mask v√† BINARY h√≥a: pixel > 0 ‚Üí 1.0, pixel = 0 ‚Üí 0.0"""
    try:
        mask = Image.open(mask_path).convert('L')
        mask = mask.resize(resize_to, Image.NEAREST)
        mask = np.array(mask, dtype=np.float32)
        # Binary h√≥a mask
        mask = (mask > 0).astype(np.float32)
        return mask
    except Exception:
        return None

# === VERIFY MASK LOADING ===
print("üîç Ki·ªÉm tra process_mask...")
if os.path.exists(TRAIN_LABELS):
    test_mask_path = os.path.join(TRAIN_LABELS, os.listdir(TRAIN_LABELS)[0])
    test_mask = process_mask(test_mask_path, RESIZE_TO)
    print(f"   Mask path: {test_mask_path}")
    print(f"   Mask unique values: {np.unique(test_mask)}")
    print(f"   Mask shape: {test_mask.shape}")
else:
    print(f"‚ö†Ô∏è TRAIN_LABELS path kh√¥ng t·ªìn t·∫°i: {TRAIN_LABELS}")

üîç Ki·ªÉm tra process_mask...
   Mask path: /thiends/hdd2t/UniverSeg/OTU_2D/train1/Label/1279.PNG
   Mask unique values: [0. 1.]
   Mask shape: (256, 256)


In [4]:
# ---------- Dataset cho SEnet ----------
class OTU2DDatasetSEnet:
    """
    Dataset cho SEnet Few-Shot Segmentation.
    SEnet y√™u c·∫ßu:
    - Query input: [1, H, W] grayscale
    - Support input: [2, H, W] = concat(grayscale_image, binary_mask)
    """
    def __init__(self, images_dir, labels_dir, ids_file, cls_labels, resize_to=RESIZE_TO):
        self.samples = []
        self.cls_labels = cls_labels
        self.images_dir = images_dir
        self.labels_dir = labels_dir
        self.resize_to = resize_to

        print("=" * 70)
        print(f"Loading OTU2DDatasetSEnet from {os.path.basename(images_dir)}...")
        print("=" * 70)

        with open(ids_file, 'r') as f:
            ids = [line.strip() for line in f if line.strip()]

        for id_ in ids:
            img_name = f"{id_}.JPG"
            mask_name = f"{id_}.PNG"
            img_path = os.path.join(images_dir, img_name)
            mask_path = os.path.join(labels_dir, mask_name)

            if not os.path.exists(img_path) or not os.path.exists(mask_path):
                continue

            cls = self.cls_labels.get(id_, None)
            if cls is None:
                continue

            # Load grayscale image cho SEnet
            img_gray = process_image(img_path, resize_to)
            if img_gray is None:
                continue

            # Load RGB cho visualization
            img_rgb = process_image_rgb(img_path, resize_to)

            mask = process_mask(mask_path, resize_to)
            if mask is None:
                continue

            if np.sum(mask) < 1:  # Skip n·∫øu kh√¥ng c√≥ mask
                continue

            self.samples.append({
                'img_gray': img_gray,      # [1, H, W]
                'img_rgb': img_rgb,        # [3, H, W] ho·∫∑c None
                'mask': mask,              # [H, W]
                'cls': cls,
                'img_path': img_path
            })

        print(f"Loaded {len(self.samples)} valid samples.")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        return {
            'img_gray': torch.from_numpy(sample['img_gray']).float(),
            'img_rgb': torch.from_numpy(sample['img_rgb']).float() if sample['img_rgb'] is not None else None,
            'mask': torch.from_numpy(sample['mask']).float(),
            'cls': sample['cls'],
            'img_path': sample['img_path']
        }
    
    def get_samples_by_class(self, cls_idx):
        """L·∫•y indices c·ªßa c√°c samples thu·ªôc class cls_idx"""
        return [i for i, s in enumerate(self.samples) if s['cls'] == cls_idx]

In [5]:
# Load datasets
support_pool = OTU2DDatasetSEnet(TRAIN_IMAGES, TRAIN_LABELS, TRAIN_TXT, train_cls_labels)
test_set = OTU2DDatasetSEnet(VAL_IMAGES, VAL_LABELS, VAL_TXT, val_cls_labels)

print(f"\nT·ªïng s·ªë ·∫£nh trong support pool: {len(support_pool)}")
print(f"T·ªïng s·ªë ·∫£nh trong test set: {len(test_set)}")

Loading OTU2DDatasetSEnet from ...
Loaded 1000 valid samples.
Loading OTU2DDatasetSEnet from ...
Loaded 469 valid samples.

T·ªïng s·ªë ·∫£nh trong support pool: 1000
T·ªïng s·ªë ·∫£nh trong test set: 469


In [6]:
# Ph√¢n t√≠ch ph√¢n b·ªë nh√£n
label_indices_support = defaultdict(list)
label_indices_test = defaultdict(list)

for idx in range(len(support_pool)):
    sample = support_pool.samples[idx]
    label_indices_support[sample['cls']].append(idx)

for idx in range(len(test_set)):
    sample = test_set.samples[idx]
    label_indices_test[sample['cls']].append(idx)

print("\n[INFO] Ph√¢n b·ªë nh√£n trong SUPPORT POOL:")
print(f"{'Class':<10} {'Name':<15} {'Count':<10} {'Percentage':<10}")
print("-" * 50)
for cls_idx in range(NUM_CLASSES):
    count = len(label_indices_support[cls_idx])
    pct = 100 * count / len(support_pool) if len(support_pool) > 0 else 0
    print(f"{cls_idx:<10} {LABEL_NAMES[cls_idx]:<15} {count:<10} {pct:.1f}%")

print("\n[INFO] Ph√¢n b·ªë nh√£n trong TEST SET:")
print(f"{'Class':<10} {'Name':<15} {'Count':<10} {'Percentage':<10}")
print("-" * 50)
for cls_idx in range(NUM_CLASSES):
    count = len(label_indices_test[cls_idx])
    pct = 100 * count / len(test_set) if len(test_set) > 0 else 0
    print(f"{cls_idx:<10} {LABEL_NAMES[cls_idx]:<15} {count:<10} {pct:.1f}%")


[INFO] Ph√¢n b·ªë nh√£n trong SUPPORT POOL:
Class      Name            Count      Percentage
--------------------------------------------------
0          Class 0         226        22.6%
1          Class 1         153        15.3%
2          Class 2         228        22.8%
3          Class 3         57         5.7%
4          Class 4         47         4.7%
5          Class 5         180        18.0%
6          Class 6         71         7.1%
7          Class 7         38         3.8%

[INFO] Ph√¢n b·ªë nh√£n trong TEST SET:
Class      Name            Count      Percentage
--------------------------------------------------
0          Class 0         110        23.5%
1          Class 1         66         14.1%
2          Class 2         108        23.0%
3          Class 3         31         6.6%
4          Class 4         19         4.1%
5          Class 5         87         18.6%
6          Class 6         33         7.0%
7          Class 7         15         3.2%


## 3. Load Model SEnet

In [8]:
# Import SEnet model
from few_shot_segmentor import FewShotSegmentorDoubleSDnet

# ============ CH·ªåN MODEL PATH ============
# Thay ƒë·ªïi path ƒë·∫øn model ƒë√£ train
MODEL_PATH = "saved_models/sne_position_all_type_spatial_fold2.pth.tar"  # V√≠ d·ª•

# Ho·∫∑c t·∫°o model m·ªõi (ch∆∞a train)
CREATE_NEW_MODEL = True  # Set False n·∫øu mu·ªën load pretrained
# =========================================

# Model parameters (theo settings.ini)
net_params = {
    'num_class': 2,
    'num_channels': 1,
    'num_filters': 64,
    'kernel_h': 5,
    'kernel_w': 5,
    'kernel_c': 1,
    'stride_conv': 1,
    'pool': 2,
    'stride_pool': 2,
    'se_block': 'NONE',
    'drop_out': 0
}

if CREATE_NEW_MODEL:
    print("üî® T·∫°o model SEnet m·ªõi (ch∆∞a train)...")
    model = FewShotSegmentorDoubleSDnet(net_params)
    print("‚úÖ Model created successfully!")
else:
    print(f"üìÇ Loading pretrained model t·ª´: {MODEL_PATH}")
    if os.path.exists(MODEL_PATH):
        model = torch.load(MODEL_PATH, map_location=DEVICE)
        print("‚úÖ Model loaded successfully!")
    else:
        print(f"‚ùå Model file kh√¥ng t·ªìn t·∫°i: {MODEL_PATH}")
        print("   ‚Üí T·∫°o model m·ªõi thay th·∫ø...")
        model = FewShotSegmentorDoubleSDnet(net_params)

model = model.to(DEVICE)
model.eval()

# In th√¥ng tin model
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nüìä Model Info:")
print(f"   Total parameters: {total_params:,}")
print(f"   Trainable parameters: {trainable_params:,}")

SyntaxError: unmatched ']' (modules.py, line 526)

## 4. Inference Functions

In [None]:
def prepare_senet_inputs(query_sample, support_samples):
    """
    Chu·∫©n b·ªã input cho SEnet:
    - query_input: [B, 1, H, W] - ·∫£nh grayscale c·ªßa query
    - condition_input: [B, 2, H, W] - concat(support_image, support_mask)
    
    Args:
        query_sample: dict v·ªõi 'img_gray' [1, H, W]
        support_samples: list of dicts v·ªõi 'img_gray' [1, H, W] v√† 'mask' [H, W]
    
    Returns:
        query_input: [1, 1, H, W]
        condition_input: [N, 2, H, W] n·∫øu N > 1, ho·∫∑c averaged condition
    """
    # Query input
    query_input = query_sample['img_gray'].unsqueeze(0)  # [1, 1, H, W]
    
    # Support/Condition inputs
    condition_inputs = []
    for sup in support_samples:
        sup_img = sup['img_gray']  # [1, H, W]
        sup_mask = sup['mask'].unsqueeze(0)  # [1, H, W]
        condition = torch.cat([sup_img, sup_mask], dim=0)  # [2, H, W]
        condition_inputs.append(condition)
    
    # Stack v√† average n·∫øu c√≥ nhi·ªÅu support
    condition_stack = torch.stack(condition_inputs, dim=0)  # [N, 2, H, W]
    
    # SEnet m·∫∑c ƒë·ªãnh nh·∫≠n 1 condition input, n√™n average ho·∫∑c ch·ªçn 1
    # ·ªû ƒë√¢y ta d√πng average c·ªßa c√°c support masks
    avg_condition = condition_stack.mean(dim=0, keepdim=True)  # [1, 2, H, W]
    
    return query_input, avg_condition


def senet_inference(model, query_sample, support_samples, device=DEVICE):
    """
    Th·ª±c hi·ªán inference v·ªõi SEnet.
    
    Returns:
        pred: prediction probabilities [H, W]
    """
    query_input, condition_input = prepare_senet_inputs(query_sample, support_samples)
    
    query_input = query_input.to(device)
    condition_input = condition_input.to(device)
    
    with torch.no_grad():
        # SEnet: forward(condition_input, query_input)
        # Output: softmax probabilities [B, 2, H, W]
        output = model(condition_input, query_input)
        
        # L·∫•y probability c·ªßa class 1 (foreground)
        pred = output[0, 1]  # [H, W]
    
    return pred.cpu()


def compute_metrics(pred, gt, threshold=0.5, smooth=1e-6):
    """T√≠nh Dice, IoU, Precision, Recall"""
    pred_bin = (pred > threshold).float()
    gt = gt.float()
    
    TP = (pred_bin * gt).sum()
    FP = (pred_bin * (1 - gt)).sum()
    FN = ((1 - pred_bin) * gt).sum()
    
    dice = (2 * TP + smooth) / (2 * TP + FP + FN + smooth)
    iou = (TP + smooth) / (TP + FP + FN + smooth)
    precision = (TP + smooth) / (TP + FP + smooth)
    recall = (TP + smooth) / (TP + FN + smooth)
    
    return {
        'dice': dice.item(),
        'iou': iou.item(),
        'precision': precision.item(),
        'recall': recall.item()
    }

## 5. Test v·ªõi 1 sample

In [None]:
# Test inference v·ªõi 1 sample
print("üß™ Test inference v·ªõi 1 sample...\n")

# Ch·ªçn query t·ª´ test set
query_idx = 0
query_sample = test_set[query_idx]
target_class = query_sample['cls']

print(f"Query image: {query_sample['img_path']}")
print(f"Target class: {target_class} ({LABEL_NAMES[target_class]})")

# L·∫•y support samples t·ª´ c√πng class
support_indices = label_indices_support[target_class][:5]  # L·∫•y 5 support
support_samples = [support_pool[i] for i in support_indices]

print(f"Number of support samples: {len(support_samples)}")

# Inference
pred = senet_inference(model, query_sample, support_samples)

# Compute metrics
gt = query_sample['mask']
metrics = compute_metrics(pred, gt)

print(f"\nüìä Metrics:")
print(f"   Dice:      {metrics['dice']:.4f}")
print(f"   IoU:       {metrics['iou']:.4f}")
print(f"   Precision: {metrics['precision']:.4f}")
print(f"   Recall:    {metrics['recall']:.4f}")

In [None]:
# Visualize k·∫øt qu·∫£
fig, axes = plt.subplots(1, 5, figsize=(20, 4))

# Query image (grayscale)
axes[0].imshow(query_sample['img_gray'][0], cmap='gray')
axes[0].set_title(f"Query Image\nClass: {LABEL_NAMES[target_class]}")
axes[0].axis('off')

# Ground Truth
axes[1].imshow(gt.numpy(), cmap='gray')
axes[1].set_title(f"Ground Truth\nPixels: {gt.sum():.0f}")
axes[1].axis('off')

# Prediction (heatmap)
axes[2].imshow(pred.numpy(), cmap='hot', vmin=0, vmax=1)
axes[2].set_title(f"Prediction\nmax: {pred.max():.3f}")
axes[2].axis('off')

# Binary prediction
pred_bin = (pred > 0.5).float()
axes[3].imshow(pred_bin.numpy(), cmap='gray')
axes[3].set_title(f"Binary Pred (>0.5)\nPixels: {pred_bin.sum():.0f}")
axes[3].axis('off')

# Overlay
if query_sample['img_rgb'] is not None:
    overlay = query_sample['img_rgb'].permute(1, 2, 0).numpy()
else:
    overlay = np.stack([query_sample['img_gray'][0]]*3, axis=-1)
axes[4].imshow(overlay)
axes[4].contour(gt.numpy(), colors='red', linewidths=2, levels=[0.5])
axes[4].contour(pred_bin.numpy(), colors='lime', linewidths=2, levels=[0.5])
axes[4].set_title(f"Overlay\nDice: {metrics['dice']:.3f}\nRed=GT, Green=Pred")
axes[4].axis('off')

plt.tight_layout()
plt.show()

## 6. ƒê√°nh gi√° theo s·ªë l∆∞·ª£ng Support (N-shot)

In [None]:
# ================== ƒê√ÅNH GI√Å THEO N ==================
N_LIST = [1, 2, 4, 8, 16, 32]
NUM_TEST_SAMPLES = min(100, len(test_set))  # Gi·ªõi h·∫°n ƒë·ªÉ test nhanh
THRESHOLD = 0.5

print("=" * 80)
print("ƒê√ÅNH GI√Å METRICS TRUNG B√åNH THEO S·ªê L∆Ø·ª¢NG SUPPORT (N)")
print("=" * 80)

results_by_N = {N: {'dice': [], 'iou': [], 'precision': [], 'recall': []} for N in N_LIST}

print(f"\nƒêang ƒë√°nh gi√° tr√™n {NUM_TEST_SAMPLES} ·∫£nh test...")

for idx in tqdm(range(NUM_TEST_SAMPLES), desc="Evaluating"):
    query_sample = test_set[idx]
    target_class = query_sample['cls']
    gt = query_sample['mask']
    
    # Skip n·∫øu mask qu√° nh·ªè
    if gt.sum() < 10:
        continue
    
    # L·∫•y support pool c·ªßa class n√†y
    class_support_indices = label_indices_support[target_class]
    
    if len(class_support_indices) == 0:
        continue
    
    for N in N_LIST:
        # L·∫•y N support samples (random)
        K = min(N, len(class_support_indices))
        selected_indices = np.random.choice(class_support_indices, size=K, replace=False)
        support_samples = [support_pool[i] for i in selected_indices]
        
        # Inference
        pred = senet_inference(model, query_sample, support_samples)
        
        # Compute metrics
        metrics = compute_metrics(pred, gt, threshold=THRESHOLD)
        
        for key in metrics:
            results_by_N[N][key].append(metrics[key])

# T√≠nh trung b√¨nh
avg_results = []
for N in N_LIST:
    row = {'N': N}
    for metric in ['dice', 'iou', 'precision', 'recall']:
        values = results_by_N[N][metric]
        row[metric] = np.mean(values) if values else 0
        row[f'{metric}_std'] = np.std(values) if values else 0
    avg_results.append(row)

df_avg = pd.DataFrame(avg_results)

# Hi·ªÉn th·ªã b·∫£ng
print("\n" + "=" * 80)
print("K·∫æT QU·∫¢ TRUNG B√åNH THEO N:")
print("=" * 80)
print(f"\n{'N':>4} | {'Dice':>14} | {'IoU':>14} | {'Precision':>14} | {'Recall':>14}")
print("-" * 70)
for _, row in df_avg.iterrows():
    print(f"{int(row['N']):>4} | {row['dice']:.4f}¬±{row['dice_std']:.3f} | "
          f"{row['iou']:.4f}¬±{row['iou_std']:.3f} | "
          f"{row['precision']:.4f}¬±{row['precision_std']:.3f} | "
          f"{row['recall']:.4f}¬±{row['recall_std']:.3f}")

In [None]:
# V·∫Ω bi·ªÉu ƒë·ªì
fig, axes = plt.subplots(1, 4, figsize=(20, 5))
metrics_names = ['dice', 'iou', 'precision', 'recall']
colors = ['#2ecc71', '#3498db', '#e74c3c', '#9b59b6']

for ax, metric, color in zip(axes, metrics_names, colors):
    means = df_avg[metric].values
    stds = df_avg[f'{metric}_std'].values
    
    ax.plot(N_LIST, means, 'o-', color=color, linewidth=2, markersize=8)
    ax.fill_between(N_LIST, means - stds, means + stds, alpha=0.2, color=color)
    ax.set_xlabel('Number of Support Images (N)', fontsize=12)
    ax.set_ylabel(metric.capitalize(), fontsize=12)
    ax.set_title(f'SEnet - {metric.capitalize()} vs N', fontsize=14, fontweight='bold')
    ax.set_xscale('log', base=2)
    ax.set_xticks(N_LIST)
    ax.set_xticklabels(N_LIST)
    ax.grid(True, alpha=0.3)
    ax.set_ylim(0, 1)

plt.suptitle('SEnet Few-Shot Segmentation - Performance by Number of Supports', 
             fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig('SEnet_metrics_by_N.png', dpi=200, bbox_inches='tight')
plt.show()

# L∆∞u CSV
df_avg.to_csv('SEnet_metrics_by_N.csv', index=False)
print(f"\n‚úÖ ƒê√£ l∆∞u k·∫øt qu·∫£ v√†o: SEnet_metrics_by_N.csv v√† SEnet_metrics_by_N.png")

## 7. ƒê√°nh gi√° t·ª´ng Class

In [None]:
# ================== ƒê√ÅNH GI√Å T·ª™NG CLASS ==================
N_SELECTED = 8  # S·ªë support images
NUM_SAMPLES_PER_CLASS = 30  # S·ªë ·∫£nh t·ªëi ƒëa ƒë·ªÉ ƒë√°nh gi√° m·ªói class

print("=" * 80)
print(f"ƒê√ÅNH GI√Å METRICS CHO T·ª™NG CLASS T·∫†I N = {N_SELECTED}")
print("=" * 80)

class_metrics = {c: {'dice': [], 'iou': [], 'precision': [], 'recall': [], 'count': 0} 
                 for c in range(NUM_CLASSES)}

for idx in tqdm(range(len(test_set)), desc="Evaluating per class"):
    query_sample = test_set[idx]
    cls_idx = query_sample['cls']
    gt = query_sample['mask']
    
    # Gi·ªõi h·∫°n s·ªë samples m·ªói class
    if class_metrics[cls_idx]['count'] >= NUM_SAMPLES_PER_CLASS:
        continue
    
    if gt.sum() < 10:
        continue
    
    class_support_indices = label_indices_support[cls_idx]
    if len(class_support_indices) == 0:
        continue
    
    K = min(N_SELECTED, len(class_support_indices))
    selected_indices = np.random.choice(class_support_indices, size=K, replace=False)
    support_samples = [support_pool[i] for i in selected_indices]
    
    pred = senet_inference(model, query_sample, support_samples)
    metrics = compute_metrics(pred, gt, threshold=THRESHOLD)
    
    for key in metrics:
        class_metrics[cls_idx][key].append(metrics[key])
    class_metrics[cls_idx]['count'] += 1

# T·∫°o DataFrame k·∫øt qu·∫£
results_per_class = []
for cls_idx in range(NUM_CLASSES):
    if class_metrics[cls_idx]['count'] == 0:
        continue
    
    row = {
        'Class': cls_idx,
        'Name': LABEL_NAMES[cls_idx],
        'Samples': class_metrics[cls_idx]['count'],
        'Dice': np.mean(class_metrics[cls_idx]['dice']),
        'Dice_std': np.std(class_metrics[cls_idx]['dice']),
        'IoU': np.mean(class_metrics[cls_idx]['iou']),
        'Precision': np.mean(class_metrics[cls_idx]['precision']),
        'Recall': np.mean(class_metrics[cls_idx]['recall']),
    }
    results_per_class.append(row)

df_class = pd.DataFrame(results_per_class)

# Hi·ªÉn th·ªã
print("\n" + "=" * 90)
print(f"K·∫æT QU·∫¢ METRICS CHO T·ª™NG CLASS (N = {N_SELECTED}):")
print("=" * 90)
print(f"\n{'Class':>6} | {'Name':>12} | {'Samples':>8} | {'Dice':>10} | {'IoU':>10} | {'Precision':>10} | {'Recall':>10}")
print("-" * 90)

for _, row in df_class.iterrows():
    print(f"{int(row['Class']):>6} | {row['Name']:>12} | {int(row['Samples']):>8} | "
          f"{row['Dice']:.4f} | {row['IoU']:.4f} | {row['Precision']:.4f} | {row['Recall']:.4f}")

# T·ªïng k·∫øt
print("-" * 90)
avg_dice = df_class['Dice'].mean()
avg_iou = df_class['IoU'].mean()
print(f"{'AVG':>6} | {'':>12} | {'':>8} | {avg_dice:.4f} | {avg_iou:.4f} | "
      f"{df_class['Precision'].mean():.4f} | {df_class['Recall'].mean():.4f}")

In [None]:
# V·∫Ω bar chart cho t·ª´ng class
if len(df_class) > 0:
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    metrics_to_plot = ['Dice', 'IoU', 'Precision', 'Recall']
    colors = ['#2ecc71', '#3498db', '#e74c3c', '#9b59b6']

    for ax, metric, color in zip(axes.flatten(), metrics_to_plot, colors):
        x = np.arange(len(df_class))
        means = df_class[metric].values
        
        bars = ax.bar(x, means, color=color, alpha=0.7, edgecolor='black')
        ax.set_xticks(x)
        ax.set_xticklabels([f"C{int(c)}" for c in df_class['Class']], fontsize=10)
        ax.set_xlabel('Class', fontsize=12)
        ax.set_ylabel(metric, fontsize=12)
        ax.set_title(f'SEnet - {metric} per Class (N={N_SELECTED})', fontsize=14, fontweight='bold')
        ax.axhline(y=df_class[metric].mean(), color='red', linestyle='--', linewidth=2, 
                   label=f'Avg: {df_class[metric].mean():.3f}')
        ax.legend(loc='lower right')
        ax.set_ylim(0, 1)
        ax.grid(True, alpha=0.3, axis='y')
        
        # Th√™m gi√° tr·ªã l√™n bars
        for bar, val in zip(bars, means):
            ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
                    f'{val:.2f}', ha='center', va='bottom', fontsize=9)

    plt.suptitle(f'SEnet Metrics per Class (N = {N_SELECTED})', fontsize=16, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.savefig(f'SEnet_metrics_per_class_N{N_SELECTED}.png', dpi=200, bbox_inches='tight')
    plt.show()
    
    # L∆∞u CSV
    df_class.to_csv(f'SEnet_metrics_per_class_N{N_SELECTED}.csv', index=False)
    print(f"\n‚úÖ ƒê√£ l∆∞u k·∫øt qu·∫£ v√†o: SEnet_metrics_per_class_N{N_SELECTED}.csv")
else:
    print("‚ö†Ô∏è Kh√¥ng c√≥ d·ªØ li·ªáu ƒë·ªÉ v·∫Ω bi·ªÉu ƒë·ªì")

## 8. Visualize k·∫øt qu·∫£ cho t·ª´ng Class

In [None]:
# Visualize 1 ·∫£nh m·ªói class v·ªõi N support kh√°c nhau
N_VIS = [1, 4, 16]  # S·ªë N ƒë·ªÉ so s√°nh
active_classes = [c for c in range(NUM_CLASSES) if len(label_indices_test[c]) > 0]

if len(active_classes) > 0:
    fig, axes = plt.subplots(len(N_VIS) + 2, len(active_classes), 
                              figsize=(4*len(active_classes), 4*(len(N_VIS)+2)))
    
    if len(active_classes) == 1:
        axes = axes[:, np.newaxis]
    
    for col, cls_idx in enumerate(active_classes):
        # L·∫•y 1 ·∫£nh test c·ªßa class n√†y
        test_idx = label_indices_test[cls_idx][0]
        query_sample = test_set[test_idx]
        gt = query_sample['mask']
        
        # Row 0: Query image
        if query_sample['img_rgb'] is not None:
            axes[0, col].imshow(query_sample['img_rgb'].permute(1, 2, 0).numpy())
        else:
            axes[0, col].imshow(query_sample['img_gray'][0], cmap='gray')
        axes[0, col].set_title(f"{LABEL_NAMES[cls_idx]}\nQuery Image", fontsize=10)
        axes[0, col].axis('off')
        
        # Row 1: Ground Truth
        axes[1, col].imshow(gt.numpy(), cmap='gray')
        axes[1, col].set_title(f"Ground Truth\nPixels: {gt.sum():.0f}", fontsize=10)
        axes[1, col].axis('off')
        
        # Rows 2+: Predictions v·ªõi N kh√°c nhau
        class_support_indices = label_indices_support[cls_idx]
        
        for row, N in enumerate(N_VIS):
            if len(class_support_indices) > 0:
                K = min(N, len(class_support_indices))
                selected_indices = class_support_indices[:K]
                support_samples = [support_pool[i] for i in selected_indices]
                
                pred = senet_inference(model, query_sample, support_samples)
                metrics = compute_metrics(pred, gt)
                
                axes[row+2, col].imshow(pred.numpy(), cmap='hot', vmin=0, vmax=1)
                axes[row+2, col].contour(gt.numpy(), colors='lime', linewidths=1, levels=[0.5])
                axes[row+2, col].set_title(f"N={N}, Dice={metrics['dice']:.3f}", fontsize=10)
            else:
                axes[row+2, col].text(0.5, 0.5, 'No support', ha='center', va='center', fontsize=12)
            axes[row+2, col].axis('off')
    
    plt.suptitle('SEnet Few-Shot Segmentation - Visualization per Class', 
                 fontsize=16, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.savefig('SEnet_visualization_per_class.png', dpi=200, bbox_inches='tight')
    plt.show()
else:
    print("‚ö†Ô∏è Kh√¥ng c√≥ class n√†o c√≥ d·ªØ li·ªáu test")

## 9. T·ªïng k·∫øt

In [None]:
print("\n" + "=" * 80)
print("üìä T·ªîNG K·∫æT SEnet Few-Shot Segmentation tr√™n OTU_2D")
print("=" * 80)

print(f"\nüìÅ Dataset:")
print(f"   ‚îú‚îÄ Support pool: {len(support_pool)} ·∫£nh")
print(f"   ‚îî‚îÄ Test set: {len(test_set)} ·∫£nh")

print(f"\nüîß Model:")
print(f"   ‚îú‚îÄ Type: FewShotSegmentorDoubleSDnet (SEnet)")
print(f"   ‚îú‚îÄ Parameters: {total_params:,}")
print(f"   ‚îî‚îÄ Input size: {RESIZE_TO}")

if len(df_avg) > 0:
    best_row = df_avg.loc[df_avg['dice'].idxmax()]
    print(f"\nüìà Best Performance:")
    print(f"   ‚îú‚îÄ Best N: {int(best_row['N'])}")
    print(f"   ‚îú‚îÄ Dice: {best_row['dice']:.4f}")
    print(f"   ‚îú‚îÄ IoU: {best_row['iou']:.4f}")
    print(f"   ‚îú‚îÄ Precision: {best_row['precision']:.4f}")
    print(f"   ‚îî‚îÄ Recall: {best_row['recall']:.4f}")

if len(df_class) > 0:
    best_class = df_class.loc[df_class['Dice'].idxmax()]
    worst_class = df_class.loc[df_class['Dice'].idxmin()]
    print(f"\nüèÜ Best class: {best_class['Name']} (Dice={best_class['Dice']:.4f})")
    print(f"‚ö†Ô∏è  Worst class: {worst_class['Name']} (Dice={worst_class['Dice']:.4f})")

print("\n" + "=" * 80)