In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Enhanced FLAVA Defect Segmentation Training with Multi-scale Features and Advanced Evaluation
# Compatible with Google Colab environment

import os
import json
import torch
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
from torchvision import transforms
from torchvision.transforms import Resize
from torch.utils.data import Dataset, DataLoader
from torch import nn
import torch.nn.functional as F
from transformers import FlavaModel, FlavaProcessor
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.manifold import TSNE
from scipy.stats import ttest_ind, pearsonr

# ==== CONFIG ====
class Config:
    base_model_path = "/content/drive/MyDrive/flava_finetuned"
    data_path = "/content/drive/MyDrive/Data12 class segmentation"
    save_path = "/content/drive/MyDrive/new_flava/attention_seg_head_enhanced"
    debug_dir = os.path.join(save_path, "debug")
    plots_dir = os.path.join(save_path, "plots")
    metrics_dir = os.path.join(save_path, "metrics")
    analysis_dir = os.path.join(save_path, "advanced_analysis")
    gradcam_dir = os.path.join(save_path, "gradcam")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    batch_size = 4
    num_epochs = 10
    lr = 2e-5
    patch_grid = 14
    mask_size = (14, 14)
    use_multi_scale = True  # Use multi-scale features from multiple transformer layers
    use_advanced_loss = True  # Use FocalDiceLoss instead of BCE
    test_split = 0.2  # Portion of data to use for validation
    synthetic_multi_scale = True  # Use synthetic multi-scale features if real ones not available

# Create necessary directories
for directory in [Config.save_path, Config.debug_dir, Config.plots_dir, Config.metrics_dir, Config.analysis_dir, Config.gradcam_dir]:
    os.makedirs(directory, exist_ok=True)

print(f"Using device: {Config.device}")

# ==== DATASET ====
class MaskedDataset(Dataset):
    def __init__(self, data_dir):
        self.imgs, self.masks = [], []
        self.img_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])
        self.mask_resize = Resize(Config.mask_size)

        for cls in os.listdir(data_dir):
            p = os.path.join(data_dir, cls)
            if not os.path.isdir(p): continue
            for f in os.listdir(p):
                if f.endswith(".json"):
                    img = os.path.join(p, f.replace(".json", ".jpg"))
                    jsn = os.path.join(p, f)
                    if os.path.exists(img):
                        self.imgs.append(img)
                        self.masks.append(jsn)

        print(f"Found {len(self.imgs)} images with masks")

    def __len__(self): return len(self.imgs)

    def __getitem__(self, i):
        img = Image.open(self.imgs[i]).convert("RGB")
        img_tensor = self.img_transform(img)

        mask_arr = np.zeros((640, 640), dtype=np.uint8)
        with open(self.masks[i]) as f:
            data = json.load(f)
            for ann in data.get('annotations', []):
                x, y, w, h = map(int, ann['bbox'])
                x2, y2 = min(x+w, 640), min(y+h, 640)
                mask_arr[y:y2, x:x2] = 1
        mask = self.mask_resize(Image.fromarray(mask_arr * 255))
        mask_tensor = transforms.ToTensor()(mask).float().squeeze(0)

        return img_tensor, mask_tensor, self.imgs[i]

# ==== LOSS FUNCTION ====
class FocalDiceLoss(nn.Module):
    def __init__(self, alpha=0.5, gamma=2.0, beta=0.5):
        super().__init__()
        self.alpha = alpha  # Focal loss weight
        self.gamma = gamma  # Focal loss focusing parameter
        self.beta = beta    # Weight between BCE and Dice loss

    def forward(self, inputs, targets):
        # Binary cross entropy with logits
        bce_loss = nn.functional.binary_cross_entropy_with_logits(
            inputs, targets, reduction='none'
        )

        # Focal term
        probs = torch.sigmoid(inputs)
        pt = torch.where(targets == 1, probs, 1-probs)
        focal_weight = (1-pt) ** self.gamma
        focal_loss = focal_weight * bce_loss

        # Dice loss
        inputs_sigmoid = torch.sigmoid(inputs)
        intersection = (inputs_sigmoid * targets).sum((1,2))
        union = (inputs_sigmoid + targets).sum((1,2))
        dice_loss = 1 - (2. * intersection + 1e-6) / (union + 1e-6)

        # Combine losses
        combined_loss = self.beta * focal_loss.mean() + (1-self.beta) * dice_loss.mean()
        return combined_loss

# ==== SEGMENTATION HEAD ====
class FLAVASegmenter(nn.Module):
    def __init__(self, base_model_path):
        super().__init__()
        self.model = FlavaModel.from_pretrained(base_model_path)
        # Extract features from multiple transformer layers
        self.use_multi_scale = Config.use_multi_scale

        # Add projection layers for each scale
        self.projections = nn.ModuleList([
            nn.Linear(self.model.config.hidden_size, 256)
            for _ in range(4)  # Use last 4 layers
        ])

        # Fusion layer for multi-scale features
        self.fusion = nn.Conv2d(256*4, 256, kernel_size=1) if self.use_multi_scale else None

        # For synthetic multi-scale if real ones not available
        if Config.synthetic_multi_scale:
            # Projection for synthetic multi-scale features
            self.syn_projection = nn.Linear(self.model.config.hidden_size, 256)
            # Fusion for synthetic features (3 scales: original, 2x2 pooled, 4x4 pooled)
            self.syn_fusion = nn.Conv2d(256*3, 256, kernel_size=1)

        # Segmentation head with convolutional layers
        self.head = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 1, kernel_size=1)
        )

        # Store activation maps for GradCAM
        self.activation = {}
        self.gradients = {}
        self._register_hooks()

        # Flag to indicate if we had to fall back to single-scale
        self.using_fallback = False
        self.using_synthetic = False

    def _register_hooks(self):
        # Register hooks for GradCAM
        def get_activation(name):
            def hook(module, input, output):
                self.activation[name] = output
            return hook

        def get_gradient(name):
            def hook(module, grad_in, grad_out):
                self.gradients[name] = grad_out[0]
            return hook

        # Register hooks for the last convolutional layer
        self.head[0].register_forward_hook(get_activation('conv_feature'))
        self.head[0].register_backward_hook(get_gradient('conv_feature'))

    def create_synthetic_multi_scale(self, embeddings):
        """Create synthetic multi-scale features from a single embedding layer"""
        # Original scale
        b, n, c = embeddings.shape

        # Project embeddings
        projected = self.syn_projection(embeddings)

        # Reshape to spatial dimensions (excluding CLS token)
        spatial = projected.reshape(b, Config.patch_grid, Config.patch_grid, -1).permute(0, 3, 1, 2)

        # Create multi-scale features
        features = [spatial]  # Original scale

        # Scale 2: Pooled 2x2
        pooled2 = F.avg_pool2d(spatial, kernel_size=2, stride=2)
        upsampled2 = F.interpolate(pooled2, size=(Config.patch_grid, Config.patch_grid), mode='bilinear', align_corners=False)
        features.append(upsampled2)

        # Scale 3: Pooled 4x4
        pooled4 = F.avg_pool2d(spatial, kernel_size=4, stride=4)
        upsampled4 = F.interpolate(pooled4, size=(Config.patch_grid, Config.patch_grid), mode='bilinear', align_corners=False)
        features.append(upsampled4)

        # Concatenate along channel dimension
        multi_scale = torch.cat(features, dim=1)

        # Fuse multi-scale features
        fused = self.syn_fusion(multi_scale)

        return fused

    def forward(self, pixel_inputs):
        # Get outputs with attention (for visualization)
        outputs = self.model(
            pixel_values=pixel_inputs,
            output_hidden_states=self.use_multi_scale,
            output_attentions=True
        )

        # Store attention maps for visualization if available
        self.last_attentions = outputs.image_attentions if hasattr(outputs, 'image_attentions') else None

        # Check if multi-scale is actually available in this model version
        multi_scale_available = hasattr(outputs, 'image_hidden_states') and outputs.image_hidden_states is not None

        # Log the first time we have to fall back
        if self.use_multi_scale and not multi_scale_available and not self.using_fallback:
            print("WARNING: Multi-scale features not available in this FLAVA model version.")
            print("Available outputs:", list(outputs.keys()))
            print("Falling back to single-scale features.")
            self.using_fallback = True

        if self.use_multi_scale and multi_scale_available:
            # Use last 4 layers
            hidden_states = outputs.image_hidden_states[-4:]
            multi_scale_features = []

            for i, hidden_state in enumerate(hidden_states):
                # Skip CLS token
                patches = hidden_state[:, 1:, :]
                b, n, c = patches.shape
                # Project and reshape to spatial dimensions
                projected = self.projections[i](patches)
                spatial = projected.reshape(b, Config.patch_grid, Config.patch_grid, -1).permute(0, 3, 1, 2)
                multi_scale_features.append(spatial)

            # Concatenate features along channel dimension
            fused_features = torch.cat(multi_scale_features, dim=1)
            # Fuse multi-scale features
            fused_features = self.fusion(fused_features)
            # Apply segmentation head
            seg_logits = self.head(fused_features)

        elif Config.synthetic_multi_scale and not self.using_synthetic:
            # Use synthetic multi-scale features
            print("Using synthetic multi-scale features")
            self.using_synthetic = True

            # Get image embeddings
            patches = outputs.image_embeddings[:, 1:, :]  # Skip CLS token

            # Create synthetic multi-scale features
            fused_features = self.create_synthetic_multi_scale(patches)

            # Apply segmentation head
            seg_logits = self.head(fused_features)

        else:
            # Single-scale approach using image embeddings
            patches = outputs.image_embeddings[:, 1:, :]  # Skip CLS token
            b, n, c = patches.shape
            projected = self.projections[0](patches)
            spatial = projected.reshape(b, Config.patch_grid, Config.patch_grid, -1).permute(0, 3, 1, 2)
            seg_logits = self.head(spatial)

        return seg_logits

    def get_gradcam(self, target_layer='conv_feature'):
        """Generate GradCAM heatmap for interpretability"""
        if target_layer not in self.activation or target_layer not in self.gradients:
            print(f"Warning: {target_layer} activations or gradients not found")
            return None

        # Get activations and gradients for the target layer
        activations = self.activation[target_layer]
        gradients = self.gradients[target_layer]

        # Global average pooling of gradients
        weights = torch.mean(gradients, dim=(2, 3), keepdim=True)

        # Weighted sum of activation maps
        cam = torch.sum(weights * activations, dim=1, keepdim=True)

        # Apply ReLU to focus on features that have a positive influence
        cam = F.relu(cam)

        # Normalize
        if torch.max(cam) > 0:
            cam = cam / torch.max(cam)

        return cam

