In [4]:
"""
Debug version of HAM10000 Head Analysis Visualization Module
"""

import os
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.patches import Rectangle
import matplotlib.gridspec as gridspec


def debug_data_structure(data_dir: Path):
    """Debug function to inspect the data structure before visualization."""
    
    print("=== DEBUGGING DATA STRUCTURE ===")
    
    # Check if files exist
    similarities_path = data_dir / "head_direction_similarities.npy"
    importance_path = data_dir / "head_importance_analysis.npy"
    
    print(f"Similarities file exists: {similarities_path.exists()}")
    print(f"Importance file exists: {importance_path.exists()}")
    
    if not similarities_path.exists() or not importance_path.exists():
        print("❌ Required data files not found!")
        return False
    
    try:
        # Load and inspect similarities
        print("\n--- Loading similarities data ---")
        direction_similarities = np.load(similarities_path, allow_pickle=True).item()
        print(f"Number of images: {len(direction_similarities)}")
        
        # Check first image to understand structure
        first_key = list(direction_similarities.keys())[0]
        first_data = direction_similarities[first_key]
        print(f"First image key: {first_key}")
        print(f"Keys in first image data: {first_data.keys()}")
        print(f"Predicted class: {first_data['predicted_class']}")
        
        # Check similarities structure
        similarities = first_data['similarities']
        print(f"Similarity keys: {list(similarities.keys())}")
        
        if 'all' in similarities:
            sim_shape = similarities['all'].shape
            print(f"'all' similarities shape: {sim_shape}")
            print(f"Expected: (layers, heads, tokens)")
        
        # Check class distribution
        classes_found = set()
        for img_data in direction_similarities.values():
            classes_found.add(img_data['predicted_class'])
        print(f"Classes found in data: {sorted(classes_found)}")
        
        # Load and inspect importance
        print("\n--- Loading importance data ---")
        head_importance = np.load(importance_path, allow_pickle=True).item()
        print(f"Head importance keys: {list(head_importance.keys())}")
        
        if 'class_head_importance' in head_importance:
            class_imp = head_importance['class_head_importance']
            print(f"Classes in importance data: {list(class_imp.keys())}")
            
            # Check shapes for each class
            for cls_idx, imp_data in class_imp.items():
                print(f"Class {cls_idx} importance shape: {imp_data.shape}")
        
        if 'class_image_counts' in head_importance:
            counts = head_importance['class_image_counts']
            print(f"Image counts per class: {counts}")
        
        # Check token patterns
        print("\n--- Checking token patterns ---")
        token_patterns = {}
        for class_idx in range(7):
            pattern_path = data_dir / f"token_patterns_class_{class_idx}.npy"
            if pattern_path.exists():
                try:
                    patterns = np.load(pattern_path, allow_pickle=True).item()
                    token_patterns[class_idx] = patterns
                    print(f"Token patterns for class {class_idx}: ✓")
                    print(f"  Keys: {list(patterns.keys())}")
                    if 'n_images' in patterns:
                        print(f"  Images: {patterns['n_images']}")
                except Exception as e:
                    print(f"Token patterns for class {class_idx}: ❌ ({e})")
            else:
                print(f"Token patterns for class {class_idx}: File not found")
        
        return True
        
    except Exception as e:
        print(f"❌ Error loading data: {e}")
        import traceback
        traceback.print_exc()
        return False


def safe_visualize_head_importance(head_importance: Dict, save_dir: Path):
    """Safer version of head importance visualization with error handling."""
    
    print("\n=== TESTING HEAD IMPORTANCE VISUALIZATION ===")
    
    try:
        importance_data = head_importance['class_head_importance']
        available_classes = list(importance_data.keys())
        print(f"Available classes: {available_classes}")
        
        # Determine actual number of classes and shape
        first_class = available_classes[0]
        shape = importance_data[first_class].shape
        num_layers, num_heads = shape
        print(f"Data shape per class: {shape} (layers x heads)")
        
        # Create figure based on actual number of classes
        n_classes = len(available_classes)
        if n_classes <= 3:
            fig, axes = plt.subplots(1, n_classes, figsize=(6*n_classes, 6))
            if n_classes == 1:
                axes = [axes]
        elif n_classes <= 4:
            fig, axes = plt.subplots(1, 4, figsize=(24, 6))
        else:
            # For more than 4 classes, use 2 rows
            n_cols = min(4, n_classes)
            n_rows = (n_classes + n_cols - 1) // n_cols
            fig, axes = plt.subplots(n_rows, n_cols, figsize=(6*n_cols, 6*n_rows))
            axes = axes.flatten() if n_classes > 1 else [axes]
        
        class_names = {
            0: 'akiec', 1: 'bcc', 2: 'bkl', 3: 'df', 4: 'mel', 5: 'nv', 6: 'vasc'
        }
        
        class_colors = {
            0: '#FF6B6B', 1: '#4ECDC4', 2: '#45B7D1', 3: '#96CEB4',
            4: '#FFEAA7', 5: '#DDA0DD', 6: '#FFA07A'
        }
        
        for i, class_idx in enumerate(available_classes):
            if i >= len(axes):
                break
                
            ax = axes[i]
            heatmap_data = importance_data[class_idx]
            
            # Create colormap
            colors = ['white', class_colors.get(class_idx, '#888888')]
            cmap = sns.blend_palette(colors, as_cmap=True)
            
            # Plot heatmap
            im = ax.imshow(heatmap_data, cmap=cmap, aspect='auto', vmin=0, vmax=1)
            
            # Customize
            ax.set_xlabel('Head Index', fontsize=12)
            ax.set_ylabel('Layer', fontsize=12)
            class_name = class_names.get(class_idx, f'Class_{class_idx}')
            ax.set_title(f'{class_name.upper()} Head Importance', fontsize=14, fontweight='bold')
            
            # Set ticks
            ax.set_xticks(range(num_heads))
            ax.set_yticks(range(num_layers))
            ax.set_yticklabels([f'L{7 + i}' for i in range(num_layers)])  # Assuming last 5 layers
            
            # Add colorbar
            cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
            cbar.set_label('Importance Score', fontsize=10)
        
        # Hide unused subplots
        for i in range(len(available_classes), len(axes)):
            axes[i].set_visible(False)
        
        plt.suptitle('Head Importance Analysis (Debug Version)', fontsize=16, fontweight='bold')
        plt.tight_layout()
        
        # Save
        save_path = save_dir / 'debug_head_importance.png'
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        plt.close()
        
        print(f"✓ Head importance visualization saved to {save_path}")
        return True
        
    except Exception as e:
        print(f"❌ Error in head importance visualization: {e}")
        import traceback
        traceback.print_exc()
        return False


