In [1]:
import glob
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, Optional, Tuple, List, Any

import numpy as np
import torch
import torch.nn.functional as F
from scipy.stats import pearsonr
from tqdm import tqdm
from vit_prisma.models.base_vit import HookedSAEViT
from vit_prisma.sae import SparseAutoencoder
import torchvision # Required for the dataset with paths
from collections import Counter


# --- Configuration ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# --- Vectorized Helper Functions ---

def batch_gini(features: torch.Tensor) -> torch.Tensor:
    """
    Calculates Gini coefficients for multiple features at once.
    Args:
        features: (n_patches, n_features) tensor
    Returns:
        (n_features,) tensor of Gini coefficients
    """
    # Handle edge cases
    if features.shape[1] == 0:
        return torch.tensor([])
    
    # Sort each feature column
    sorted_features, _ = torch.sort(features.abs(), dim=0)
    n = features.shape[0]
    cumsum = torch.cumsum(sorted_features, dim=0)
    
    # Avoid division by zero
    cumsum_last = cumsum[-1].clamp(min=1e-8)
    
    return (n + 1 - 2 * cumsum.sum(dim=0) / cumsum_last) / n

def batch_dice(features: torch.Tensor, target: torch.Tensor, threshold: float = 0.1) -> torch.Tensor:
    """
    Calculates Dice coefficients for multiple features against a target.
    Args:
        features: (n_patches, n_features) tensor
        target: (n_patches,) tensor
        threshold: threshold for binarization
    Returns:
        (n_features,) tensor of Dice coefficients
    """
    feat_masks = (features > threshold).float()
    target_mask = (target > threshold).float().unsqueeze(1)
    
    intersections = (feat_masks * target_mask).sum(dim=0)
    unions = feat_masks.sum(dim=0) + target_mask.sum()
    
    return (2 * intersections / (unions + 1e-8))