# ==== EVALUATION FUNCTIONS ====
def calculate_metrics(preds, masks):
    """Calculate comprehensive evaluation metrics"""
    # IoU
    intersection = (preds * masks).sum((1,2))
    union = ((preds + masks) >= 1).float().sum((1,2))
    batch_iou = (intersection / (union + 1e-6))

    # Dice
    dice = (2 * intersection) / (preds.sum((1,2)) + masks.sum((1,2)) + 1e-6)

    # Precision, Recall, F1
    tp = (preds * masks).sum((1,2))
    fp = (preds * (1-masks)).sum((1,2))
    fn = ((1-preds) * masks).sum((1,2))
    tn = ((1-preds) * (1-masks)).sum((1,2))

    precision = tp / (tp + fp + 1e-6)
    recall = tp / (tp + fn + 1e-6)
    f1 = 2 * precision * recall / (precision + recall + 1e-6)
    accuracy = (tp + tn) / (tp + tn + fp + fn + 1e-6)
    specificity = tn / (tn + fp + 1e-6)

    metrics = {
        'iou': batch_iou.mean().item(),
        'dice': dice.mean().item(),
        'precision': precision.mean().item(),
        'recall': recall.mean().item(),
        'f1': f1.mean().item(),
        'accuracy': accuracy.mean().item(),
        'specificity': specificity.mean().item()
    }

    return metrics

def evaluate_model(model, dataloader, processor):
    """Evaluate model on a dataset"""
    model.eval()
    all_metrics = []

    with torch.no_grad():
        for imgs, masks, _ in tqdm(dataloader, desc="Evaluating"):
            imgs, masks = imgs.to(Config.device), masks.to(Config.device)
            pixel_inputs = processor(images=imgs, return_tensors="pt")["pixel_values"].to(Config.device)

            logits = model(pixel_inputs).squeeze(1)
            preds = (torch.sigmoid(logits) > 0.5).float()

            batch_metrics = calculate_metrics(preds, masks)
            all_metrics.append(batch_metrics)

    # Calculate mean metrics across all batches
    results = {}
    for metric in all_metrics[0].keys():
        results[metric] = np.mean([m[metric] for m in all_metrics])

    return results

# ==== VISUALIZATION FUNCTIONS ====
def visualize_attention(model, image_tensor, processor, save_path, img_path=None):
    """
    Visualize and analyze attention maps from FLAVA for interpretability analysis

    Args:
        model: The FLAVA segmentation model
        image_tensor: Input image tensor
        processor: FLAVA processor
        save_path: Path to save visualizations
        img_path: Optional original image path for naming

    Returns:
        Dictionary with attention statistics
    """
    model.eval()
    with torch.no_grad():
        # Ensure batch dimension
        if len(image_tensor.shape) == 3:
            image_tensor = image_tensor.unsqueeze(0)

        # Get pixel values using the processor
        pixel_values = processor(images=image_tensor, return_tensors="pt")["pixel_values"].to(Config.device)

        # Forward pass with attention outputs
        outputs = model.model(pixel_values=pixel_values, output_attentions=True)

        # Check if attention maps are available
        if not hasattr(outputs, 'image_attentions') or outputs.image_attentions is None:
            print("WARNING: No attention maps available for visualization.")
            attention_stats = {
                'mean_entropy': 0.0,
                'min_entropy': 0.0,
                'max_entropy': 0.0,
                'num_patches_50pct': 0,
                'num_patches_90pct': 0,
                'attention_concentration': 0.0
            }
            return attention_stats

        # Get segmentation prediction
        logits = model(pixel_values).squeeze(1)
        pred = (torch.sigmoid(logits) > 0.5).float()

        # Extract attention patterns
        attentions = outputs.image_attentions  # List of tensors [batch, heads, seq_len, seq_len]

        # Get number of attention layers and heads
        num_layers = len(attentions)
        num_heads = attentions[0].shape[1]

        # Create base filename
        base_filename = os.path.basename(img_path) if img_path else "attention_analysis"

        # Create figure for attention visualization across layers
        plt.figure(figsize=(15, num_layers * 3))

        # 1. Analyze attention from CLS token to patches across layers
        cls_attentions = []
        for layer_idx, layer_attn in enumerate(attentions):
            # Average across heads for CLS token attention
            cls_attn = layer_attn[0, :, 0, 1:].mean(0)  # [seq_len-1]
            cls_attn_map = cls_attn.reshape(Config.patch_grid, Config.patch_grid).cpu().numpy()
            cls_attentions.append(cls_attn_map)

            plt.subplot(num_layers, 3, layer_idx*3 + 1)
            plt.imshow(cls_attn_map, cmap='viridis')
            plt.title(f"Layer {layer_idx+1}: CLS Attention")
            plt.colorbar(fraction=0.046, pad=0.04)
            plt.axis('off')

        # 2. Create attention heatmap overlaid on the original image
        image_np = image_tensor[0].cpu().permute(1, 2, 0).numpy()

        # Use last layer's attention for overlay
        last_layer_cls_attn = cls_attentions[-1]

        # Resize attention map to match image size
        h, w = image_np.shape[:2]
        attn_resized = np.array(Image.fromarray(last_layer_cls_attn).resize((w, h)))

        # Normalize attention for visualization
        attn_normalized = (attn_resized - attn_resized.min()) / (attn_resized.max() - attn_resized.min() + 1e-8)

        # Create attention heatmap overlay
        plt.subplot(num_layers, 3, 2)
        plt.imshow(image_np)
        plt.imshow(attn_normalized, cmap='hot', alpha=0.5)
        plt.title("Attention Overlay")
        plt.axis('off')

        # 3. Compare attention with segmentation prediction
        plt.subplot(num_layers, 3, 3)
        plt.imshow(pred[0].cpu().numpy(), cmap='gray')
        plt.title("Segmentation Prediction")
        plt.axis('off')

        # Save the combined visualization
        plt.tight_layout()
        plt.savefig(os.path.join(save_path, f"{base_filename}_attention_layers.png"), dpi=300)
        plt.close()

        # 4. Analyze attention heads in the last layer
        plt.figure(figsize=(15, 10))
        n_cols = 4
        n_rows = (num_heads + n_cols - 1) // n_cols

        for head_idx in range(num_heads):
            head_attn = attentions[-1][0, head_idx, 0, 1:].reshape(Config.patch_grid, Config.patch_grid).cpu().numpy()

            plt.subplot(n_rows, n_cols, head_idx + 1)
            plt.imshow(head_attn, cmap='viridis')
            plt.title(f"Head {head_idx + 1}")
            plt.axis('off')

        plt.tight_layout()
        plt.savefig(os.path.join(save_path, f"{base_filename}_attention_heads.png"), dpi=300)
        plt.close()

        # 5. Calculate attention statistics
        attention_stats = {}

        # Mean attention entropy (measures attention distribution)
        entropy_values = []
        for layer_idx in range(num_layers):
            for head_idx in range(num_heads):
                # Get attention distribution from CLS token
                attn_dist = attentions[layer_idx][0, head_idx, 0, 1:].cpu().numpy()
                # Normalize to get probability distribution
                attn_dist = attn_dist / (attn_dist.sum() + 1e-10)
                # Calculate entropy: -sum(p * log(p))
                entropy = -np.sum(attn_dist * np.log(attn_dist + 1e-10))
                entropy_values.append(entropy)

        attention_stats['mean_entropy'] = np.mean(entropy_values)
        attention_stats['min_entropy'] = np.min(entropy_values)
        attention_stats['max_entropy'] = np.max(entropy_values)

        # Attention concentration (how much attention focuses on top k% of patches)
        last_layer_attn = attentions[-1][0, :, 0, 1:].mean(0).cpu().numpy()  # Average across heads
        sorted_attn = np.sort(last_layer_attn)[::-1]  # Sort in descending order
        cumsum_attn = np.cumsum(sorted_attn) / np.sum(sorted_attn)

        # Find how many patches capture 50% and 90% of attention
        attention_stats['num_patches_50pct'] = np.argmax(cumsum_attn >= 0.5) + 1
        attention_stats['num_patches_90pct'] = np.argmax(cumsum_attn >= 0.9) + 1
        attention_stats['attention_concentration'] = attention_stats['num_patches_50pct'] / len(last_layer_attn)

        # 6. Plot attention concentration curve
        plt.figure(figsize=(10, 6))
        plt.plot(range(1, len(cumsum_attn) + 1), cumsum_attn, 'b-')
        plt.axhline(y=0.5, color='r', linestyle='--', label='50% attention')
        plt.axhline(y=0.9, color='g', linestyle='--', label='90% attention')
        plt.axvline(x=attention_stats['num_patches_50pct'], color='r', linestyle=':')
        plt.axvline(x=attention_stats['num_patches_90pct'], color='g', linestyle=':')
        plt.xlabel('Number of Patches (sorted by attention)')
        plt.ylabel('Cumulative Attention')
        plt.title('Attention Concentration Analysis')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.savefig(os.path.join(save_path, f"{base_filename}_attention_concentration.png"), dpi=300)
        plt.close()

        # Save statistics to JSON
        with open(os.path.join(save_path, f"{base_filename}_attention_stats.json"), 'w') as f:
            json.dump(attention_stats, f, indent=2)

        return attention_stats