def test_basic_plot():
    """Test if basic matplotlib works."""
    try:
        fig, ax = plt.subplots(figsize=(6, 4))
        ax.plot([1, 2, 3], [1, 4, 2])
        ax.set_title("Basic Test Plot")
        plt.close()
        print("✓ Basic matplotlib works")
        return True
    except Exception as e:
        print(f"❌ Basic matplotlib test failed: {e}")
        return False


def debug_and_test_visualization(data_dir: Path):
    """Main debug function."""
    
    print("STARTING VISUALIZATION DEBUG")
    print("="*50)
    
    # Test 1: Basic matplotlib
    if not test_basic_plot():
        return
    
    # Test 2: Data structure
    if not debug_data_structure(data_dir):
        return
    
    # Test 3: Try basic visualization
    try:
        print("\n=== ATTEMPTING SAFE VISUALIZATION ===")
        
        # Load data
        importance_path = data_dir / "head_importance_analysis.npy"
        head_importance = np.load(importance_path, allow_pickle=True).item()
        
        # Create viz directory
        viz_dir = data_dir / "debug_visualizations"
        os.makedirs(viz_dir, exist_ok=True)
        
        # Try safe head importance visualization
        success = safe_visualize_head_importance(head_importance, viz_dir)
        
        if success:
            print("\n✅ DEBUG VISUALIZATION SUCCESSFUL!")
            print(f"Check results in: {viz_dir}")
        else:
            print("\n❌ Visualization failed even in safe mode")
            
    except Exception as e:
        print(f"\n❌ Critical error in debug visualization: {e}")
        import traceback
        traceback.print_exc()


# Usage
if __name__ == "__main__":
    data_directory = Path("./results/train/head_analysis")  # Update this path
    debug_and_test_visualization(data_directory)

STARTING VISUALIZATION DEBUG
✓ Basic matplotlib works
=== DEBUGGING DATA STRUCTURE ===
Similarities file exists: True
Importance file exists: True

--- Loading similarities data ---
Number of images: 6409
First image key: bcc_ISIC_0031789
Keys in first image data: dict_keys(['predicted_class', 'similarities'])
Predicted class: 5
Similarity keys: [0, 1, 2, 3, 4, 6, 'all']
'all' similarities shape: (5, 12, 197)
Expected: (layers, heads, tokens)
Classes found in data: [0, 1, 2, 3, 4, 5, 6]