def build_attribution_aligned_feature_dictionary(
    model: HookedSAEViT,
    sae: SparseAutoencoder,
    dataloader: torch.utils.data.DataLoader,
    attribution_dir: str,
    n_samples: int = 1000,
    layer_idx: int = 9,
    patch_size: int = 16,
    min_occurrences: int = 5,
    threshold: float = 0.10,
    save_path: Optional[str] = None
) -> Dict[str, Any]:

    device = next(model.parameters()).device
    feature_occurrences = defaultdict(list)
    samples_processed = 0

    resid_hook_name = f"blocks.{layer_idx}.hook_resid_post"
    pbar = tqdm(dataloader, total=min(n_samples, len(dataloader)),
                desc="Analyzing feature alignment")

    for imgs, labels, paths in pbar:
        if samples_processed >= n_samples:
            break

        img, label, path = imgs[0:1].to(device), labels[0].item(), paths[0]

        # SAE codes
        with torch.no_grad():
            _, cache = model.run_with_cache(img, names_filter=[resid_hook_name])
            resid = cache[resid_hook_name]
            _, codes = sae.encode(resid)
        feature_activations = codes[0, 1:]  # (n_patches, n_features)

        # Attribution map
        try:
            stem    = Path(path).stem
            parts   = stem.split('_')
            prefix  = 'train'
            uuid, aug = parts[0], parts[1]
            pattern = str(Path(attribution_dir) / f"{prefix}_{uuid}_*_{aug}_attribution.npy")
            attr_files = glob.glob(pattern)
            if not attr_files:
                logging.warning(f"No attribution for pattern {pattern}")
                continue
            attr_map = np.load(attr_files[0])
            if attr_map.ndim != 2:
                logging.warning(f"Unexpected map shape {attr_map.shape} for {stem}")
                continue

            attr_tensor = F.avg_pool2d(
                torch.as_tensor(attr_map).float()
                     .unsqueeze_(0).unsqueeze_(0).to(device),
                kernel_size=patch_size, stride=patch_size
            ).flatten()  # (n_patches,)

            attr_tensor = (attr_tensor - attr_tensor.mean()) / (attr_tensor.std() + 1e-8)
        except Exception as e:
            logging.warning(f"Attribution load error for {stem}: {e}")
            continue

        # Compute PFAC only
        active_idx = (feature_activations.abs().sum(0) > 1e-6).nonzero(as_tuple=True)[0]
        if active_idx.numel() == 0:
            continue

        feats = feature_activations[:, active_idx]  # (n_patches, n_active)
        mu, sigma = feats.mean(0, keepdim=True), feats.std(0, keepdim=True)
        valid = (sigma.squeeze() > 1e-6).nonzero(as_tuple=True)[0]
        if valid.numel() == 0:
            continue

        idx_valid = active_idx[valid]
        feats_norm = (feats[:, valid] - mu[:, valid]) / (sigma[:, valid] + 1e-8)
        n_patches = attr_tensor.numel()

        pfac_vec = (attr_tensor @ feats_norm) / (n_patches - 1)  # (n_valid,)
        for i, fid in enumerate(idx_valid):
            feature_occurrences[fid.item()].append({
                'pfac_corr': pfac_vec[i].item(),
                'class':     label
            })

        samples_processed += 1
        pbar.set_postfix({"Processed": samples_processed})

    # Aggregate results
    reliable_features: Dict[int, Any] = {}
    for fid, occ in feature_occurrences.items():
        if len(occ) < min_occurrences:
            continue

        pfacs   = [o['pfac_corr'] for o in occ]
        classes = [o['class']      for o in occ]

        mean_pfac = float(np.mean(pfacs))
        class_count_map = Counter(classes)
        class_mean_pfac = {}
        for cls, cnt in class_count_map.items():
            vals = [p for p,c in zip(pfacs, classes) if c == cls]
            class_mean_pfac[cls] = float(np.mean(vals))

        reliable_features[fid] = {
            'mean_pfac_corr':  mean_pfac,
            'occurrences':     len(occ),
            'class_count_map': dict(class_count_map),
            'class_mean_pfac': class_mean_pfac,
            'raw_pfacs':       pfacs
        }

    final_dict = {
        'feature_stats': reliable_features,
        'metadata': {
            'layer_idx': layer_idx,
            'patch_size': patch_size,
            'n_samples_processed': samples_processed,
            'min_occurrences': min_occurrences,
            'attribution_dir': attribution_dir,
            'threshold': threshold
        },
        'metric_definitions': {
            'pfac_corr': "Pearson correlation between feature activations and attribution map"
        }
    }

    if save_path:
        Path(save_path).parent.mkdir(parents=True, exist_ok=True)
        torch.save(final_dict, save_path)
        logging.info(f"Dictionary saved to {save_path}")

    return final_dict

In [None]:
import torchvision
from transmm_sfaf import get_processor_for_precached_224_images, IDX2CLS
from vit_prisma.models.weight_conversion import convert_timm_weights
import torch

def load_models():
    """Load SAE and fine-tuned model"""
    # Load SAE
    sae_path = "./models/sweep/sae_l7_k128_exp32_lr2e-05/d77c1ce8-vit_medical_sae_k_sweep/n_images_49276.pt"
    sae = SparseAutoencoder.load_from_pretrained(sae_path)
    sae.cuda().eval()

    # Load model
    model = HookedSAEViT.from_pretrained("vit_base_patch16_224")
    model.head = torch.nn.Linear(model.cfg.d_model, 6)

    checkpoint = torch.load(
        "./model/vit_b-ImageNet_class_init-frozen_False-dataset_Hyperkvasir_anatomical.pth", weights_only=False
    )
    state_dict = checkpoint['model_state_dict'].copy()

    if 'lin_head.weight' in state_dict:
        state_dict['head.weight'] = state_dict.pop('lin_head.weight')
    if 'lin_head.bias' in state_dict:
        state_dict['head.bias'] = state_dict.pop('lin_head.bias')

    converted_weights = convert_timm_weights(state_dict, model.cfg)
    model.load_state_dict(converted_weights)
    model.cuda().eval()

    return sae, model