def visualize_batch(model, imgs, masks, paths, epoch, step, save_dir):
    """Create debug visualizations for a batch"""
    # Forward pass to get predictions - using no_grad to avoid gradient tracking
    with torch.no_grad():
        logits = model(imgs)
        preds = torch.sigmoid(logits)
        preds_binary = (preds > 0.5).float()

        # Get attention maps if available
        attention_maps = None
        if hasattr(model, 'last_attentions') and model.last_attentions is not None:
            try:
                # Last layer, mean over heads, first token (CLS)
                attention_maps = model.last_attentions[-1].mean(1)[:, 0, 1:].reshape(-1, Config.patch_grid, Config.patch_grid)
            except (IndexError, AttributeError) as e:
                print(f"Warning: Could not extract attention maps: {e}")

    # Create directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)

    for i in range(min(2, imgs.shape[0])):  # Visualize up to 2 samples
        try:
            name = os.path.basename(paths[i])

            # Make sure everything is detached from computation graph
            mask_np = masks[i].detach().cpu().numpy()
            pred_np = preds[i].detach().cpu().squeeze().numpy()
            pred_binary_np = preds_binary[i].detach().cpu().squeeze().numpy()
            img_np = imgs[i].detach().cpu().permute(1, 2, 0).numpy()

            # Calculate metrics
            with torch.no_grad():
                batch_metrics = calculate_metrics(preds_binary[i:i+1], masks[i:i+1])

            # Create figure with appropriate number of subplots
            n_plots = 5 if attention_maps is not None else 4
            plt.figure(figsize=(n_plots*3, 3))

            # Input image
            plt.subplot(1, n_plots, 1)
            plt.imshow(img_np)
            plt.title("Input Image")
            plt.axis('off')

            # Ground truth mask
            plt.subplot(1, n_plots, 2)
            plt.imshow(mask_np, cmap='gray')
            plt.title("Ground Truth")
            plt.axis('off')

            # Prediction probability
            plt.subplot(1, n_plots, 3)
            plt.imshow(pred_np, cmap='hot')
            plt.title("Prediction (Prob)")
            plt.axis('off')

            # Binary prediction
            plt.subplot(1, n_plots, 4)
            plt.imshow(pred_binary_np, cmap='gray')
            plt.title(f"Binary Pred (IoU: {batch_metrics['iou']:.2f})")
            plt.axis('off')

            # Attention map (if available)
            if attention_maps is not None:
                attn_np = attention_maps[i].detach().cpu().numpy()
                plt.subplot(1, n_plots, 5)
                plt.imshow(attn_np, cmap='viridis')
                plt.title("Attention Map")
                plt.axis('off')

            plt.tight_layout()
            plt.savefig(os.path.join(save_dir, f"epoch{epoch+1}_step{step}_{name}"))
            plt.close()

        except Exception as e:
            print(f"Warning: Error in visualization for sample {i}: {str(e)}")
            continue

def plot_training_curves(metrics_history, save_path):
    """Plot training and validation metrics over epochs"""
    epochs = range(1, len(metrics_history['train_loss']) + 1)

    # Create subplots for each metric
    metric_groups = [
        ['loss'],
        ['iou', 'dice'],
        ['precision', 'recall', 'f1'],
        ['accuracy', 'specificity']
    ]

    fig, axes = plt.subplots(len(metric_groups), 1, figsize=(10, 4*len(metric_groups)))

    for i, metrics in enumerate(metric_groups):
        ax = axes[i]

        for metric in metrics:
            if f'train_{metric}' in metrics_history:
                ax.plot(epochs, metrics_history[f'train_{metric}'], 'b-', label=f'Train {metric}')
            if f'val_{metric}' in metrics_history:
                ax.plot(epochs, metrics_history[f'val_{metric}'], 'r-', label=f'Val {metric}')

        ax.set_xlabel('Epoch')
        ax.set_ylabel(' / '.join(metrics).capitalize())
        ax.set_title(f"{' / '.join(metrics).capitalize()} over Epochs")
        ax.legend()
        ax.grid(True, linestyle='--', alpha=0.6)

    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

# ==== TRAINING ====
def train():
    # متغیرهای مربوط به بهترین مدل
    best_val_iou = 0
    best_epoch = 0

    # Load dataset
    full_dataset = MaskedDataset(Config.data_path)

    # Split into train and validation sets
    train_indices, val_indices = train_test_split(
        range(len(full_dataset)),
        test_size=Config.test_split,
        random_state=42
    )

    train_dataset = torch.utils.data.Subset(full_dataset, train_indices)
    val_dataset = torch.utils.data.Subset(full_dataset, val_indices)

    train_loader = DataLoader(train_dataset, batch_size=Config.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=Config.batch_size, shuffle=False)

    print(f"Training on {len(train_dataset)} samples, validating on {len(val_dataset)} samples")

    # Initialize model and optimizer
    processor = FlavaProcessor.from_pretrained(Config.base_model_path)
    processor.image_processor.do_rescale = False

    model = FLAVASegmenter(Config.base_model_path).to(Config.device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=Config.lr)

    # Initialize loss function
    if Config.use_advanced_loss:
        loss_fn = FocalDiceLoss()
        print("Using FocalDiceLoss")
    else:
        loss_fn = nn.BCEWithLogitsLoss()
        print("Using BCEWithLogitsLoss")

    # Initialize metrics history
    metrics_history = {
        'train_loss': [], 'val_loss': [],
        'train_iou': [], 'val_iou': [],
        'train_dice': [], 'val_dice': [],
        'train_precision': [], 'val_precision': [],
        'train_recall': [], 'val_recall': [],
        'train_f1': [], 'val_f1': [],
        'train_accuracy': [], 'val_accuracy': [],
        'train_specificity': [], 'val_specificity': []
    }

    # Create directory for attention analysis
    attention_dir = os.path.join(Config.save_path, "attention_analysis")
    os.makedirs(attention_dir, exist_ok=True)

    # Check if model has attention capabilities
    # Try a test forward pass to see if we get attention maps
    sample_img, _, _ = full_dataset[0]
    sample_img = sample_img.unsqueeze(0).to(Config.device)
    pixel_inputs = processor(images=sample_img, return_tensors="pt")["pixel_values"].to(Config.device)

    with torch.no_grad():
        outputs = model.model(pixel_values=pixel_inputs, output_attentions=True)
        has_attention = hasattr(outputs, 'image_attentions') and outputs.image_attentions is not None

    if not has_attention:
        print("WARNING: This FLAVA model doesn't provide attention maps. Attention visualization will be disabled.")

    # Training loop
    for epoch in range(Config.num_epochs):
        # Training phase
        model.train()
        train_losses = []
        train_metrics_list = []

        for step, (imgs, masks, paths) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1} Training")):
            imgs, masks = imgs.to(Config.device), masks.to(Config.device)
            pixel_inputs = processor(images=imgs, return_tensors="pt")["pixel_values"].to(Config.device)

            # Forward pass
            logits = model(pixel_inputs).squeeze(1)
            loss = loss_fn(logits, masks)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Calculate metrics
            preds = (torch.sigmoid(logits) > 0.5).float()
            batch_metrics = calculate_metrics(preds, masks)

            train_losses.append(loss.item())
            train_metrics_list.append(batch_metrics)

            # Debug visualization
            if step % 20 == 0:
                try:
                    visualize_batch(model, imgs, masks, paths, epoch, step, Config.debug_dir)
                except Exception as e:
                    print(f"Warning: Error in batch visualization: {e}")

                # Run detailed attention analysis on one sample every 50 steps
                if has_attention and step % 50 == 0 and step > 0:
                    try:
                        attention_stats = visualize_attention(
                            model,
                            imgs[0],
                            processor,
                            attention_dir,
                            img_path=paths[0]
                        )
                        print(f"  Attention stats - Concentration: {attention_stats['attention_concentration']:.4f}, Entropy: {attention_stats['mean_entropy']:.4f}")
                    except Exception as e:
                        print(f"Warning: Error in attention visualization: {e}")

        # Calculate average training metrics
        avg_train_loss = np.mean(train_losses)
        avg_train_metrics = {}
        for metric in train_metrics_list[0].keys():
            avg_train_metrics[metric] = np.mean([m[metric] for m in train_metrics_list])

        # Validation phase
        model.eval()
        val_losses = []
        val_metrics_list = []

        with torch.no_grad():
            for step, (imgs, masks, paths) in enumerate(tqdm(val_loader, desc=f"Epoch {epoch+1} Validation")):
                imgs, masks = imgs.to(Config.device), masks.to(Config.device)
                pixel_inputs = processor(images=imgs, return_tensors="pt")["pixel_values"].to(Config.device)

                # Forward pass
                logits = model(pixel_inputs).squeeze(1)
                loss = loss_fn(logits, masks)

                # Calculate metrics
                preds = (torch.sigmoid(logits) > 0.5).float()
                batch_metrics = calculate_metrics(preds, masks)

                val_losses.append(loss.item())
                val_metrics_list.append(batch_metrics)

                # Debug visualization for validation
                if step % 10 == 0:
                    try:
                        visualize_batch(model, imgs, masks, paths, epoch, f"val_{step}", Config.debug_dir)
                    except Exception as e:
                        print(f"Warning: Error in validation visualization: {e}")

                    # Run detailed attention analysis on validation samples
                    if has_attention and step == 0:
                        try:
                            attention_stats = visualize_attention(
                                model,
                                imgs[0],
                                processor,
                                attention_dir,
                                img_path=f"val_epoch{epoch+1}_{os.path.basename(paths[0])}"
                            )
                        except Exception as e:
                            print(f"Warning: Error in validation attention visualization: {e}")

        # Calculate average validation metrics
        avg_val_loss = np.mean(val_losses)
        avg_val_metrics = {}
        for metric in val_metrics_list[0].keys():
            avg_val_metrics[metric] = np.mean([m[metric] for m in val_metrics_list])

        # Update metrics history
        metrics_history['train_loss'].append(avg_train_loss)
        metrics_history['val_loss'].append(avg_val_loss)

        for metric in avg_train_metrics:
            metrics_history[f'train_{metric}'].append(avg_train_metrics[metric])
            metrics_history[f'val_{metric}'].append(avg_val_metrics[metric])

        # ذخیره بهترین مدل بر اساس IoU
        if epoch == 0 or avg_val_metrics['iou'] > best_val_iou:
            best_val_iou = avg_val_metrics['iou']
            best_epoch = epoch + 1
            torch.save(model.state_dict(), os.path.join(Config.save_path, "flava_seg_head_enhanced_best.pth"))
            print(f"  New best model saved! IoU: {best_val_iou:.4f}")

        # Print epoch summary
        print(f"\n[Epoch {epoch+1}/{Config.num_epochs}] Summary:")
        print(f"  Train - Loss: {avg_train_loss:.4f}, IoU: {avg_train_metrics['iou']:.4f}, Dice: {avg_train_metrics['dice']:.4f}")
        print(f"  Validation - Loss: {avg_val_loss:.4f}, IoU: {avg_val_metrics['iou']:.4f}, Dice: {avg_val_metrics['dice']:.4f}")

        # Plot training curves
        plot_training_curves(metrics_history, os.path.join(Config.plots_dir, f"training_curves_epoch{epoch+1}.png"))

        # Save metrics as CSV
        metrics_df = pd.DataFrame(metrics_history)
        metrics_df.to_csv(os.path.join(Config.metrics_dir, "training_metrics.csv"))

        # Save model checkpoint
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': avg_train_loss,
            'val_loss': avg_val_loss,
            'metrics': metrics_history
        }, os.path.join(Config.save_path, f"model_checkpoint_epoch{epoch+1}.pth"))

    # Save final model
    torch.save(model.state_dict(), os.path.join(Config.save_path, "flava_seg_head_enhanced_final.pth"))

    # Final plots and visualizations
    plot_training_curves(metrics_history, os.path.join(Config.plots_dir, "final_training_curves.png"))

    # Generate final attention analysis for paper-ready visualizations
    if has_attention:
        print("\nGenerating final attention analysis visualizations...")
        # Select a few representative samples from validation set
        num_samples = min(5, len(val_dataset))
        for i in range(num_samples):
            try:
                img, mask, img_path = val_dataset[i]
                attention_stats = visualize_attention(
                    model,
                    img.to(Config.device),
                    processor,
                    attention_dir,
                    img_path=f"final_sample_{i+1}_{os.path.basename(img_path)}"
                )
            except Exception as e:
                print(f"Warning: Error in final attention visualization for sample {i}: {e}")

    print(f"\n✅ Model training completed! Results saved to: {Config.save_path}")
    print(f"  Best validation IoU: {best_val_iou:.4f} (Epoch {best_epoch})")
    print(f"  Final validation IoU: {metrics_history['val_iou'][-1]:.4f}")
    print(f"  Best model saved to: {os.path.join(Config.save_path, 'flava_seg_head_enhanced_best.pth')}")

    if has_attention:
        print(f"  Attention analysis and visualizations saved to: {attention_dir}")

    return model, metrics_history