--- Loading importance data ---
Head importance keys: ['class_head_importance', 'class_ranked_heads', 'images_by_class', 'class_image_counts', 'similarity_stats', 'top_heads_per_class']
Classes in importance data: [0, 1, 2, 3, 4, 5, 6]
Class 0 importance shape: (5, 12)
Class 1 importance shape: (5, 12)
Class 2 importance shape: (5, 12)
Class 3 importance shape: (5, 12)
Class 4 importance shape: (5, 12)
Class 5 importance shape: (5, 12)
Class 6 importance shape: (5, 12)
Image counts per class: {0: 211, 

In [None]:
import head_analysis_visualization
from pathlib import Path

data_directory = Path("./results/train/head_analysis")
head_analysis_visualization.create_comprehensive_visualizations(data_directory)

Loading HAM10000 analysis data...
Generating comprehensive visualizations for HAM10000 analysis...
1/8 - Head Importance Heatmaps...


  plt.tight_layout()


2/8 - Similarity Distributions...


KeyboardInterrupt: 

In [1]:
import ham10k

config = ham10k.Config()
ham10k.evaluate_saved_model("./models/ham10000/best_ham10000_vit_model_2.pth", config)


🔍 Loading model from: ./models/ham10000/best_ham10000_vit_model_2.pth
📋 Classes: ['akiec', 'bcc', 'bkl', 'df', 'mel', 'nv', 'vasc']
📊 Dataset splits - Train: 6,409, Val: 1,603, Test: 2,003
📁 Organizing images into ham10k/preprocessed
📂 Found 6409 images already organized in train folder
📋 Using existing organized images
⚖️  Class weights:
  akiec: 4.381
  bcc: 2.783
  bkl: 1.302
  df: 12.373
  mel: 1.286
  nv: 0.213
  vasc: 10.061
🤖 Creating ViT model...
Loading weights from ./models/ham10000/best_ham10000_vit_model_2.pth
Checkpoint keys: ['epoch', 'model_state_dict', 'optimizer_state_dict', 'val_f1', 'class_names']
✅ Using model_state_dict from checkpoint
Epoch: 19
Validation F1: 0.7720340927620811
Classes: ['akiec', 'bcc', 'bkl', 'df', 'mel', 'nv', 'vasc']
✅ Model weights loaded successfully
Loaded ImageNet pretrained ViT-B


Validating: 100%|██████████| 63/63 [00:05<00:00, 11.03it/s, Loss=0.4361, Acc=86.2%]

              precision    recall  f1-score   support

       akiec       0.57      0.71      0.63        65
         bcc       0.81      0.79      0.80       103
         bkl       0.82      0.67      0.74       220
          df       0.86      0.83      0.84        23
         mel       0.63      0.67      0.65       223
          nv       0.93      0.94      0.93      1341
        vasc       0.81      0.89      0.85        28

    accuracy                           0.86      2003
   macro avg       0.78      0.78      0.78      2003
weighted avg       0.86      0.86      0.86      2003






In [2]:

import ham10k

config = ham10k.Config()
ham10k.evaluate_saved_model("./models/ham10000/best_ham10000_vit_model.pth", config)


🔍 Loading model from: ./models/ham10000/best_ham10000_vit_model.pth
📋 Classes: ['akiec', 'bcc', 'bkl', 'df', 'mel', 'nv', 'vasc']
📊 Dataset splits - Train: 6,409, Val: 1,603, Test: 2,003
📁 Organizing images into ham10k/preprocessed
📂 Found 6409 images already organized in train folder
📋 Using existing organized images
⚖️  Class weights:
  akiec: 4.381
  bcc: 2.783
  bkl: 1.302
  df: 12.373
  mel: 1.286
  nv: 0.213
  vasc: 10.061
🤖 Creating ViT model...
Loading weights from ./models/ham10000/best_ham10000_vit_model_2.pth
Checkpoint keys: ['epoch', 'model_state_dict', 'optimizer_state_dict', 'val_f1', 'class_names']
✅ Using model_state_dict from checkpoint
Epoch: 19
Validation F1: 0.7720340927620811
Classes: ['akiec', 'bcc', 'bkl', 'df', 'mel', 'nv', 'vasc']
✅ Model weights loaded successfully
Loaded ImageNet pretrained ViT-B


Validating: 100%|██████████| 63/63 [00:05<00:00, 11.45it/s, Loss=0.5580, Acc=79.8%]

              precision    recall  f1-score   support

       akiec       0.74      0.48      0.58        65
         bcc       0.77      0.83      0.80       103
         bkl       0.60      0.79      0.68       220
          df       0.73      0.83      0.78        23
         mel       0.46      0.78      0.58       223
          nv       0.97      0.81      0.89      1341
        vasc       0.83      0.86      0.84        28

    accuracy                           0.80      2003
   macro avg       0.73      0.77      0.73      2003
weighted avg       0.85      0.80      0.81      2003






In [4]:
import torch

# Path to the downloaded fine-tuned weights
weights_path = './model/vit_b-ImageNet_class_init-frozen_False-dataset_Hyperkvasir_anatomical.pth'
# Make sure the path is correct!

# --- CONFIRMED NUMBER OF CLASSES ---
# Based on your provided table "Anatomical landmark recognition"
# Cecum, Ileum, Retroflex-rectum, Pylorus, Retroflex-stomach, Z-line
# There are 6 classes.
CONFIRMED_NUM_CLASSES = 6
print(f"Using CONFIRMED_NUM_CLASSES = {CONFIRMED_NUM_CLASSES}")
num_classes_hyperkvasir_anatomical = CONFIRMED_NUM_CLASSES # Use this directly

try:
    # Set weights_only=True if you only want to load weights and not pickled code.
    # However, since this is a full checkpoint, weights_only=False might be needed
    # if their model saving relies on pickled custom classes, though less likely for state_dict.
    # For safety, if you trust the source, weights_only=False is what you used.
    # If you only need the state_dict, and it's just tensors, weights_only=True is safer.
    # Let's try to be robust.
    try:
        loaded_checkpoint = torch.load(weights_path, map_location='cpu', weights_only=False)
        print("Loaded checkpoint with weights_only=True")
    except Exception as e_true:
        print(f"Could not load with weights_only=True ({e_true}), trying weights_only=False.")
        loaded_checkpoint = torch.load(weights_path, map_location='cpu', weights_only=False)
        print("Loaded checkpoint with weights_only=False")


    actual_model_state_dict = None
    if isinstance(loaded_checkpoint, dict) and 'model_state_dict' in loaded_checkpoint:
        actual_model_state_dict = loaded_checkpoint['model_state_dict']
        print("Found 'model_state_dict' in the loaded checkpoint.")
    elif isinstance(loaded_checkpoint, dict) and 'model' in loaded_checkpoint: # Fallback for other common patterns
        actual_model_state_dict = loaded_checkpoint['model']
        print("Found 'model' in the loaded checkpoint (using as state_dict).")
    elif isinstance(loaded_checkpoint, dict) and 'state_dict' in loaded_checkpoint: # Another fallback
        actual_model_state_dict = loaded_checkpoint['state_dict']
        print("Found 'state_dict' in the loaded checkpoint (using as state_dict).")
    elif isinstance(loaded_checkpoint, dict): # If it's a dict but not one of the known keys, maybe it IS the state_dict
        actual_model_state_dict = loaded_checkpoint
        print("Loaded checkpoint is a dictionary, assuming it is the state_dict directly.")
    else:
        print("Error: Loaded checkpoint is not a dictionary or recognized structure.")


    if actual_model_state_dict:
        # Their VisionTransformer_from_Any uses self.lin_head
        # It's also possible that due to DataParallel during training, keys might be prefixed with 'module.'
        head_weight_key_found = None
        if 'lin_head.weight' in actual_model_state_dict:
            head_weight_key_found = 'lin_head.weight'
        elif 'module.lin_head.weight' in actual_model_state_dict:
            head_weight_key_found = 'module.lin_head.weight'
        elif 'head.weight' in actual_model_state_dict: # Fallback to standard timm naming
            head_weight_key_found = 'head.weight'
        elif 'module.head.weight' in actual_model_state_dict: # Fallback with module prefix
            head_weight_key_found = 'module.head.weight'

        if head_weight_key_found:
            num_classes_from_weights = actual_model_state_dict[head_weight_key_found].shape[0]
            print(f"Inferred num_classes from '{head_weight_key_found}': {num_classes_from_weights}")
            if num_classes_from_weights != CONFIRMED_NUM_CLASSES:
                print(f"WARNING: Number of classes from weights ({num_classes_from_weights}) "
                      f"does NOT match CONFIRMED_NUM_CLASSES ({CONFIRMED_NUM_CLASSES}). "
                      "Please double-check!")
                # You might want to trust the weights file if there's a discrepancy
                # num_classes_hyperkvasir_anatomical = num_classes_from_weights
            else:
                print("Number of classes from weights matches confirmed value. Good.")
        else:
            print("Could not find 'lin_head.weight' or 'head.weight' (with or without 'module.' prefix) in the model_state_dict.")
            print("Keys found in model_state_dict (last 10):", list(actual_model_state_dict.keys())[-10:])
            print(f"Proceeding with CONFIRMED_NUM_CLASSES = {CONFIRMED_NUM_CLASSES}. Ensure this is correct.")
    else:
        print("Could not extract 'model_state_dict' for inspection.")
        print(f"Proceeding with CONFIRMED_NUM_CLASSES = {CONFIRMED_NUM_CLASSES}. Ensure this is correct.")


except FileNotFoundError:
    print(f"ERROR: Weights file not found at {weights_path}")
    print("Please download the file and update the path.")
    # num_classes_hyperkvasir_anatomical is already set to CONFIRMED_NUM_CLASSES
except Exception as e:
    print(f"An error occurred during loading or inspection: {e}")
    # num_classes_hyperkvasir_anatomical is already set to CONFIRMED_NUM_CLASSES


Using CONFIRMED_NUM_CLASSES = 6
Loaded checkpoint with weights_only=True
Found 'model_state_dict' in the loaded checkpoint.
Inferred num_classes from 'lin_head.weight': 6
Number of classes from weights matches confirmed value. Good.


In [11]:
import torch
import torch.nn as nn
# Assuming ssl4gie_models.py contains their VisionTransformer_from_Any class
from ssl4gie_models import VisionTransformer_from_Any # Make sure this import works

# --- Configuration ---
NUM_CLASSES_HYPERKVASIR_ANATOMICAL = 6 # CORRECTED
PATH_TO_FINETUNED_WEIGHTS = './model/vit_b-ImageNet_class_init-frozen_False-dataset_Hyperkvasir_anatomical.pth'

# Instantiate their model structure
ssl_model = VisionTransformer_from_Any(
    head=True,
    num_classes=NUM_CLASSES_HYPERKVASIR_ANATOMICAL,
    frozen=False,
    dense=False,
    det=False,
    fixed_size=224,
    embed_dim=768,
    depth=12,
    num_heads=12,
    out_token='cls',
    ImageNet_weights=False
)
ssl_model.eval()

print(f"Loading fine-tuned weights from: {PATH_TO_FINETUNED_WEIGHTS}")
actual_model_state_dict_from_checkpoint = None # Initialize

try:
    try:
        checkpoint = torch.load(PATH_TO_FINETUNED_WEIGHTS, map_location='cpu', weights_only=True)
        print("Loaded checkpoint with weights_only=True")
    except Exception:
        checkpoint = torch.load(PATH_TO_FINETUNED_WEIGHTS, map_location='cpu', weights_only=False)
        print("Loaded checkpoint with weights_only=False (fallback)")

    if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
        actual_model_state_dict_from_checkpoint = checkpoint['model_state_dict']
    elif isinstance(checkpoint, dict): # If it's a dict but no 'model_state_dict', assume it IS the state_dict
         actual_model_state_dict_from_checkpoint = checkpoint
         print("Checkpoint is a dict, but no 'model_state_dict' key. Assuming the checkpoint itself is the state_dict.")
    else:
        print("ERROR: Loaded checkpoint is not a dictionary or 'model_state_dict' key not found.")
        raise ValueError("Invalid checkpoint structure")


    if actual_model_state_dict_from_checkpoint:
        # Create a new state_dict without 'module.' prefix if necessary.
        new_state_dict_for_ssl_model = {}
        has_module_prefix = any(key.startswith('module.') for key in actual_model_state_dict_from_checkpoint.keys())

        if has_module_prefix:
            print("Detected 'module.' prefix in checkpoint keys. Removing it for SSL model.")
        for k, v in actual_model_state_dict_from_checkpoint.items():
            name = k[7:] if has_module_prefix and k.startswith('module.') else k
            new_state_dict_for_ssl_model[name] = v
        
        missing_keys, unexpected_keys = ssl_model.load_state_dict(new_state_dict_for_ssl_model, strict=False)
        print("Loading into SSL4GIE model structure:")
        if missing_keys: print(f"  Missing keys in SSL model: {missing_keys}")
        if unexpected_keys: print(f"  Unexpected keys in SSL model: {unexpected_keys}")
        
        # Check common success patterns
        is_successful_load = True
        if not missing_keys and not unexpected_keys:
            print("  Successfully loaded all weights into SSL4GIE model structure!")
        elif not missing_keys and unexpected_keys == ['head.weight', 'head.bias'] and hasattr(ssl_model, 'lin_head'):
            # This case means their VisionTransformer_from_Any base (timm.ViT) has 'head', but their subclass uses 'lin_head'.
            # The new_state_dict_for_ssl_model would have 'lin_head' if 'module.lin_head' was in checkpoint,
            # or 'head' if 'module.head' was in checkpoint.
            # This specific condition might need adjustment based on actual keys.
            print("  Successfully loaded backbone into SSL model. Head layer name might differ (e.g., 'head' vs 'lin_head'). This is usually okay if we map it later.")
        elif not unexpected_keys and all(k.startswith("lin_head") for k in missing_keys) and hasattr(ssl_model, 'head') and not hasattr(ssl_model, 'lin_head'):
            print(" SSL model expects 'lin_head' but got 'head' or vice versa, this will be handled in mapping to your model")
        else:
            print("  Check missing/unexpected keys during SSL4GIE model load. Some might be critical.")
            if missing_keys or unexpected_keys: # Be more conservative if any mismatch
                 is_successful_load = False

        if not is_successful_load:
             print("  Review log for SSL model loading issues.")
             # Consider exiting if critical weights are missing

    else:
        print("ERROR: Could not extract a valid model_state_dict from the loaded checkpoint.")
        # exit()

except FileNotFoundError:
    print(f"ERROR: Fine-tuned weights file not found at {PATH_TO_FINETUNED_WEIGHTS}")
    # exit()
except Exception as e:
    print(f"An error occurred while loading weights into SSL4GIE model: {e}")
    # exit()

# Now, ssl_model contains the loaded fine-tuned weights.
# Get its state_dict for transferring to your model if loading was successful.
if 'is_successful_load' in locals() and is_successful_load and actual_model_state_dict_from_checkpoint is not None:
    ssl_model_state_dict = ssl_model.state_dict() # Get the correctly formatted state_dict from ssl_model
    print("\nProceeding to transfer weights to your custom model.")
    # Assuming your VisionTransformer code is in 'your_vit_arch.py'
    from translrp.ViT_new import vit_base_patch16_224 # Or your general VisionTransformer class

    # Number of output classes for YOUR project.
    # This can be the same as NUM_CLASSES_HYPERKVASIR_ANATOMICAL
    # or different if you are adapting the fine-tuned model further.
    # For a direct load, it should be the same.
    NUM_CLASSES_FOR_YOUR_PROJECT = NUM_CLASSES_HYPERKVASIR_ANATOMICAL

    print(f"\nInstantiating your ViT-B model with {NUM_CLASSES_FOR_YOUR_PROJECT} classes...")
    my_vit_model = vit_base_patch16_224(pretrained=False, num_classes=NUM_CLASSES_FOR_YOUR_PROJECT)
    my_vit_model.eval()

    # Prepare state_dict for your model:
    # The main difference will be the head layer name
    # SSL4GIE: lin_head.weight, lin_head.bias
    # Your model: head.weight, head.bias
    final_state_dict_for_my_model = {}
    for key, value in ssl_model_state_dict.items():
        if key == "lin_head.weight":
            final_state_dict_for_my_model["head.weight"] = value
        elif key == "lin_head.bias":
            final_state_dict_for_my_model["head.bias"] = value
        else:
            final_state_dict_for_my_model[key] = value

    print("\nLoading state_dict into your model architecture:")
    missing_keys, unexpected_keys = my_vit_model.load_state_dict(final_state_dict_for_my_model, strict=False)

    if missing_keys:
        print("  Missing keys in your model (expected from SSL4GIE model):", missing_keys)
    if unexpected_keys:
        print("  Unexpected keys in your model (not in SSL4GIE model's state_dict):", unexpected_keys)

    if not missing_keys and not unexpected_keys:
        print("  Successfully loaded all weights into your model architecture!")
    else:
        print("  Potential issues in weight loading. Review missing/unexpected keys.")
        print("  Common ViT-B keys should match. Ensure embed_dim, depth, num_heads are the same.")

    # --- Verification (Optional) ---
    # If you have a sample input image for anatomical landmark recognition and its label
    # you could try a forward pass.
    # For now, a dummy input:
    dummy_input = torch.randn(1, 3, 224, 224) # Batch of 1 image
    try:
        with torch.no_grad():
            output = my_vit_model(dummy_input)
        print(f"\nDummy input passed through your loaded model. Output shape: {output.shape}")
        assert output.shape == (1, NUM_CLASSES_FOR_YOUR_PROJECT)
        print("Output shape is correct.")
    except Exception as e:
        print(f"\nError passing dummy input through your loaded model: {e}")
else:
    print("\nCritical error during SSL model weight loading. Cannot proceed to transfer to your model.")
    ssl_model_state_dict_for_transfer = None # Indicate failure
    

Loading fine-tuned weights from: ./model/vit_b-ImageNet_class_init-frozen_False-dataset_Hyperkvasir_anatomical.pth
Loaded checkpoint with weights_only=False (fallback)
Loading into SSL4GIE model structure:
  Successfully loaded all weights into SSL4GIE model structure!

Proceeding to transfer weights to your custom model.

Instantiating your ViT-B model with 6 classes...

Loading state_dict into your model architecture:
  Successfully loaded all weights into your model architecture!

Dummy input passed through your loaded model. Output shape: torch.Size([1, 6])
Output shape is correct.


In [8]:
# Assuming your VisionTransformer code is in 'your_vit_arch.py'
from translrp.ViT_new import vit_base_patch16_224 # Or your general VisionTransformer class

# Number of output classes for YOUR project.
# This can be the same as NUM_CLASSES_HYPERKVASIR_ANATOMICAL
# or different if you are adapting the fine-tuned model further.
# For a direct load, it should be the same.
NUM_CLASSES_FOR_YOUR_PROJECT = NUM_CLASSES_HYPERKVASIR_ANATOMICAL

print(f"\nInstantiating your ViT-B model with {NUM_CLASSES_FOR_YOUR_PROJECT} classes...")
my_vit_model = vit_base_patch16_224(pretrained=False, num_classes=NUM_CLASSES_FOR_YOUR_PROJECT)
my_vit_model.eval()

# Prepare state_dict for your model:
# The main difference will be the head layer name
# SSL4GIE: lin_head.weight, lin_head.bias
# Your model: head.weight, head.bias
final_state_dict_for_my_model = {}
for key, value in ssl_model_state_dict.items():
    if key == "lin_head.weight":
        final_state_dict_for_my_model["head.weight"] = value
    elif key == "lin_head.bias":
        final_state_dict_for_my_model["head.bias"] = value
    else:
        final_state_dict_for_my_model[key] = value

print("\nLoading state_dict into your model architecture:")
missing_keys, unexpected_keys = my_vit_model.load_state_dict(final_state_dict_for_my_model, strict=False)

if missing_keys:
    print("  Missing keys in your model (expected from SSL4GIE model):", missing_keys)
if unexpected_keys:
    print("  Unexpected keys in your model (not in SSL4GIE model's state_dict):", unexpected_keys)

if not missing_keys and not unexpected_keys:
    print("  Successfully loaded all weights into your model architecture!")
else:
    print("  Potential issues in weight loading. Review missing/unexpected keys.")
    print("  Common ViT-B keys should match. Ensure embed_dim, depth, num_heads are the same.")

# --- Verification (Optional) ---
# If you have a sample input image for anatomical landmark recognition and its label
# you could try a forward pass.
# For now, a dummy input:
dummy_input = torch.randn(1, 3, 224, 224) # Batch of 1 image
try:
    with torch.no_grad():
        output = my_vit_model(dummy_input)
    print(f"\nDummy input passed through your loaded model. Output shape: {output.shape}")
    assert output.shape == (1, NUM_CLASSES_FOR_YOUR_PROJECT)
    print("Output shape is correct.")
except Exception as e:
    print(f"\nError passing dummy input through your loaded model: {e}")


Instantiating your ViT-B model with 6 classes...


NameError: name 'ssl_model_state_dict' is not defined

In [13]:
import torch
import torch.nn as nn
from timm.models.vision_transformer import VisionTransformer as TimmVisionTransformer
from translrp.ViT_new import vit_base_patch16_224

# --- SSL4GIE's VisionTransformer_from_Any (Simplified for this use case) ---
# We only need the parts relevant to head=True, dense=False, det=False
class SSL4GIE_ViT(TimmVisionTransformer):
    def __init__(
        self,
        num_classes,
        embed_dim=768,
        depth=12,
        num_heads=12,
        # Unused for this specific loading but kept for signature compatibility if strict loading was used
        head=True, 
        frozen=False,
        dense=False,
        det=False,
        fixed_size=224,
        out_token='cls',
        ImageNet_weights=False, # Will be False when loading their fine-tuned checkpoint
    ):
        super().__init__(
            patch_size=16, embed_dim=embed_dim, depth=depth, num_heads=num_heads, num_classes=0 # Base timm ViT's head is Identity if num_classes=0
        )
        # Original SSL4GIE model sets self.head = nn.Identity() and then adds self.lin_head
        # Replicating that structure for faithful intermediate loading:
        self.head = nn.Identity() # This is the timm.ViT's head
        self.lin_head = nn.Linear(embed_dim, num_classes) # This is SSL4GIE's specific head
        
        self.head_bool = head # For their forward logic
        self.out_token = out_token # For their forward logic

        # Unused for this specific task but part of their class structure
        self.frozen = frozen
        self.dense = dense
        self.det = det
        if ImageNet_weights: # This path won't be taken when loading their fine-tuned ckpt
            # This is a placeholder, as the actual download URL is for their ImageNet_class init
            # which we are bypassing by loading their finetuned weights directly.
            print("WARNING: ImageNet_weights=True in SSL4GIE_ViT init, but we load custom weights.")


    def forward(self, x): # Simplified forward for classification
        x = self.forward_features(x) # From TimmVisionTransformer
        if self.out_token == "cls":
            x = x[:, 0]
        elif self.out_token == "spatial": # Not typical for their classification ViT-B
            x = x[:, 1:].mean(1)
        # If self.head_bool is True (which it is for classification)
        x = self.lin_head(x) # Use their specific linear head
        return x


def transform_and_save_weights(
    input_checkpoint_path,
    output_weights_path,
    num_classes,
    verbose=True):
    """
    Loads weights from an SSL4GIE checkpoint, transforms them for a
    translrp.ViT_new architecture, and saves the transformed state_dict.

    Args:
        input_checkpoint_path (str): Path to the SSL4GIE .pth checkpoint file.
        output_weights_path (str): Path to save the transformed model's state_dict.
        num_classes (int): Number of output classes for the final model.
        verbose (bool): Whether to print detailed loading messages.
    """
    if verbose: print(f"--- Starting weight transformation ---")
    if verbose: print(f"Input checkpoint: {input_checkpoint_path}")
    if verbose: print(f"Output for transformed weights: {output_weights_path}")
    if verbose: print(f"Number of classes: {num_classes}")

    # 1. Instantiate SSL4GIE's model structure
    ssl_temp_model = SSL4GIE_ViT(
        num_classes=num_classes,
        embed_dim=768, depth=12, num_heads=12, # Standard ViT-B params
        # Other params are defaults or not affecting weight loading structure for this case
    )
    ssl_temp_model.eval()

    # 2. Load the downloaded fine-tuned weights into ssl_temp_model
    actual_model_state_dict_from_checkpoint = None
    try:
        try:
            checkpoint = torch.load(input_checkpoint_path, map_location='cpu', weights_only=True)
            if verbose: print("Loaded checkpoint with weights_only=True")
        except Exception:
            checkpoint = torch.load(input_checkpoint_path, map_location='cpu', weights_only=False)
            if verbose: print("Loaded checkpoint with weights_only=False (fallback)")

        if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
            actual_model_state_dict_from_checkpoint = checkpoint['model_state_dict']
        elif isinstance(checkpoint, dict):
            actual_model_state_dict_from_checkpoint = checkpoint
            if verbose: print("Checkpoint is a dict, but no 'model_state_dict' key. Assuming it's the state_dict.")
        else:
            raise ValueError("Loaded checkpoint is not a dictionary or 'model_state_dict' key not found.")

        if actual_model_state_dict_from_checkpoint:
            new_state_dict_for_ssl_model = {}
            has_module_prefix = any(key.startswith('module.') for key in actual_model_state_dict_from_checkpoint.keys())
            if has_module_prefix and verbose:
                print("Detected 'module.' prefix in checkpoint. Removing it.")
            for k, v in actual_model_state_dict_from_checkpoint.items():
                name = k[7:] if has_module_prefix and k.startswith('module.') else k
                new_state_dict_for_ssl_model[name] = v
            
            ssl_temp_model.load_state_dict(new_state_dict_for_ssl_model, strict=True) # Be strict here
            if verbose: print("Successfully loaded weights into intermediate SSL4GIE model structure.")
        else:
            raise ValueError("Could not extract a valid model_state_dict from the checkpoint.")

    except Exception as e:
        print(f"ERROR loading weights into intermediate SSL4GIE model: {e}")
        return

    # 3. Get the state_dict from the populated ssl_temp_model
    ssl_model_clean_state_dict = ssl_temp_model.state_dict()

    # 4. Instantiate your target model structure
    if verbose: print(f"Instantiating your target ViT-B model with {num_classes} classes...")
    target_model = vit_base_patch16_224(pretrained=False, num_classes=num_classes)
    target_model.eval()

    # 5. Prepare state_dict for your target model (mapping SSL4GIE's 'lin_head' to your 'head')
    final_state_dict_for_target_model = {}
    for key, value in ssl_model_clean_state_dict.items():
        if key == "lin_head.weight": # SSL4GIE_ViT uses lin_head
            final_state_dict_for_target_model["head.weight"] = value # Your ViT_new uses head
        elif key == "lin_head.bias":
            final_state_dict_for_target_model["head.bias"] = value
        else:
            final_state_dict_for_target_model[key] = value
    
    # 6. Load into your target model
    try:
        target_model.load_state_dict(final_state_dict_for_target_model, strict=True) # Be strict
        if verbose: print("Successfully loaded transformed weights into your target model architecture!")
    except Exception as e:
        print(f"ERROR loading transformed weights into your target model: {e}")
        # You might want to try strict=False here for debugging, but True is preferred for final check
        # missing_keys, unexpected_keys = target_model.load_state_dict(final_state_dict_for_target_model, strict=False)
        # print(f"  Missing keys: {missing_keys}")
        # print(f"  Unexpected keys: {unexpected_keys}")
        return

    # 7. Save the state_dict of your transformed model
    try:
        torch.save(target_model.state_dict(), output_weights_path)
        if verbose: print(f"Successfully saved transformed model state_dict to: {output_weights_path}")
    except Exception as e:
        print(f"ERROR saving transformed model state_dict: {e}")
        return

    if verbose: print(f"--- Weight transformation complete ---")


if __name__ == '__main__':
    # --- Configuration ---
    NUM_CLASSES = 6 # For Hyperkvasir anatomical landmarks
    
    # Path to the downloaded SSL4GIE fine-tuned weights
    INPUT_SSL4GIE_CHECKPOINT_PATH = './model/vit_b-ImageNet_class_init-frozen_False-dataset_Hyperkvasir_anatomical.pth'
    
    # Path where you want to save the transformed weights compatible with your translrp.ViT_new
    OUTPUT_TRANSFORMED_WEIGHTS_PATH = './model/vit_b_hyperkvasir_anatomical_for_translrp.pth'

    transform_and_save_weights(
        input_checkpoint_path=INPUT_SSL4GIE_CHECKPOINT_PATH,
        output_weights_path=OUTPUT_TRANSFORMED_WEIGHTS_PATH,
        num_classes=NUM_CLASSES,
        verbose=True
    )

    # --- Optional: Test loading the saved transformed weights ---
    print("\n--- Testing loading the newly saved transformed weights ---")
    try:
        test_model = vit_base_patch16_224(pretrained=False, num_classes=NUM_CLASSES)
        test_model.load_state_dict(torch.load(OUTPUT_TRANSFORMED_WEIGHTS_PATH, map_location='cpu'))
        test_model.eval()
        print(f"Successfully loaded saved transformed weights into a new instance of your model.")
        
        # Dummy input test
        dummy_input = torch.randn(1, 3, 224, 224)
        with torch.no_grad():
            output = test_model(dummy_input)
        print(f"Dummy input passed through test_model. Output shape: {output.shape}")
        assert output.shape == (1, NUM_CLASSES)
        print("Output shape is correct.")

    except FileNotFoundError:
        print(f"Test load failed: Transformed weights file not found at {OUTPUT_TRANSFORMED_WEIGHTS_PATH}")
    except Exception as e:
        print(f"Test load failed: {e}")



--- Starting weight transformation ---
Input checkpoint: ./model/vit_b-ImageNet_class_init-frozen_False-dataset_Hyperkvasir_anatomical.pth
Output for transformed weights: ./model/vit_b_hyperkvasir_anatomical_for_translrp.pth
Number of classes: 6
Loaded checkpoint with weights_only=False (fallback)
Successfully loaded weights into intermediate SSL4GIE model structure.
Instantiating your target ViT-B model with 6 classes...
Successfully loaded transformed weights into your target model architecture!
Successfully saved transformed model state_dict to: ./model/vit_b_hyperkvasir_anatomical_for_translrp.pth
--- Weight transformation complete ---

--- Testing loading the newly saved transformed weights ---
Successfully loaded saved transformed weights into a new instance of your model.
Dummy input passed through test_model. Output shape: torch.Size([1, 6])
Output shape is correct.


In [14]:
from pathlib import Path
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import classification_report, accuracy_score, f1_score # For evaluation
from tqdm import tqdm # For progress bars

# Assuming your vit.preprocessing and other imports are available
import vit.preprocessing as preprocessing
# Import your model loading and PipelineConfig
import vit.model as model_module # Assuming vit.model contains your CLASSES, CLS2IDX, IDX2CLS and model loading
from config import PipelineConfig


# --- Modified HyperkvasirDataset ---
# This version will use the CLS2IDX imported from vit.model for the current run
class HyperkvasirDataset(Dataset):
    def __init__(self, image_dir, current_cls2idx, current_idx2cls, transform=None):
        self.image_dir = Path(image_dir)
        self.image_paths = sorted(list(self.image_dir.glob("*.jpg")))
        self.transform = transform
        
        # Use the mappings provided for THIS RUN
        self.cls_to_idx = current_cls2idx
        self.idx_to_cls = current_idx2cls # For potentially deriving ordered class names
        
        # Derive class_names for reports based on the numerical order of indices in idx_to_cls
        # This ensures the report labels match the 0-5 order of the model output
        self.class_names_for_report = [self.idx_to_cls[i] for i in range(len(self.idx_to_cls))]

        if not self.image_paths:
            print(f"Warning: No images found in {self.image_dir}")
        self._verify_filename_suffixes() # Important sanity check

    def _verify_filename_suffixes(self):
        unmappable_files = []
        for path in self.image_paths:
            class_name_from_file = path.stem.split('_')[-1]
            if class_name_from_file not in self.cls_to_idx:
                unmappable_files.append(path)
        if unmappable_files:
            print(f"ERROR: The following files have suffixes that cannot be mapped using the current CLS2IDX:")
            for f_path in unmappable_files[:5]:
                print(f"  - {f_path} (suffix: {f_path.stem.split('_')[-1]})")
            if len(unmappable_files) > 5:
                print(f"  ... and {len(unmappable_files) - 5} more.")
            print(f"Keys in current CLS2IDX: {list(self.cls_to_idx.keys())}")
            raise ValueError("Mismatch between filename suffixes and CLS2IDX keys. Check vit/model.py or your manual definition.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        try:
            image = Image.open(image_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            raise

        class_name_from_file = image_path.stem.split('_')[-1]
        
        # This check should pass if _verify_filename_suffixes passed
        if class_name_from_file not in self.cls_to_idx:
             raise ValueError(f"Class '{class_name_from_file}' from file {image_path} not in CLS2IDX map.")
        
        class_idx_numerical = self.cls_to_idx[class_name_from_file]

        if self.transform:
            image = self.transform(image)
        return image, class_idx_numerical, str(image_path) # Return path for debugging if needed

# --- Modified create_data_loaders ---
# It will now implicitly use the mappings imported from vit.model at the time of the call
def create_data_loaders(config, batch_size=32, source_dir_root: Path = Path("./hyper-kvasir/preprocessed/")):
    # These will be the mappings currently defined in vit.model when this function is called
    current_classes_list = model_module.CLASSES
    current_cls2idx = model_module.CLS2IDX
    current_idx2cls = model_module.IDX2CLS

    # Sanity check: Ensure CLS2IDX and IDX2CLS are consistent with CLASSES
    if len(current_classes_list) != len(current_cls2idx) or len(current_classes_list) != len(current_idx2cls):
        raise ValueError("Mismatch in lengths of CLASSES, CLS2IDX, IDX2CLS in vit.model.py")
    for i, name in enumerate(current_classes_list):
        if current_cls2idx.get(name) != i or current_idx2cls.get(i) != name:
            raise ValueError(f"Inconsistency in mappings for class '{name}' at index {i} in vit.model.py. "
                             f"Expected CLS2IDX['{name}'] == {i} and IDX2CLS[{i}] == '{name}'. "
                             f"Got CLS2IDX.get('{name}') = {current_cls2idx.get(name)}, IDX2CLS.get({i}) = {current_idx2cls.get(i)}")


    processor = preprocessing.get_processor_for_precached_224_images()

    train_dataset = HyperkvasirDataset(source_dir_root / "train", current_cls2idx, current_idx2cls, transform=processor)
    val_dataset = HyperkvasirDataset(source_dir_root / "val", current_cls2idx, current_idx2cls, transform=processor)
    test_dataset = HyperkvasirDataset(source_dir_root / "test", current_cls2idx, current_idx2cls, transform=processor)

    class_names_for_report = test_dataset.class_names_for_report

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    return train_loader, val_loader, test_loader, class_names_for_report


# --- Your main evaluation function (train_vit_on_hyperkvasir or equivalent) ---
# This is where you'll call create_data_loaders and the evaluation logic
def run_evaluation_with_current_mapping(config, model_path, num_classes):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # --- Crucial: Print the mapping being used for this run ---
    print(f"\n--- Running evaluation with the following mapping from vit.model.py ---")
    print(f"CLASSES: {model_module.CLASSES}")
    print(f"CLS2IDX: {model_module.CLS2IDX}")
    print(f"IDX2CLS: {model_module.IDX2CLS}")
    # ---

    # Create data loaders - they will pick up the current mappings from vit.model
    # Assuming batch_size is defined, e.g., config.eval.batch_size or hardcoded
    eval_batch_size = getattr(config.eval, 'batch_size', 16) # Example: get from config or default to 16
    _, _, test_loader, class_names_for_report = create_data_loaders(config, batch_size=eval_batch_size)
    
    print(f"Order for classification_report: {class_names_for_report}")

    # Load your ViT model
    # This assumes model_module.load_vit_model uses the num_classes and loads weights properly
    vit_model = model_module.load_vit_model(device=device, model_path=model_path, num_classes=num_classes).to(device)
    vit_model.eval()

    criterion = torch.nn.CrossEntropyLoss() # Not strictly needed if only evaluating, but good practice

    # Detailed test evaluation loop (similar to your validate_epoch)
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for images, labels, _ in tqdm(test_loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)
            outputs = vit_model(images)
            preds = outputs.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    print("\nDetailed Test Results for the current mapping:")
    # The class_names_for_report from create_data_loaders ensures labels in the report match the model's output indices
    print(classification_report(all_labels, all_preds, target_names=class_names_for_report, zero_division=0))
    
    accuracy = accuracy_score(all_labels, all_preds)
    f1_macro = f1_score(all_labels, all_preds, average='macro', zero_division=0)
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Macro F1-score: {f1_macro:.4f}")
    print(f"--- Evaluation with current mapping complete ---\n")


# --- How you'd use it ---
# In your main script (e.g., the one with if __name__ == "__main__":)

# 1. BEFORE RUNNING: Manually edit vit/model.py to set the desired order in CLASSES.
#    Then ensure CLS2IDX and IDX2CLS are correctly derived from it.
#    Example content of vit/model.py:
#    ```python
#    # In vit/model.py
#    # Experiment 1: Alphabetical (default)
#    CLASSES = ["cecum", "ileum", "pylorus", "retroflex-rectum", "retroflex-stomach", "z-line"]
#    # Experiment 2: Table Order
#    #CLASSES = ["cecum", "ileum", "retroflex-rectum", "pylorus", "retroflex-stomach", "z-line"]
#
#    CLS2IDX = {cls: i for i, cls in enumerate(CLASSES)} # Corrected: class name to index
#    IDX2CLS = {i: cls for i, cls in enumerate(CLASSES)} # Corrected: index to class name
#
#    # ... your load_vit_model function and ViT class definition ...
#    def load_vit_model(device, model_path, num_classes):
#        # ... your actual model instantiation and weight loading ...
#        from translrp.ViT_new import vit_base_patch16_224 # Assuming this is your model class
#        m = vit_base_patch16_224(pretrained=False, num_classes=num_classes)
#        m.load_state_dict(torch.load(model_path, map_location=device))
#        return m
#    ```
#
# 2. Then run your main evaluation script:
if __name__ == "__main__":
    # Setup
    pipeline_config = PipelineConfig() # Load your configuration
    
    # Ensure translrp.ViT_new is importable by adding its parent to sys.path if necessary
    import sys
    import os
    # Example: if translrp is in the project root and this script is in a subfolder
    # SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) # e.g. /path/to/project/scripts
    # PROJECT_ROOT = os.path.dirname(SCRIPT_DIR) # e.g. /path/to/project
    # if PROJECT_ROOT not in sys.path:
    #    sys.path.insert(0, PROJECT_ROOT)

    # Path to your transformed model weights
    MODEL_PATH = './model/vit_b_hyperkvasir_anatomical_for_translrp.pth'
    NUM_MODEL_CLASSES = 6

    # The script will now use whatever CLASSES, CLS2IDX, IDX2CLS are currently
    # defined in vit/model.py when it's imported.
    run_evaluation_with_current_mapping(
        config=pipeline_config,
        model_path=MODEL_PATH,
        num_classes=NUM_MODEL_CLASSES
    )

    print("To test a different mapping, edit vit/model.py (CLASSES, CLS2IDX, IDX2CLS) and re-run this script.")



Using device: cuda

--- Running evaluation with the following mapping from vit.model.py ---
CLASSES: ['cecum', 'ileum', 'pylorus', 'retroflex-rectum', 'retroflex-stomach', 'z-line']
CLS2IDX: {0: 'cecum', 1: 'ileum', 2: 'pylorus', 3: 'retroflex-rectum', 4: 'retroflex-stomach', 5: 'z-line'}
IDX2CLS: {'cecum': 0, 'ileum': 1, 'pylorus': 2, 'retroflex-rectum': 3, 'retroflex-stomach': 4, 'z-line': 5}


AttributeError: 'PipelineConfig' object has no attribute 'eval'