sae, model = load_models()
# Build new dictionary
label_map = {2: 3, 3: 2}

def custom_target_transform(target):
    return label_map.get(target, target)

class ImageFolderWithPaths(torchvision.datasets.ImageFolder):
    def __getitem__(self, index):
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None: sample = self.transform(sample)
        # Assuming you have a target_transform defined elsewhere
        # if self.target_transform is not None: target = self.target_transform(target)
        return sample, target, path

train_dataset = ImageFolderWithPaths(
    "./hyper-kvasir_imagefolder/train", 
    get_processor_for_precached_224_images()
)
dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True)
layer_id = 7
stealth_dict = build_attribution_aligned_feature_dictionary(
    model, sae, dataloader, 
    n_samples=50000,
    attribution_dir="./results/train/attributions",
    layer_idx=layer_id,
    min_occurrences=1,           # Must appear 3+ times
    save_path=f"./sae_dictionaries/sfaf_stealth_l{layer_id}_alignment_min1_32k64.pt"
)

2025-07-13 14:17:59 INFO:root: get_activation_fn received: activation_fn=topk, kwargs={'k': 128}
2025-07-13 14:18:00 DEBUG:urllib3.connectionpool: Starting new HTTPS connection (1): huggingface.co:443
2025-07-13 14:18:00 DEBUG:urllib3.connectionpool: https://huggingface.co:443 "HEAD /timm/vit_base_patch16_224.augreg2_in21k_ft_in1k/resolve/main/config.json HTTP/1.1" 307 0
2025-07-13 14:18:00 DEBUG:urllib3.connectionpool: https://huggingface.co:443 "HEAD /api/resolve-cache/models/timm/vit_base_patch16_224.augreg2_in21k_ft_in1k/063c6c38a5d8510b2e57df480445e94b231dad2c/config.json HTTP/1.1" 200 0


ln_pre not set