def visualize_gradcam(model, image_tensor, mask_tensor, processor, save_path=None):
    """Visualize GradCAM to understand which parts of the image influence predictions"""
    model.eval()

    # Ensure batch dimension
    if len(image_tensor.shape) == 3:
        image_tensor = image_tensor.unsqueeze(0)
    if len(mask_tensor.shape) == 2:
        mask_tensor = mask_tensor.unsqueeze(0)

    # Clear gradients
    model.zero_grad()

    # Forward pass
    pixel_values = processor(images=image_tensor, return_tensors="pt")["pixel_values"].to(Config.device)
    logits = model(pixel_values)

    # Calculate prediction
    pred = torch.sigmoid(logits)

    # Backward pass to get gradients
    pred.mean().backward()

    # Get GradCAM
    cam = model.get_gradcam()

    if cam is None:
        print("Warning: Could not generate GradCAM.")
        return None

    # Resize CAM to match image size
    cam_resized = F.interpolate(cam, size=(224, 224), mode='bilinear', align_corners=False)

    # Convert tensors to numpy for visualization
    image_np = image_tensor[0].cpu().permute(1, 2, 0).numpy()
    mask_np = mask_tensor[0].cpu().numpy()
    pred_np = pred[0, 0].detach().cpu().numpy()
    cam_np = cam_resized[0, 0].detach().cpu().numpy()

    if save_path:
        # Save visualization
        plt.figure(figsize=(15, 5))

        plt.subplot(1, 4, 1)
        plt.imshow(image_np)
        plt.title("Original Image")
        plt.axis('off')

        plt.subplot(1, 4, 2)
        plt.imshow(mask_np, cmap='gray')
        plt.title("Ground Truth")
        plt.axis('off')

        plt.subplot(1, 4, 3)
        plt.imshow(pred_np, cmap='gray')
        plt.title("Prediction")
        plt.axis('off')

        plt.subplot(1, 4, 4)
        plt.imshow(image_np)
        plt.imshow(cam_np, cmap='jet', alpha=0.5)
        plt.title("GradCAM Overlay")
        plt.axis('off')

        plt.tight_layout()
        plt.savefig(save_path, dpi=300)
        plt.close()

    return cam_np

def analyze_feature_embeddings(model, dataset, processor, save_dir, num_samples=100):
    """Analyze embedding space to understand model behavior"""
    # Limit the number of samples to analyze
    num_samples = min(num_samples, len(dataset))
    indices = np.random.choice(len(dataset), num_samples, replace=False)

    # Collect embeddings and metadata
    embeddings = []
    labels = []
    filenames = []
    metrics = []

    model.eval()
    with torch.no_grad():
        for i in tqdm(indices, desc="Extracting embeddings"):
            img, mask, img_path = dataset[i]

            # Get class from path
            try:
                class_name = img_path.split('/')[-2]  # Assuming folder name is class name
            except:
                class_name = "unknown"

            # Process image
            pixel_values = processor(images=img.unsqueeze(0), return_tensors="pt")["pixel_values"].to(Config.device)

            # Get embeddings
            outputs = model.model(pixel_values)
            embedding = outputs.image_embeddings[:, 0, :].cpu().numpy()  # CLS token

            # Get prediction and calculate metrics
            logits = model(pixel_values).squeeze()
            pred = (torch.sigmoid(logits) > 0.5).float()
            metric = calculate_metrics(pred.unsqueeze(0), mask.unsqueeze(0).to(Config.device))

            embeddings.append(embedding)
            labels.append(class_name)
            filenames.append(os.path.basename(img_path))
            metrics.append(metric)

    # Convert to numpy arrays
    embeddings = np.vstack(embeddings)

    # Run t-SNE to visualize embedding space
    tsne = TSNE(n_components=2, random_state=42)
    embeddings_2d = tsne.fit_transform(embeddings)

    # Create DataFrame for easier plotting
    df = pd.DataFrame({
        'x': embeddings_2d[:, 0],
        'y': embeddings_2d[:, 1],
        'class': labels,
        'filename': filenames,
        'iou': [m['iou'] for m in metrics],
        'dice': [m['dice'] for m in metrics]
    })

    # Save the data for future analysis
    df.to_csv(os.path.join(save_dir, "embedding_analysis.csv"), index=False)

    # Plot by class
    plt.figure(figsize=(12, 10))
    unique_classes = df['class'].unique()

    # Create a colormap
    cmap = plt.cm.get_cmap('tab10', len(unique_classes))

    for i, cls in enumerate(unique_classes):
        subset = df[df['class'] == cls]
        plt.scatter(subset['x'], subset['y'], c=[cmap(i)], label=cls, alpha=0.7)

    plt.title('t-SNE Visualization of FLAVA Embeddings by Class')
    plt.xlabel('t-SNE Dimension 1')
    plt.ylabel('t-SNE Dimension 2')
    plt.legend()
    plt.savefig(os.path.join(save_dir, "tsne_by_class.png"), dpi=300)
    plt.close()

    # Plot by IoU performance
    plt.figure(figsize=(12, 10))
    scatter = plt.scatter(df['x'], df['y'], c=df['iou'], cmap='viridis', alpha=0.7)
    plt.colorbar(scatter, label='IoU Score')
    plt.title('t-SNE Visualization of FLAVA Embeddings by Segmentation Performance')
    plt.xlabel('t-SNE Dimension 1')
    plt.ylabel('t-SNE Dimension 2')
    plt.savefig(os.path.join(save_dir, "tsne_by_performance.png"), dpi=300)
    plt.close()

    return df

def statistical_analysis(model, dataset, processor, save_dir):
    """Perform advanced statistical analysis of model performance"""
    metrics_by_class = {}
    size_metrics = {}  # Correlate defect size with performance
    location_metrics = {}  # Analyze performance by defect location

    # Process each sample
    for i in tqdm(range(len(dataset)), desc="Performing statistical analysis"):
        img, mask, img_path = dataset[i]

        # Get class from path
        try:
            class_name = img_path.split('/')[-2]  # Assuming folder name is class name
        except:
            class_name = "unknown"

        if class_name not in metrics_by_class:
            metrics_by_class[class_name] = []

        # Get prediction
        model.eval()
        with torch.no_grad():
            pixel_values = processor(images=img.unsqueeze(0), return_tensors="pt")["pixel_values"].to(Config.device)
            logits = model(pixel_values).squeeze()
            pred = (torch.sigmoid(logits) > 0.5).float()

            # Calculate metrics
            metrics = calculate_metrics(pred.unsqueeze(0), mask.unsqueeze(0).to(Config.device))
            metrics_by_class[class_name].append(metrics)

            # Analyze defect size correlation
            defect_size = mask.sum().item() / mask.numel()
            size_bin = int(defect_size * 10)  # Create size bins (0-10%)
            if size_bin not in size_metrics:
                size_metrics[size_bin] = []
            size_metrics[size_bin].append(metrics['iou'])

            # Analyze defect location
            # Calculate center of mass of the defect
            if mask.sum() > 0:
                indices = torch.nonzero(mask)
                center_y = indices[:, 0].float().mean() / mask.shape[0]
                center_x = indices[:, 1].float().mean() / mask.shape[1]

                # Divide image into 3x3 grid
                grid_y = min(2, int(center_y * 3))
                grid_x = min(2, int(center_x * 3))
                grid_pos = (grid_y * 3) + grid_x

                if grid_pos not in location_metrics:
                    location_metrics[grid_pos] = []
                location_metrics[grid_pos].append(metrics['iou'])

    # Generate class-wise performance report
    class_performance = {cls: {metric: np.mean([m[metric] for m in metrics_list])
                             for metric in ['iou', 'dice', 'precision', 'recall']}
                       for cls, metrics_list in metrics_by_class.items()}

    # Save as CSV
    cls_df = pd.DataFrame(class_performance).T
    cls_df.to_csv(os.path.join(save_dir, "class_performance.csv"))

    # Generate class performance visualization
    plt.figure(figsize=(12, 8))
    cls_df[['iou', 'dice', 'precision', 'recall']].plot(kind='bar')
    plt.title('Performance Metrics by Class')
    plt.ylabel('Score')
    plt.grid(axis='y', linestyle='--', alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, "class_performance.png"), dpi=300)
    plt.close()

    # Generate size correlation plot
    plt.figure(figsize=(10, 6))
    sizes = []
    ious = []
    errors = []

    for size_bin, iou_values in sorted(size_metrics.items()):
        if len(iou_values) > 0:  # Only include bins with data
            sizes.append(size_bin / 10)  # Convert bin back to percentage
            ious.append(np.mean(iou_values))
            errors.append(np.std(iou_values) / np.sqrt(len(iou_values)))  # Standard error

    plt.errorbar(sizes, ious, yerr=errors, fmt='o-', capsize=5)
    plt.xlabel('Defect Size (% of image)')
    plt.ylabel('Mean IoU')
    plt.title('Relationship Between Defect Size and Segmentation Performance')
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.savefig(os.path.join(save_dir, "size_performance_correlation.png"), dpi=300)
    plt.close()

    # Generate location performance visualization (3x3 grid)
    plt.figure(figsize=(8, 8))
    grid_values = np.zeros((3, 3))

    for pos, values in location_metrics.items():
        if len(values) > 0:  # Only include positions with data
            row = pos // 3
            col = pos % 3
            grid_values[row, col] = np.mean(values)

    ax = sns.heatmap(grid_values, annot=True, cmap='viridis', fmt='.3f', cbar_kws={'label': 'Mean IoU'})
    ax.set_title('Segmentation Performance by Defect Location')
    ax.set_xlabel('Horizontal Position')
    ax.set_ylabel('Vertical Position')
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, "location_performance.png"), dpi=300)
    plt.close()

    # Perform statistical tests
    # 1. Compare performance between different defect sizes
    size_stats = {}
    if len(size_metrics) > 1:
        size_keys = sorted(size_metrics.keys())
        for i in range(len(size_keys)):
            for j in range(i+1, len(size_keys)):
                if len(size_metrics[size_keys[i]]) > 5 and len(size_metrics[size_keys[j]]) > 5:
                    t_stat, p_val = ttest_ind(size_metrics[size_keys[i]], size_metrics[size_keys[j]], equal_var=False)
                    comparison = f"{size_keys[i]/10:.1f}% vs {size_keys[j]/10:.1f}%"
                    size_stats[comparison] = {
                        't_statistic': t_stat,
                        'p_value': p_val,
                        'significant': p_val < 0.05
                    }

    # 2. Compare performance between different classes
    class_stats = {}
    class_keys = list(metrics_by_class.keys())
    for i in range(len(class_keys)):
        for j in range(i+1, len(class_keys)):
            if len(metrics_by_class[class_keys[i]]) > 5 and len(metrics_by_class[class_keys[j]]) > 5:
                iou_i = [m['iou'] for m in metrics_by_class[class_keys[i]]]
                iou_j = [m['iou'] for m in metrics_by_class[class_keys[j]]]
                t_stat, p_val = ttest_ind(iou_i, iou_j, equal_var=False)
                comparison = f"{class_keys[i]} vs {class_keys[j]}"
                class_stats[comparison] = {
                    't_statistic': t_stat,
                    'p_value': p_val,
                    'significant': p_val < 0.05
                }

    # Save statistical test results
    with open(os.path.join(save_dir, "statistical_tests.txt"), 'w') as f:
        f.write("Statistical Analysis of FLAVA Defect Segmentation\n")
        f.write("===============================================\n\n")

        f.write("Size Comparison Tests (t-tests)\n")
        f.write("-------------------------------\n")
        for comparison, stats in size_stats.items():
            f.write(f"{comparison}: t={stats['t_statistic']:.4f}, p={stats['p_value']:.4f}")
            if stats['significant']:
                f.write(" (significant)\n")
            else:
                f.write(" (not significant)\n")

        f.write("\nClass Comparison Tests (t-tests)\n")
        f.write("-------------------------------\n")
        for comparison, stats in class_stats.items():
            f.write(f"{comparison}: t={stats['t_statistic']:.4f}, p={stats['p_value']:.4f}")
            if stats['significant']:
                f.write(" (significant)\n")
            else:
                f.write(" (not significant)\n")

    return {
        'class_performance': class_performance,
        'size_metrics': size_metrics,
        'location_metrics': location_metrics,
        'size_stats': size_stats,
        'class_stats': class_stats
    }