2025-07-13 14:18:01 INFO:timm.models._builder: Loading pretrained weights from Hugging Face hub (timm/vit_base_patch16_224.augreg2_in21k_ft_in1k)
2025-07-13 14:18:01 DEBUG:urllib3.connectionpool: https://huggingface.co:443 "HEAD /timm/vit_base_patch16_224.augreg2_in21k_ft_in1k/resolve/main/model.safetensors HTTP/1.1" 302 0
2025-07-13 14:18:01 INFO:timm.models._hub: [timm/vit_base_patch16_224.augreg2_in21k_ft_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
2025-07-13 14:18:01 INFO:root: Filling in 2 missing keys with default initialization
2025-07-13 14:18:01 INFO:root: Loaded pretrained model vit_base_patch16_224 into HookedTransformer


Converting the weights of a timm model to a Prisma ViT
LayerNorm folded.
Centered weights writing to residual stream
Converting the weights of a timm model to a Prisma ViT


Analyzing feature alignment:   5%|▍         | 3/65 [00:00<00:17,  3.48it/s, Processed=3]


KeyboardInterrupt: 

In [17]:
import pandas as pd
from transmm_sfaf import load_models, get_processor_for_precached_224_images, IDX2CLS
from transmm_sfaf import load_or_build_sf_af_dictionary

sae, model = load_models()
layer_idx = 10
stealth_dict = load_or_build_sf_af_dictionary(
    model,
    sae,
    n_samples=50000,
    layer_idx=layer_idx,
    dict_path=f"./sae_dictionaries/sfaf_stealth_l{layer_idx}_alignment_min10_32k512.pt",
    rebuild=False
)

print("Dictionary loaded successfully!")
print("Metadata:", stealth_dict['metadata'])
print("-" * 50)

# --- Convert to a Pandas DataFrame for easy analysis ---
feature_stats = stealth_dict.get('feature_stats', {})
if not feature_stats:
    print("No feature statistics found in the dictionary.")
else:
    # We convert the dictionary of features into a list of dictionaries, then to a DataFrame
    feature_list = []
    for feat_id, stats in feature_stats.items():
        # Create a flat dictionary for the DataFrame row
        row = {'feature_id': feat_id}
        # Add all stats except the bulky raw_metrics
        row.update({k: v for k, v in stats.items() if k != 'raw_metrics'})
        feature_list.append(row)

    df = pd.DataFrame(feature_list)

    # Set display options for better readability
    pd.set_option('display.max_rows', 50)
    pd.set_option('display.width', 120)
    
    print("--- Top 20 Features by Mean PFAC (Alignment) ---")
    # PFAC = Patch-Feature Attribution Correlation
    print(df.sort_values(by='mean_pfac_corr', ascending=False).head(100).tail(50))
    print("\n" + "="*80 + "\n")

2025-07-09 14:54:38 INFO:root: get_activation_fn received: activation_fn=topk, kwargs={'k': 1024}
2025-07-09 14:54:38 DEBUG:urllib3.connectionpool: Starting new HTTPS connection (17): huggingface.co:443


ln_pre not set


2025-07-09 14:54:39 INFO:timm.models._builder: Loading pretrained weights from Hugging Face hub (timm/vit_base_patch16_224.augreg2_in21k_ft_in1k)
2025-07-09 14:54:39 DEBUG:urllib3.connectionpool: Starting new HTTPS connection (18): huggingface.co:443
2025-07-09 14:54:39 INFO:timm.models._hub: [timm/vit_base_patch16_224.augreg2_in21k_ft_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
2025-07-09 14:54:39 INFO:root: Filling in 2 missing keys with default initialization
2025-07-09 14:54:39 INFO:root: Loaded pretrained model vit_base_patch16_224 into HookedTransformer


Converting the weights of a timm model to a Prisma ViT
LayerNorm folded.
Centered weights writing to residual stream
Converting the weights of a timm model to a Prisma ViT
Loading existing S_f/A_f dictionary from ./sae_dictionaries/sfaf_stealth_l10_alignment_min10_32k512.pt
Dictionary loaded successfully!
Metadata: {'layer_idx': 10, 'patch_size': 16, 'n_samples_processed': 65, 'min_occurrences': 10, 'attribution_dir': './results/train/attributions'}
--------------------------------------------------
--- Top 20 Features by Mean PFAC (Alignment) ---
      feature_id  mean_pfac_corr  std_pfac_corr  cv_pfac_corr  mean_awa_score  ...  mean_gini_score  mean_dice_score  \
1516       19927        0.113865       0.193980      1.703573        0.550749  ...         0.989824         0.041581   
1684       21725        0.113816       0.202648      1.780465        1.010488  ...         0.968784         0.065548   
1628       12579        0.113499       0.220535      1.943038        1.105776  ...    

In [18]:
import pandas as pd
import numpy as np
from typing import Dict, Any, List, Tuple

# Assuming these are already loaded from your setup code
# from transmm_sfaf import load_models, IDX2CLS, load_or_build_sf_af_dictionary
# sae, model = load_models()
# layer_idx = 7
# stealth_dict = load_or_build_sf_af_dictionary(...)
# print("Dictionary loaded successfully!")

# --- Configuration ---
TOP_K_PER_CLASS = 20
# Add a minimum occurrence filter to ensure statistical significance
MIN_OCCURRENCES_FOR_CLASS = 10

# --- 1. Process Raw Data into a Per-Class Structure ---

feature_stats = stealth_dict.get('feature_stats', {})
if not feature_stats:
    print("No feature statistics found in the dictionary.")
else:
    per_class_feature_data = []

    # Iterate through each feature in the dictionary
    for feat_id, stats in feature_stats.items():
        raw_metrics = stats.get('raw_metrics', {})
        pfac_corrs = raw_metrics.get('pfac_corrs', [])
        classes = raw_metrics.get('classes', [])

        if not pfac_corrs or not classes:
            continue

        # Group correlations by class for the current feature
        corrs_by_class = {}
        for corr, cls_id in zip(pfac_corrs, classes):
            if cls_id not in corrs_by_class:
                corrs_by_class[cls_id] = []
            corrs_by_class[cls_id].append(corr)

        # Calculate stats for each class this feature appeared in
        for cls_id, corrs_list in corrs_by_class.items():
            valid_corrs = [c for c in corrs_list if not np.isnan(c)]
            occurrences = len(valid_corrs)

            # --- NEW: Filter based on minimum occurrences ---
            if occurrences < MIN_OCCURRENCES_FOR_CLASS:
                continue  # Skip if this feature-class pair is too rare

            avg_corr = np.mean(valid_corrs)
            
            per_class_feature_data.append({
                'feature_id': feat_id,
                'class_id': cls_id,
                'class_name': IDX2CLS.get(cls_id, 'Unknown'), # Still useful for display
                'avg_pfac_corr': avg_corr, # The key metric for ranking
                'occurrences': occurrences, # How often it activated for this class
            })

    # --- 2. Create the DataFrame and Perform Analysis ---
    
    if not per_class_feature_data:
        print(f"No valid per-class feature data could be extracted with min occurrences >= {MIN_OCCURRENCES_FOR_CLASS}.")
    else:
        df_per_class = pd.DataFrame(per_class_feature_data)

        # Set display options for better readability
        pd.set_option('display.max_rows', 500)
        pd.set_option('display.max_columns', 10)
        pd.set_option('display.width', 150)
        
        print("=" * 80)
        print(f"      TOP {TOP_K_PER_CLASS} CORRELATED FEATURES PER CLASS (ranked by avg_pfac_corr) for Layer {layer_idx}")
        print(f"      (Minimum Occurrences per Class: {MIN_OCCURRENCES_FOR_CLASS})")
        print("=" * 80)
        print("PFAC = Patch-Feature Attribution Correlation\n")

        # --- 3. Group by Class ID and Display Results ---
        
        # Sort by correlation, then group by class ID and take the top N
        top_features_per_class = (
            df_per_class
            .sort_values(by='avg_pfac_corr', ascending=False)
            .groupby('class_id', sort=False) # Group by class_id instead of class_name
            .head(TOP_K_PER_CLASS)
        )
        
        # Set the multi-index using class_id and feature_id for clear grouping
        top_features_per_class = top_features_per_class.set_index(['class_id', 'feature_id'])

        print(top_features_per_class)

      TOP 20 CORRELATED FEATURES PER CLASS (ranked by avg_pfac_corr) for Layer 10
      (Minimum Occurrences per Class: 10)
PFAC = Patch-Feature Attribution Correlation

                           class_name  avg_pfac_corr  occurrences
class_id feature_id                                              
5        12609                 z-line       0.340186           11
         13673                 z-line       0.336493           11
         14977                 z-line       0.285648           10
         12153                 z-line       0.270660           11
         10860                 z-line       0.270145           11
         5286                  z-line       0.270063           10
         465                   z-line       0.268521           11
         8735                  z-line       0.265743           10
         2605                  z-line       0.261874           11
         14400                 z-line       0.258170           11
         2203                  z-line 

In [23]:
# 1. Define the list of feature IDs you want to inspect
specific_feature_ids = [8, 9, 12, 23, 26]

print("\n" + "="*80)
print(f"--- Statistics for Specific Features: {specific_feature_ids} ---")

# 2. Filter the DataFrame to get only the rows for these features
# .isin() checks if the value in the 'feature_id' column is present in your list
specific_features_df = df[df['feature_id'].isin(specific_feature_ids)]

# 3. Print the resulting filtered DataFrame
if not specific_features_df.empty:
    # We can transpose (.T) the DataFrame for better readability when there are few features
    # This turns columns into rows, which is often easier to read.
    print(specific_features_df.set_index('feature_id').T)
else:
    print("None of the specified features were found in the dictionary.")
    print("This might be because they didn't meet the min_occurrences threshold.")

print("="*80 + "\n")


--- Statistics for Specific Features: [8, 9, 12, 23, 26] ---
feature_id                      12                  23        9             8                   26
mean_pfac_corr            -0.01718           -0.018236  0.073728      0.006976            0.005308
std_pfac_corr             0.085807            0.071139  0.162014      0.086522            0.094934
cv_pfac_corr              4.994277            3.900749  2.197426     12.400353           17.882137
mean_awa_score           -1.019369           -0.569488  1.565944      0.211732            0.099672
std_awa_score             3.391973            2.094591  3.699006      2.186467            1.626569
mean_gini_score           0.967501             0.97359  0.990795      0.985377            0.991927
mean_dice_score           0.043068            0.040191  0.046066      0.034196            0.016611
consistency_score         0.561156            0.429977 -0.498712     -0.534286           -0.647333
occurrences                   3559             

In [7]:
import torchvision
from transmm_sfaf import load_models, get_processor_for_precached_224_images, IDX2CLS
import torch


sae, model = load_models()
# Build new dictionary
label_map = {2: 3, 3: 2}

def custom_target_transform(target):
    return label_map.get(target, target)

class ImageFolderWithPaths(torchvision.datasets.ImageFolder):
    def __getitem__(self, index):
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None: sample = self.transform(sample)
        # Assuming you have a target_transform defined elsewhere
        # if self.target_transform is not None: target = self.target_transform(target)
        return sample, target, path

train_dataset = ImageFolderWithPaths(
    "./hyper-kvasir_imagefolder/dev", 
    get_processor_for_precached_224_images()
)
dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
layer_id = 9
stealth_dict = build_comprehensive_feature_dictionary(
    model, sae, dataloader, 
    n_samples=50000,
    attribution_dir="./results/dev/attributions",
    layer_idx=layer_id,
    min_occurrences=3,           # Must appear 3+ times
    save_path=f"./sae_dictionaries/sfaf_stealth_l{layer_id}_alignment_comp.pt"
)

2025-07-07 14:47:40 INFO:root: get_activation_fn received: activation_fn=topk, kwargs={'k': 128}
2025-07-07 14:47:41 DEBUG:urllib3.connectionpool: Resetting dropped connection: huggingface.co
2025-07-07 14:47:41 DEBUG:urllib3.connectionpool: https://huggingface.co:443 "HEAD /timm/vit_base_patch16_224.augreg2_in21k_ft_in1k/resolve/main/config.json HTTP/1.1" 307 0
2025-07-07 14:47:41 DEBUG:urllib3.connectionpool: https://huggingface.co:443 "HEAD /api/resolve-cache/models/timm/vit_base_patch16_224.augreg2_in21k_ft_in1k/063c6c38a5d8510b2e57df480445e94b231dad2c/config.json HTTP/1.1" 200 0


ln_pre not set


2025-07-07 14:47:42 INFO:timm.models._builder: Loading pretrained weights from Hugging Face hub (timm/vit_base_patch16_224.augreg2_in21k_ft_in1k)
2025-07-07 14:47:42 DEBUG:urllib3.connectionpool: https://huggingface.co:443 "HEAD /timm/vit_base_patch16_224.augreg2_in21k_ft_in1k/resolve/main/model.safetensors HTTP/1.1" 302 0
2025-07-07 14:47:42 INFO:timm.models._hub: [timm/vit_base_patch16_224.augreg2_in21k_ft_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
2025-07-07 14:47:42 INFO:root: Filling in 2 missing keys with default initialization
2025-07-07 14:47:42 INFO:root: Loaded pretrained model vit_base_patch16_224 into HookedTransformer


Converting the weights of a timm model to a Prisma ViT
LayerNorm folded.
Centered weights writing to residual stream
Converting the weights of a timm model to a Prisma ViT


Building comprehensive dict: 100%|██████████| 750/750 [01:39<00:00,  7.57it/s, Processed=750, Features Found=936]
2025-07-07 14:49:22 INFO:root: Comprehensive dictionary saved to ./sae_dictionaries/sfaf_stealth_l9_alignment_comp.pt