if __name__ == "__main__":
    try:
        model, metrics_history = train()

        # Get validation dataset
        full_dataset = MaskedDataset(Config.data_path)
        _, val_indices = train_test_split(range(len(full_dataset)), test_size=Config.test_split, random_state=42)
        val_dataset = torch.utils.data.Subset(full_dataset, val_indices)

        # Load processor
        processor = FlavaProcessor.from_pretrained(Config.base_model_path)
        processor.image_processor.do_rescale = False

        # Check if we can generate publication visualizations (depends on attention availability)
        with torch.no_grad():
            sample_img, _, _ = val_dataset[0]
            sample_img = sample_img.unsqueeze(0).to(Config.device)
            outputs = model.model(
                pixel_values=processor(images=sample_img, return_tensors="pt")["pixel_values"].to(Config.device),
                output_attentions=True
            )
            has_attention = hasattr(outputs, 'image_attentions') and outputs.image_attentions is not None

        # Function to generate publication-quality visualizations of predictions only (no attention)
        def generate_basic_visualizations(model, dataset, processor, save_dir, num_samples=5):
            """Generate basic prediction visualizations suitable for publication"""
            pub_viz_dir = os.path.join(save_dir, "publication_figures")
            os.makedirs(pub_viz_dir, exist_ok=True)

            # Select representative samples
            indices = np.random.choice(len(dataset), num_samples, replace=False)

            # Create a figure with multiple rows for samples and columns for visualizations
            fig = plt.figure(figsize=(12, num_samples * 3))
            gs = fig.add_gridspec(num_samples, 3, hspace=0.2, wspace=0.1)

            for i, idx in enumerate(indices):
                img, mask, img_path = dataset[idx]
                img_name = os.path.basename(img_path)

                # Process image and get predictions
                model.eval()
                with torch.no_grad():
                    pixel_values = processor(images=img.unsqueeze(0), return_tensors="pt")["pixel_values"].to(Config.device)
                    logits = model(pixel_values).squeeze()
                    pred = (torch.sigmoid(logits) > 0.5).float()

                    # Calculate IoU and Dice
                    iou = calculate_metrics(pred.unsqueeze(0), mask.unsqueeze(0).to(Config.device))['iou']
                    dice = calculate_metrics(pred.unsqueeze(0), mask.unsqueeze(0).to(Config.device))['dice']

                    # Error map (difference between prediction and ground truth)
                    error_map = torch.abs(pred - mask.to(Config.device)).cpu().numpy()

                # Original image
                ax = fig.add_subplot(gs[i, 0])
                ax.imshow(img.permute(1, 2, 0).cpu().numpy())
                ax.set_title("Input Image" if i == 0 else None)
                ax.axis('off')

                # Ground truth
                ax = fig.add_subplot(gs[i, 1])
                ax.imshow(mask.cpu().numpy(), cmap='gray')
                ax.set_title("Ground Truth" if i == 0 else None)
                ax.axis('off')

                # Model prediction with metrics
                ax = fig.add_subplot(gs[i, 2])
                pred_np = pred.cpu().numpy()
                ax.imshow(pred_np, cmap='gray')
                ax.set_title(f"Prediction (IoU: {iou:.2f}, Dice: {dice:.2f})" if i == 0 else f"IoU: {iou:.2f}, Dice: {dice:.2f}")
                ax.axis('off')

            plt.tight_layout()
            plt.savefig(os.path.join(pub_viz_dir, "prediction_visualization.png"), dpi=300, bbox_inches='tight')
            plt.savefig(os.path.join(pub_viz_dir, "prediction_visualization.pdf"), format='pdf', bbox_inches='tight')
            plt.close()

            # Create a metrics table for all samples
            all_metrics = []
            for i in range(min(20, len(dataset))):  # Get metrics for up to 20 samples
                img, mask, _ = dataset[i]
                with torch.no_grad():
                    pixel_values = processor(images=img.unsqueeze(0), return_tensors="pt")["pixel_values"].to(Config.device)
                    logits = model(pixel_values).squeeze()
                    pred = (torch.sigmoid(logits) > 0.5).float()
                    metrics = calculate_metrics(pred.unsqueeze(0), mask.unsqueeze(0).to(Config.device))
                    all_metrics.append(metrics)

            # Calculate overall statistics
            metrics_df = pd.DataFrame(all_metrics)
            metrics_summary = {
                'metric': [],
                'mean': [],
                'std': [],
                'min': [],
                'max': []
            }

            for metric in metrics_df.columns:
                metrics_summary['metric'].append(metric)
                metrics_summary['mean'].append(metrics_df[metric].mean())
                metrics_summary['std'].append(metrics_df[metric].std())
                metrics_summary['min'].append(metrics_df[metric].min())
                metrics_summary['max'].append(metrics_df[metric].max())

            summary_df = pd.DataFrame(metrics_summary)
            summary_df.to_csv(os.path.join(pub_viz_dir, "metrics_summary.csv"), index=False)

            # Create a metrics visualization
            plt.figure(figsize=(10, 6))
            metrics_to_plot = ['iou', 'dice', 'precision', 'recall', 'f1']
            means = [summary_df[summary_df['metric'] == m]['mean'].values[0] for m in metrics_to_plot]
            stds = [summary_df[summary_df['metric'] == m]['std'].values[0] for m in metrics_to_plot]

            x = np.arange(len(metrics_to_plot))
            plt.bar(x, means, yerr=stds, align='center', alpha=0.7, capsize=10)
            plt.xticks(x, [m.capitalize() for m in metrics_to_plot])
            plt.ylabel('Score')
            plt.title('Performance Metrics')
            plt.ylim(0, 1.0)
            plt.grid(axis='y', linestyle='--', alpha=0.3)

            for i, v in enumerate(means):
                plt.text(i, v + 0.02, f"{v:.3f}", ha='center')

            plt.tight_layout()
            plt.savefig(os.path.join(pub_viz_dir, "metrics_visualization.png"), dpi=300)
            plt.close()

            print(f"Basic visualizations saved to {pub_viz_dir}")

        # Generate publication-quality visualizations
        print("\nGenerating publication-ready visualizations...")

        # If we have attention capabilities, use the advanced function
        if has_attention:
            # Function to generate publication-quality visualizations with attention analysis
            def generate_publication_visualizations(model, dataset, processor, save_dir, num_samples=5):
                """Generate high-quality visualizations for scientific publication"""
                pub_viz_dir = os.path.join(save_dir, "publication_figures")
                os.makedirs(pub_viz_dir, exist_ok=True)

                # Select representative samples
                indices = np.random.choice(len(dataset), num_samples, replace=False)

                # Create a figure with multiple rows for samples and columns for visualizations
                fig = plt.figure(figsize=(15, num_samples * 4))
                gs = fig.add_gridspec(num_samples, 4, hspace=0.2, wspace=0.1)

                for i, idx in enumerate(indices):
                    img, mask, img_path = dataset[idx]
                    img_name = os.path.basename(img_path)

                    # Process image and get predictions
                    model.eval()
                    with torch.no_grad():
                        pixel_values = processor(images=img.unsqueeze(0), return_tensors="pt")["pixel_values"].to(Config.device)
                        outputs = model.model(pixel_values=pixel_values, output_attentions=True)
                        logits = model(pixel_values).squeeze()
                        pred = (torch.sigmoid(logits) > 0.5).float()

                        # Get attention map (CLS token to patches)
                        attn = outputs.image_attentions[-1]  # Last layer
                        attn = attn.mean(1)[0]  # Average across heads, first batch
                        cls_attn = attn[0, 1:].reshape(Config.patch_grid, Config.patch_grid).cpu().numpy()

                        # Calculate error map
                        error_map = torch.abs(pred - mask.to(Config.device)).cpu().numpy()

                    # Original image
                    ax = fig.add_subplot(gs[i, 0])
                    ax.imshow(img.permute(1, 2, 0).cpu().numpy())
                    ax.set_title("Input Image" if i == 0 else None)
                    ax.axis('off')

                    # Ground truth
                    ax = fig.add_subplot(gs[i, 1])
                    ax.imshow(mask.cpu().numpy(), cmap='gray')
                    ax.set_title("Ground Truth" if i == 0 else None)
                    ax.axis('off')

                    # Model prediction
                    ax = fig.add_subplot(gs[i, 2])
                    pred_np = pred.cpu().numpy()
                    ax.imshow(pred_np, cmap='gray')
                    ax.set_title("Model Prediction" if i == 0 else None)
                    ax.axis('off')

                    # Attention map
                    ax = fig.add_subplot(gs[i, 3])
                    im = ax.imshow(cls_attn, cmap='viridis')
                    ax.set_title("Attention Map" if i == 0 else None)
                    ax.axis('off')

                    # Add colorbar to the last attention map
                    if i == 0:
                        plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

                plt.tight_layout()
                plt.savefig(os.path.join(pub_viz_dir, "multi_sample_visualization.png"), dpi=300, bbox_inches='tight')
                plt.savefig(os.path.join(pub_viz_dir, "multi_sample_visualization.pdf"), format='pdf', bbox_inches='tight')
                plt.close()

                # Generate attention concentration analysis across samples
                plt.figure(figsize=(10, 6))

                attention_concentrations = []
                sample_labels = []

                for i, idx in enumerate(indices):
                    img, mask, img_path = dataset[idx]
                    sample_labels.append(f"Sample {i+1}")

                    # Process image
                    model.eval()
                    with torch.no_grad():
                        pixel_values = processor(images=img.unsqueeze(0), return_tensors="pt")["pixel_values"].to(Config.device)
                        outputs = model.model(pixel_values=pixel_values, output_attentions=True)

                        # Get attention from last layer
                        attn = outputs.image_attentions[-1][0]  # Last layer, first batch

                        # For each attention head
                        head_concentrations = []
                        for head_idx in range(attn.shape[0]):
                            # Get CLS token attention distribution
                            attn_dist = attn[head_idx, 0, 1:].cpu().numpy()
                            # Sort in descending order
                            sorted_attn = np.sort(attn_dist)[::-1]
                            # Calculate cumulative sum
                            cumsum_attn = np.cumsum(sorted_attn) / np.sum(sorted_attn)
                            # Find number of patches for 50% attention
                            num_patches_50 = np.argmax(cumsum_attn >= 0.5) + 1
                            # Calculate concentration
                            concentration = num_patches_50 / len(attn_dist)
                            head_concentrations.append(concentration)

                        attention_concentrations.append(head_concentrations)

                # Convert to numpy array for easier manipulation
                attention_concentrations = np.array(attention_concentrations)

                # Plot boxplot of attention concentration across heads for each sample
                plt.boxplot(attention_concentrations.T, labels=sample_labels)
                plt.ylabel('Attention Concentration (patches for 50% attention / total patches)')
                plt.title('Attention Concentration Analysis Across Samples')
                plt.grid(axis='y', linestyle='--', alpha=0.7)
                plt.savefig(os.path.join(pub_viz_dir, "attention_concentration_analysis.png"), dpi=300)
                plt.close()

                # Generate statistics summary
                stats_file = os.path.join(pub_viz_dir, "attention_statistics.txt")
                with open(stats_file, 'w') as f:
                    f.write("Attention Concentration Statistics\n")
                    f.write("================================\n\n")
                    f.write("Lower values indicate more focused attention\n\n")

                    f.write("Per Sample Statistics:\n")
                    for i, sample in enumerate(sample_labels):
                        mean_conc = np.mean(attention_concentrations[i])
                        min_conc = np.min(attention_concentrations[i])
                        max_conc = np.max(attention_concentrations[i])
                        f.write(f"{sample}: Mean={mean_conc:.4f}, Min={min_conc:.4f}, Max={max_conc:.4f}\n")

                    f.write("\nOverall Statistics:\n")
                    overall_mean = np.mean(attention_concentrations)
                    overall_std = np.std(attention_concentrations)
                    f.write(f"Mean concentration: {overall_mean:.4f} ± {overall_std:.4f}\n")

                print(f"Publication visualizations saved to {pub_viz_dir}")

            try:
                # Use the advanced visualization with attention if available
                generate_publication_visualizations(model, val_dataset, processor, Config.save_path)
            except Exception as e:
                print(f"Error generating advanced visualizations: {e}")
                print("Falling back to basic visualizations...")
                generate_basic_visualizations(model, val_dataset, processor, Config.save_path)
        else:
            # Use the basic visualization if attention is not available
            generate_basic_visualizations(model, val_dataset, processor, Config.save_path)

        # Generate GradCAM visualizations for interpretability
        print("\nGenerating GradCAM visualizations for interpretability...")
        try:
            # Create a subset for GradCAM analysis
            num_samples = min(10, len(val_dataset))
            indices = np.random.choice(len(val_dataset), num_samples, replace=False)

            for i, idx in enumerate(indices):
                img, mask, img_path = val_dataset[idx]
                img_name = os.path.basename(img_path)

                # Generate and save GradCAM
                gradcam_path = os.path.join(Config.gradcam_dir, f"gradcam_{img_name}")
                visualize_gradcam(model, img, mask, processor, gradcam_path)

            print(f"GradCAM visualizations saved to {Config.gradcam_dir}")
        except Exception as e:
            print(f"Error generating GradCAM visualizations: {e}")

        # Perform feature embedding analysis
        print("\nAnalyzing feature embeddings...")
        try:
            embedding_df = analyze_feature_embeddings(model, val_dataset, processor, Config.analysis_dir)
            print(f"Feature embedding analysis saved to {Config.analysis_dir}")
        except Exception as e:
            print(f"Error analyzing feature embeddings: {e}")

        # Perform statistical analysis
        print("\nPerforming comprehensive statistical analysis...")
        try:
            stats = statistical_analysis(model, val_dataset, processor, Config.analysis_dir)
            print(f"Statistical analysis saved to {Config.analysis_dir}")
        except Exception as e:
            print(f"Error performing statistical analysis: {e}")

        print("\n✅ All analyses completed successfully!")
        print(f"  All results saved to: {Config.save_path}")
        print("  The enhanced analysis provides:")
        print("  - Synthetic multi-scale features for improved segmentation")
        print("  - GradCAM visualizations for model interpretability")
        print("  - Feature embedding analysis for understanding model behavior")
        print("  - Comprehensive statistical analysis by defect class, size, and location")
        print("  - Publication-ready visualizations and metrics")

    except Exception as e:
        print(f"An error occurred during training: {e}")
        import traceback
        traceback.print_exc()

Using device: cuda
Found 522 images with masks
Training on 417 samples, validating on 105 samples
Using FocalDiceLoss


Epoch 1 Training:   0%|          | 0/105 [00:00<?, ?it/s]

Available outputs: ['image_embeddings', 'image_output']
Falling back to single-scale features.
Using synthetic multi-scale features


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
Epoch 1 Training: 100%|██████████| 105/105 [00:27<00:00,  3.82it/s]
Epoch 1 Validation: 100%|██████████| 27/27 [00:04<00:00,  6.17it/s]


  New best model saved! IoU: 0.6077

[Epoch 1/10] Summary:
  Train - Loss: 0.3981, IoU: 0.4440, Dice: 0.5323
  Validation - Loss: 0.3484, IoU: 0.6077, Dice: 0.6656


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
Epoch 2 Training: 100%|██████████| 105/105 [00:34<00:00,  3.08it/s]
Epoch 2 Validation: 100%|██████████| 27/27 [00:04<00:00,  5.79it/s]


  New best model saved! IoU: 0.6622

[Epoch 2/10] Summary:
  Train - Loss: 0.3416, IoU: 0.5964, Dice: 0.6603
  Validation - Loss: 0.3274, IoU: 0.6622, Dice: 0.7070


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
Epoch 3 Training: 100%|██████████| 105/105 [00:46<00:00,  2.27it/s]
Epoch 3 Validation: 100%|██████████| 27/27 [00:06<00:00,  4.26it/s]


  New best model saved! IoU: 0.6786

[Epoch 3/10] Summary:
  Train - Loss: 0.3176, IoU: 0.6561, Dice: 0.6972
  Validation - Loss: 0.3107, IoU: 0.6786, Dice: 0.7162


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
Epoch 4 Training: 100%|██████████| 105/105 [00:30<00:00,  3.48it/s]
Epoch 4 Validation: 100%|██████████| 27/27 [00:04<00:00,  5.51it/s]


  New best model saved! IoU: 0.6884

[Epoch 4/10] Summary:
  Train - Loss: 0.3031, IoU: 0.6884, Dice: 0.7167
  Validation - Loss: 0.3137, IoU: 0.6884, Dice: 0.7178


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
Epoch 5 Training: 100%|██████████| 105/105 [00:30<00:00,  3.44it/s]
Epoch 5 Validation: 100%|██████████| 27/27 [00:04<00:00,  5.55it/s]


  New best model saved! IoU: 0.6950

[Epoch 5/10] Summary:
  Train - Loss: 0.2927, IoU: 0.7089, Dice: 0.7389
  Validation - Loss: 0.3068, IoU: 0.6950, Dice: 0.7278


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
Epoch 6 Training: 100%|██████████| 105/105 [00:44<00:00,  2.34it/s]
Epoch 6 Validation: 100%|██████████| 27/27 [00:04<00:00,  5.83it/s]


  New best model saved! IoU: 0.6980

[Epoch 6/10] Summary:
  Train - Loss: 0.2832, IoU: 0.7287, Dice: 0.7481
  Validation - Loss: 0.3162, IoU: 0.6980, Dice: 0.7343


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
Epoch 7 Training: 100%|██████████| 105/105 [00:28<00:00,  3.74it/s]
Epoch 7 Validation: 100%|██████████| 27/27 [00:04<00:00,  5.62it/s]


  New best model saved! IoU: 0.7063

[Epoch 7/10] Summary:
  Train - Loss: 0.2769, IoU: 0.7276, Dice: 0.7570
  Validation - Loss: 0.3029, IoU: 0.7063, Dice: 0.7433


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
Epoch 8 Training: 100%|██████████| 105/105 [00:26<00:00,  3.97it/s]
Epoch 8 Validation: 100%|██████████| 27/27 [00:04<00:00,  5.72it/s]


  New best model saved! IoU: 0.7103

[Epoch 8/10] Summary:
  Train - Loss: 0.2732, IoU: 0.7409, Dice: 0.7625
  Validation - Loss: 0.2988, IoU: 0.7103, Dice: 0.7384


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
Epoch 9 Training: 100%|██████████| 105/105 [00:26<00:00,  3.90it/s]
Epoch 9 Validation: 100%|██████████| 27/27 [00:05<00:00,  5.06it/s]



[Epoch 9/10] Summary:
  Train - Loss: 0.2660, IoU: 0.7498, Dice: 0.7663
  Validation - Loss: 0.2794, IoU: 0.7086, Dice: 0.7360


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
Epoch 10 Training: 100%|██████████| 105/105 [00:27<00:00,  3.87it/s]
Epoch 10 Validation: 100%|██████████| 27/27 [00:05<00:00,  5.23it/s]


  New best model saved! IoU: 0.7153

[Epoch 10/10] Summary:
  Train - Loss: 0.2635, IoU: 0.7477, Dice: 0.7673
  Validation - Loss: 0.2956, IoU: 0.7153, Dice: 0.7385

✅ Model training completed! Results saved to: /content/drive/MyDrive/new_flava/attention_seg_head_enhanced
  Best validation IoU: 0.7153 (Epoch 10)
  Final validation IoU: 0.7153
  Best model saved to: /content/drive/MyDrive/new_flava/attention_seg_head_enhanced/flava_seg_head_enhanced_best.pth
Found 522 images with masks

Generating publication-ready visualizations...


  plt.tight_layout()


Basic visualizations saved to /content/drive/MyDrive/new_flava/attention_seg_head_enhanced/publication_figures

Generating GradCAM visualizations for interpretability...


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


GradCAM visualizations saved to /content/drive/MyDrive/new_flava/attention_seg_head_enhanced/gradcam

Analyzing feature embeddings...


Extracting embeddings:   0%|          | 0/100 [00:00<?, ?it/s]


Error analyzing feature embeddings: Wrong shape for input_ids (shape torch.Size([1, 3, 224, 224])) or attention_mask (shape torch.Size([1, 3, 224, 224]))

Performing comprehensive statistical analysis...


Performing statistical analysis: 100%|██████████| 105/105 [00:03<00:00, 28.52it/s]


Statistical analysis saved to /content/drive/MyDrive/new_flava/attention_seg_head_enhanced/advanced_analysis

✅ All analyses completed successfully!
  All results saved to: /content/drive/MyDrive/new_flava/attention_seg_head_enhanced
  The enhanced analysis provides:
  - Synthetic multi-scale features for improved segmentation
  - GradCAM visualizations for model interpretability
  - Feature embedding analysis for understanding model behavior
  - Comprehensive statistical analysis by defect class, size, and location
  - Publication-ready visualizations and metrics


<Figure size 1200x800 with 0 Axes>

In [None]:
import os
import torch
import shutil
from transformers import FlavaModel, FlavaProcessor

# مسیرها و تنظیمات - مطمئن شوید با تنظیمات شما مطابقت دارد
save_path = "/content/drive/MyDrive/new_flava/attention_seg_head_enhanced"
base_model_path = "/content/drive/MyDrive//flava_finetuned"
best_epoch = 10  # بهترین epoch که از پیام قبلی مشخص شده است
device = "cuda" if torch.cuda.is_available() else "cpu"

# تعریف کلاس مدل - مطمئن شوید با همان کلاسی که برای آموزش استفاده کردید مطابقت دارد
class FLAVASegmenter(nn.Module):
    def __init__(self, base_model_path):
        super().__init__()
        self.model = FlavaModel.from_pretrained(base_model_path)
        # Extract features from multiple transformer layers
        self.use_multi_scale = Config.use_multi_scale

        # Add projection layers for each scale
        self.projections = nn.ModuleList([
            nn.Linear(self.model.config.hidden_size, 256)
            for _ in range(4)  # Use last 4 layers
        ])

        # Fusion layer for multi-scale features
        self.fusion = nn.Conv2d(256*4, 256, kernel_size=1) if self.use_multi_scale else None

        # For synthetic multi-scale if real ones not available
        if Config.synthetic_multi_scale:
            # Projection for synthetic multi-scale features
            self.syn_projection = nn.Linear(self.model.config.hidden_size, 256)
            # Fusion for synthetic features (3 scales: original, 2x2 pooled, 4x4 pooled)
            self.syn_fusion = nn.Conv2d(256*3, 256, kernel_size=1)

        # Segmentation head with convolutional layers
        self.head = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 1, kernel_size=1)
        )

        # Store activation maps for GradCAM
        self.activation = {}
        self.gradients = {}
        self._register_hooks()

        # Flag to indicate if we had to fall back to single-scale
        self.using_fallback = False
        self.using_synthetic = False

    def _register_hooks(self):
        # Register hooks for GradCAM
        def get_activation(name):
            def hook(module, input, output):
                self.activation[name] = output
            return hook

        def get_gradient(name):
            def hook(module, grad_in, grad_out):
                self.gradients[name] = grad_out[0]
            return hook

        # Register hooks for the last convolutional layer
        self.head[0].register_forward_hook(get_activation('conv_feature'))
        self.head[0].register_backward_hook(get_gradient('conv_feature'))

    def create_synthetic_multi_scale(self, embeddings):
        """Create synthetic multi-scale features from a single embedding layer"""
        # Original scale
        b, n, c = embeddings.shape

        # Project embeddings
        projected = self.syn_projection(embeddings)

        # Reshape to spatial dimensions (excluding CLS token)
        spatial = projected.reshape(b, Config.patch_grid, Config.patch_grid, -1).permute(0, 3, 1, 2)

        # Create multi-scale features
        features = [spatial]  # Original scale

        # Scale 2: Pooled 2x2
        pooled2 = F.avg_pool2d(spatial, kernel_size=2, stride=2)
        upsampled2 = F.interpolate(pooled2, size=(Config.patch_grid, Config.patch_grid), mode='bilinear', align_corners=False)
        features.append(upsampled2)

        # Scale 3: Pooled 4x4
        pooled4 = F.avg_pool2d(spatial, kernel_size=4, stride=4)
        upsampled4 = F.interpolate(pooled4, size=(Config.patch_grid, Config.patch_grid), mode='bilinear', align_corners=False)
        features.append(upsampled4)

        # Concatenate along channel dimension
        multi_scale = torch.cat(features, dim=1)

        # Fuse multi-scale features
        fused = self.syn_fusion(multi_scale)

        return fused

    def forward(self, pixel_inputs):
        # Get outputs with attention (for visualization)
        outputs = self.model(
            pixel_values=pixel_inputs,
            output_hidden_states=self.use_multi_scale,
            output_attentions=True
        )

        # Store attention maps for visualization if available
        self.last_attentions = outputs.image_attentions if hasattr(outputs, 'image_attentions') else None

        # Check if multi-scale is actually available in this model version
        multi_scale_available = hasattr(outputs, 'image_hidden_states') and outputs.image_hidden_states is not None

        # Log the first time we have to fall back
        if self.use_multi_scale and not multi_scale_available and not self.using_fallback:
            print("WARNING: Multi-scale features not available in this FLAVA model version.")
            print("Available outputs:", list(outputs.keys()))
            print("Falling back to single-scale features.")
            self.using_fallback = True

        if self.use_multi_scale and multi_scale_available:
            # Use last 4 layers
            hidden_states = outputs.image_hidden_states[-4:]
            multi_scale_features = []

            for i, hidden_state in enumerate(hidden_states):
                # Skip CLS token
                patches = hidden_state[:, 1:, :]
                b, n, c = patches.shape
                # Project and reshape to spatial dimensions
                projected = self.projections[i](patches)
                spatial = projected.reshape(b, Config.patch_grid, Config.patch_grid, -1).permute(0, 3, 1, 2)
                multi_scale_features.append(spatial)

            # Concatenate features along channel dimension
            fused_features = torch.cat(multi_scale_features, dim=1)
            # Fuse multi-scale features
            fused_features = self.fusion(fused_features)
            # Apply segmentation head
            seg_logits = self.head(fused_features)

        elif Config.synthetic_multi_scale and not self.using_synthetic:
            # Use synthetic multi-scale features
            print("Using synthetic multi-scale features")
            self.using_synthetic = True

            # Get image embeddings
            patches = outputs.image_embeddings[:, 1:, :]  # Skip CLS token

            # Create synthetic multi-scale features
            fused_features = self.create_synthetic_multi_scale(patches)

            # Apply segmentation head
            seg_logits = self.head(fused_features)

        else:
            # Single-scale approach using image embeddings
            patches = outputs.image_embeddings[:, 1:, :]  # Skip CLS token
            b, n, c = patches.shape
            projected = self.projections[0](patches)
            spatial = projected.reshape(b, Config.patch_grid, Config.patch_grid, -1).permute(0, 3, 1, 2)
            seg_logits = self.head(spatial)

        return seg_logits

    def get_gradcam(self, target_layer='conv_feature'):
        """Generate GradCAM heatmap for interpretability"""
        if target_layer not in self.activation or target_layer not in self.gradients:
            print(f"Warning: {target_layer} activations or gradients not found")
            return None

        # Get activations and gradients for the target layer
        activations = self.activation[target_layer]
        gradients = self.gradients[target_layer]

        # Global average pooling of gradients
        weights = torch.mean(gradients, dim=(2, 3), keepdim=True)

        # Weighted sum of activation maps
        cam = torch.sum(weights * activations, dim=1, keepdim=True)

        # Apply ReLU to focus on features that have a positive influence
        cam = F.relu(cam)

        # Normalize
        if torch.max(cam) > 0:
            cam = cam / torch.max(cam)

        return cam

# ایجاد مدل خالی
model = FLAVASegmenter(base_model_path).to(device)

# بارگذاری بهترین checkpoint
print(f"در حال بارگذاری بهترین مدل از epoch {best_epoch}...")
best_checkpoint_path = os.path.join(save_path, f"model_checkpoint_epoch{best_epoch}.pth")

# بارگذاری با weights_only=False برای جلوگیری از خطای PyTorch 2.6
checkpoint = torch.load(best_checkpoint_path, weights_only=False)

# بارگذاری وزن‌های مدل
model.load_state_dict(checkpoint['model_state_dict'])
print("وزن‌های مدل با موفقیت بارگذاری شدند!")

# ذخیره به عنوان بهترین مدل
best_model_path = os.path.join(save_path, "best_model.pth")
torch.save({
    'epoch': best_epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': checkpoint['optimizer_state_dict'] if 'optimizer_state_dict' in checkpoint else None,
    'val_iou': checkpoint['metrics']['val_iou'][best_epoch-1] if 'metrics' in checkpoint and 'val_iou' in checkpoint['metrics'] else None,
    'val_dice': checkpoint['metrics']['val_dice'][best_epoch-1] if 'metrics' in checkpoint and 'val_dice' in checkpoint['metrics'] else None,
}, best_model_path)
print(f"بهترین مدل در مسیر زیر ذخیره شد: {best_model_path}")

# ایجاد کپی از بهترین checkpoint به عنوان مرجع
best_link_path = os.path.join(save_path, "best_checkpoint.pth")
if os.path.exists(best_link_path):
    os.remove(best_link_path)
shutil.copy2(best_checkpoint_path, best_link_path)
print(f"کپی بهترین checkpoint در مسیر زیر ایجاد شد: {best_link_path}")

# اکنون می‌توانید مدل را برای استفاده‌های بعدی ذخیره کنید
model_only_path = os.path.join(save_path, "best_model_weights_only.pth")
torch.save(model.state_dict(), model_only_path)
print(f"وزن‌های بهترین مدل به صورت جداگانه در مسیر زیر ذخیره شدند: {model_only_path}")

print("\nتمام عملیات با موفقیت انجام شد!")

در حال بارگذاری بهترین مدل از epoch 10...
وزن‌های مدل با موفقیت بارگذاری شدند!
بهترین مدل در مسیر زیر ذخیره شد: /content/drive/MyDrive/new_flava/attention_seg_head_enhanced/best_model.pth
کپی بهترین checkpoint در مسیر زیر ایجاد شد: /content/drive/MyDrive/new_flava/attention_seg_head_enhanced/best_checkpoint.pth
وزن‌های بهترین مدل به صورت جداگانه در مسیر زیر ذخیره شدند: /content/drive/MyDrive/new_flava/attention_seg_head_enhanced/best_model_weights_only.pth

تمام عملیات با موفقیت انجام شد!
