# Evaluating your SAE

Code based off Rob Graham's ([themachinefan](https://github.com/themachinefan)) SAE evaluation code.

In [1]:
import os
os.getcwd()

'/workspace/ViT-Prisma/src/vit_prisma/sae/evals'

In [2]:
import einops
import torch
import torchvision

import plotly.express as px

from tqdm import tqdm

import numpy as np
import os
import requests

# Setup

In [3]:
from dataclasses import dataclass
from vit_prisma.sae.config import VisionModelSAERunnerConfig


@dataclass
class EvalConfig(VisionModelSAERunnerConfig):
    sae_path: str = '/workspace/sae_checkpoints/sparse-autoencoder-clip-b-32-sae-vanilla-x64-layer-10-hook_mlp_out-l1-0.0001/n_images_2600058.pt'
    model_name: str = "open-clip:laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K"
    model_type: str =  "clip"
    patch_size: str = 32

    dataset_path = "/workspace"
    dataset_train_path: str = "/workspace/ILSVRC/Data/CLS-LOC/train"
    dataset_val_path: str = "/workspace/ILSVRC/Data/CLS-LOC/val"

    verbose: bool = True

    device: bool = 'cuda'

    eval_max: int = 50_000 # 50_000
    batch_size: int = 32

    # make the max image output folder a subfolder of the sae path


    @property
    def max_image_output_folder(self) -> str:
        # Get the base directory of sae_checkpoints
        sae_base_dir = os.path.dirname(os.path.dirname(self.sae_path))
        
        # Get the name of the original SAE checkpoint folder
        sae_folder_name = os.path.basename(os.path.dirname(self.sae_path))
        
        # Create a new folder path in sae_checkpoints/images with the original name
        output_folder = os.path.join(sae_base_dir, 'max_images', sae_folder_name)
        output_folder = os.path.join(output_folder, f"layer_{self.hook_point_layer}") # Add layer number

        
        # Ensure the directory exists
        os.makedirs(output_folder, exist_ok=True)
        
        return output_folder

cfg = EvalConfig()

n_tokens_per_buffer (millions): 0.032
Lower bound: n_contexts_per_buffer (millions): 0.00064
Total training steps: 158691
Total training images: 13000000
Total wandb updates: 15869
Expansion factor: 16
n_tokens_per_feature_sampling_window (millions): 204.8
n_tokens_per_dead_feature_window (millions): 1024.0
Using Ghost Grads.
We will reset the sparsity calculation 158 times.
Number tokens in sparsity calculation window: 4.10e+06
Gradient clipping with max_norm=1.0
Using SAE initialization method: encoder_transpose_decoder


In [4]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7ae7318a3d30>

## Load model

In [5]:
from vit_prisma.models.base_vit import HookedViT

model_name = "open-clip:laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K"
model = HookedViT.from_pretrained(model_name, is_timm=False, is_clip=True).to(cfg.device)
 

model_id download_pretrained_from_hf: laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K
Official model name open-clip:laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K
Converting OpenCLIP weights
model_id download_pretrained_from_hf: laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K
visual projection shape torch.Size([768, 512])
Setting center_writing_weights to False for OpenCLIP
Setting fold_ln to False for OpenCLIP
Loaded pretrained model open-clip:laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K into HookedTransformer


## Load datasets

In [6]:
import importlib
import vit_prisma
# importlib.reload(vit_prisma.dataloaders.imagenet_dataset)

In [7]:
# load dataset
import open_clip
from vit_prisma.utils.data_utils.imagenet_utils import setup_imagenet_paths
from vit_prisma.dataloaders.imagenet_dataset import get_imagenet_transforms_clip, ImageNetValidationDataset

from torchvision import transforms
from transformers import CLIPProcessor

og_model_name = "hf-hub:laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K"
og_model, _, preproc = open_clip.create_model_and_transforms(og_model_name)
processor = preproc

size=224

data_transforms = transforms.Compose([
    transforms.Resize((size, size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                     std=[0.26862954, 0.26130258, 0.27577711]),
])
    
imagenet_paths = setup_imagenet_paths(cfg.dataset_path)
imagenet_paths["train"] = "/workspace/ILSVRC/Data/CLS-LOC/train"
imagenet_paths["val"] = "/workspace/ILSVRC/Data/CLS-LOC/val"
imagenet_paths["val_labels"] = "/workspace/LOC_val_solution.csv"
imagenet_paths["label_strings"] = "/workspace/LOC_synset_mapping.txt"
print()
train_data = torchvision.datasets.ImageFolder(cfg.dataset_train_path, transform=data_transforms)
val_data = ImageNetValidationDataset(cfg.dataset_val_path, 
                                imagenet_paths['label_strings'], 
                                imagenet_paths['val_labels'], 
                                data_transforms,
                                return_index=True,
)
val_data_visualize = ImageNetValidationDataset(cfg.dataset_val_path, 
                                imagenet_paths['label_strings'], 
                                imagenet_paths['val_labels'],
                                torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),]), return_index=True)

print(f"Validation data length: {len(val_data)}") if cfg.verbose else None



Validation data length: 50000


In [8]:
from vit_prisma.sae.training.activations_store import VisionActivationsStore
# import dataloader
from torch.utils.data import DataLoader

# activations_loader = VisionActivationsStore(cfg, model, train_data, eval_dataset=val_data)
val_dataloader = DataLoader(val_data, batch_size=cfg.batch_size, shuffle=False, num_workers=4)


## Load pretrained SAE to evaluate

In [9]:
from vit_prisma.sae.sae import SparseAutoencoder
sparse_autoencoder = SparseAutoencoder(cfg).load_from_pretrained("/workspace/sae_checkpoints/sparse-autoencoder-clip-b-32-sae-vanilla-x64-layer-10-hook_mlp_out-l1-0.0001/n_images_2600058.pt")
sparse_autoencoder.to(cfg.device)
sparse_autoencoder.eval()  # prevents error if we're expecting a dead neuron mask for who 


get_activation_fn received: activation_fn=relu, kwargs={}
n_tokens_per_buffer (millions): 0.032
Lower bound: n_contexts_per_buffer (millions): 0.00064
Total training steps: 158691
Total training images: 13000000
Total wandb updates: 1586
Expansion factor: 64
n_tokens_per_feature_sampling_window (millions): 204.8
n_tokens_per_dead_feature_window (millions): 1024.0
Using Ghost Grads.
We will reset the sparsity calculation 158 times.
Number tokens in sparsity calculation window: 4.10e+06
Gradient clipping with max_norm=1.0
Using SAE initialization method: encoder_transpose_decoder
get_activation_fn received: activation_fn=relu, kwargs={}


SparseAutoencoder(
  (hook_sae_in): HookPoint()
  (hook_hidden_pre): HookPoint()
  (hook_hidden_post): HookPoint()
  (hook_sae_out): HookPoint()
  (activation_fn): ReLU()
)

## Clip Labeling AutoInterp

In [10]:
# all_imagenet_class_names

In [11]:
from vit_prisma.dataloaders.imagenet_dataset import get_imagenet_index_to_name
ind_to_name = get_imagenet_index_to_name()

all_imagenet_class_names = []
for i in range(len(ind_to_name)):
    all_imagenet_class_names.append(ind_to_name[str(i)][1])

In [12]:
cfg.max_image_output_folder

'/workspace/sae_checkpoints/max_images/sparse-autoencoder-clip-b-32-sae-vanilla-x64-layer-10-hook_mlp_out-l1-0.0001/layer_9'

## Feature steering

In [13]:
def standard_replacement_hook_curry(feat_idx: int = 0, feat_activ: float = 1.0):
    def standard_replacement_hook(activations: torch.Tensor, hook):
        activations = sparse_autoencoder.forward(activations)[0].to(activations.dtype)
        feature_acts = sparse_autoencoder.encode_standard(activations)

        # in all batches and patches, set feature w idx idx to 0
        print(f"feature_acts[:,:,idx].shape: {feature_acts[:,:,feat_idx].shape}")
        print(f"feat activ: {feature_acts[:,:,feat_idx]}")
        feature_acts[:,:,feat_idx] *= feat_activ
        print(f"feat activ: {feature_acts[:,:,feat_idx]}")
        print(f"feat activ: {feature_acts.shape}")
        print(f"feat activ: {feature_acts}")
        print("feature_acts[:,:,idx].sum(): (should be batch size x len seq x feat val)", feature_acts[:,:,feat_idx].sum())
        sae_out = sparse_autoencoder.hook_sae_out(
            einops.einsum(
                feature_acts,
                sparse_autoencoder.W_dec,
                "... d_sae, d_sae d_in -> ... d_in",
            )
            + sparse_autoencoder.b_dec
        )
        
        print(f"sae_out.shape: {sae_out.shape}")
        print(f"sae_out: {sae_out}")

        # allows normalization. Possibly identity if no normalization
        sae_out = sparse_autoencoder.run_time_activation_norm_fn_out(sae_out)
        return sae_out
    return standard_replacement_hook


def steering_hook_fn(
    activations, cfg, hook, sae, steering_indices, steering_strength=1.0, mean_ablation_values=None, include_error=False

):
    sae.to(activations.device)


    sae_input = activations.clone()
    sae_output, feature_activations, *data = sae(sae_input)
    
    steered_feature_activations = feature_activations.clone()
    
    steered_feature_activations[:, :, steering_indices] = steering_strength

    steered_sae_out = einops.einsum(
                steered_feature_activations,
                sae.W_dec,
                "... d_sae, d_sae d_in -> ... d_in",
            ) + sae.b_dec

    steered_sae_out = sae.run_time_activation_norm_fn_out(steered_sae_out)
    
    print(steered_sae_out.shape)
    print(steered_sae_out.shape)
    print(f"steering norm: {(steered_sae_out - sae_output).norm()}")
    
    

    if include_error:
        error = sae_input - sae_output
        print(f"error.norm(): {error.norm()}")
        return steered_sae_out + error
    return steered_sae_out

In [119]:
random_feat_idxs = np.random.randint(0, high=3000, size=(25))

In [120]:
# for a given feature, set it high/low on maxim activ. imgs and high/low on non-activ images
# hook SAE and replace desired feature with 0 or 1 
from typing import List, Dict, Tuple
import torch
import einops
from tqdm import tqdm

from functools import partial

@torch.no_grad()
def compute_feature_activations_set_feat(
    images: torch.Tensor,
    model: torch.nn.Module,
    sparse_autoencoder: torch.nn.Module,
    encoder_weights: torch.Tensor,
    encoder_biases: torch.Tensor,
    feature_ids: List[int],
    feature_categories: List[str],
    top_k: int = 10
):
    """
    Compute the highest activating tokens for given features in a batch of images.
    
    Args:
        images: Input images
        model: The main model
        sparse_autoencoder: The sparse autoencoder
        encoder_weights: Encoder weights for selected features
        encoder_biases: Encoder biases for selected features
        feature_ids: List of feature IDs to analyze
        feature_categories: Categories of the features
        top_k: Number of top activations to return per feature

    Returns:
        Dictionary mapping feature IDs to tuples of (top_indices, top_values)
    """
    _, cache = model.run_with_cache(images, names_filter=[sparse_autoencoder.cfg.hook_point])
#     recons_image_embeddings_feat_altered = model.run_with_hooks(
#         images,
#         fwd_hooks=[("blocks.9.hook_mlp_out", standard_replacement_hook)],
#     )
    recons_image_embeddings_feat_altered_list = []
    for idx in np.array(range(sparse_autoencoder.W_dec.shape[0]))[random_feat_idxs]:
        print(f"Feature: {idx} ====================")
        
        steering_hook = partial(
            steering_hook_fn,
            cfg=cfg,
            sae=sparse_autoencoder,
            steering_indices=[idx],
            steering_strength=10.0,
            mean_ablation_values = [1.0],
            include_error=True,
            )
        
        
        recons_image_embeddings_feat_altered = model.run_with_hooks(
            images,
#             fwd_hooks=[("blocks.9.hook_mlp_out", standard_replacement_hook_curry(idx, 10.0))],
            fwd_hooks=[("blocks.9.hook_mlp_out", steering_hook)],
        )
        recons_image_embeddings_feat_altered_list.append(recons_image_embeddings_feat_altered)

    
    # output is in clip embedding space
    recons_image_embeddings_default = model.run_with_hooks(
        images,
        fwd_hooks=[("blocks.9.hook_mlp_out", lambda x, hook: x)],
    )
    
    print(f"recons_image_embeddings_default: {recons_image_embeddings_default}")
    print(f"recons_image_embeddings_default.shape: {recons_image_embeddings_default.shape}")
    print(f"recons_image_embeddings_default: {recons_image_embeddings_default.shape}")

    print(f"recons_image_embeddings_feat_altered: {recons_image_embeddings_feat_altered}")
    print(f"recons_image_embeddings_feat_altered.shape: {recons_image_embeddings_feat_altered.shape}")

    return recons_image_embeddings_feat_altered_list, recons_image_embeddings_default

In [122]:
from collections import defaultdict
max_samples = cfg.eval_max

# top_activations = {i: (None, None) for i in interesting_features_indices}
encoder_biases = sparse_autoencoder.b_enc#[interesting_features_indices]
encoder_weights = sparse_autoencoder.W_enc#[:, interesting_features_indices]

top_k=10
processed_samples = 0
default_embeds_list = []
feature_steered_embeds = defaultdict(list)
l = 0
for batch_images, _, batch_indices in tqdm(val_dataloader, total=max_samples // cfg.batch_size):
    batch_images = batch_images.to(cfg.device)
    batch_indices = batch_indices.to(cfg.device)
    batch_size = batch_images.shape[0]

    altered_embeds_list, default_embeds = compute_feature_activations_set_feat(
        batch_images, model, sparse_autoencoder, encoder_weights, encoder_biases,
        None, None, top_k
    )
    default_embeds_list.append(default_embeds)
    for j, altered_embeds in enumerate(altered_embeds_list):
        feature_steered_embeds[random_feat_idxs[j]].extend(altered_embeds)
    # either label embeds or optimize to maximal token in text transformer embedding face
    l += 1
    if l >= 50:
        break

  0%|                                                                                                            | 0/1562 [00:00<?, ?it/s]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.76763916015625
error.norm(): 3214.9658203125
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.7882385253906
error.norm(): 3214.9658203125
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.6164855957031
error.norm(): 3214.9658203125
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.3077697753906
error.norm(): 3214.9658203125
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.3682556152344
error.norm(): 3214.9658203125
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.3701477050781
error.norm(): 3214.9658203125
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 398.8790588378906
error.norm(): 3214.9658203125
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.1642150878906
error.norm(): 3214.9658203125
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.50180053

  0%|                                                                                                    | 1/1562 [00:02<55:03,  2.12s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.6460266113281
error.norm(): 3214.9658203125
recons_image_embeddings_default: tensor([[ 0.0352,  0.0083, -0.0740,  ..., -0.0311,  0.0275,  0.0019],
        [-0.0101, -0.0539, -0.0622,  ...,  0.0199, -0.0555, -0.0743],
        [-0.0206,  0.0059, -0.0366,  ..., -0.0307,  0.0756, -0.0016],
        ...,
        [ 0.0099, -0.0045, -0.0059,  ..., -0.0521,  0.0647, -0.0225],
        [-0.0422,  0.0518, -0.0482,  ...,  0.0098,  0.0418,  0.0290],
        [-0.0411, -0.0590,  0.0014,  ..., -0.0432, -0.0089, -0.0449]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[ 0.0309,  0.0017, -0.0614,  ..., -0.0337,  0.0321, -0.0164],
        [-0.0490, -0.0249, -0.0655,  ..., -0.0098, -0.0570, -0.0745],
        [-0.0502, -0.0072, -0.0447,  ..., -0.0369,  0.0772, -0.0287],
        ...,
        [-0.0101,

  0%|▏                                                                                                   | 2/1562 [00:03<48:30,  1.87s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.869873046875
error.norm(): 2310.85302734375
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.701416015625
error.norm(): 2310.85302734375
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.68310546875
error.norm(): 2310.85302734375
recons_image_embeddings_default: tensor([[ 0.0146, -0.0148, -0.0460,  ...,  0.0118,  0.0082,  0.0083],
        [-0.0018,  0.0212, -0.0113,  ...,  0.0519, -0.0585, -0.0361],
        [-0.0171, -0.0393, -0.0432,  ...,  0.0160,  0.0028,  0.0136],
        ...,
        [-0.0224, -0.0082, -0.0361,  ..., -0.0352,  0.0784,  0.0265],
        [-0.0062,  0.0247, -0.0572,  ...,  0.0121, -0.0083,  0.0222],
        [-0.0130,  0.0321, -0.0363,  ...,  0.0437,  0.0279, -0.0109]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-0

  0%|▏                                                                                                   | 3/1562 [00:05<46:33,  1.79s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.767822265625
error.norm(): 2365.204345703125
recons_image_embeddings_default: tensor([[ 0.0294,  0.0383,  0.0048,  ..., -0.0036,  0.0256,  0.0279],
        [ 0.0004,  0.0353, -0.0868,  ..., -0.0146,  0.0002,  0.0059],
        [ 0.0709, -0.0185, -0.0175,  ...,  0.0050,  0.0293,  0.0257],
        ...,
        [-0.0168, -0.0003, -0.0274,  ..., -0.0302,  0.0601, -0.0477],
        [ 0.0075,  0.0213, -0.0235,  ..., -0.0346,  0.0216,  0.0487],
        [ 0.0059, -0.0119, -0.0019,  ...,  0.0249, -0.0424,  0.0157]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-0.0047,  0.0190,  0.0004,  ...,  0.0007, -0.0066, -0.0041],
        [-0.0046,  0.0503, -0.0850,  ..., -0.0282,  0.0126, -0.0336],
        [ 0.0279, -0.0167, -0.0165,  ...,  0.0091,  0.0406, -0.0079],
        ...,
        [-0.0416

  0%|▎                                                                                                   | 4/1562 [00:07<44:40,  1.72s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.8222351074219
error.norm(): 2469.7783203125
recons_image_embeddings_default: tensor([[ 0.0495,  0.0061, -0.0375,  ..., -0.0073, -0.0049,  0.0464],
        [ 0.0656,  0.0185, -0.0169,  ..., -0.0542,  0.0806,  0.0280],
        [ 0.0439,  0.0136,  0.0194,  ..., -0.0279,  0.0640, -0.0370],
        ...,
        [ 0.0259,  0.0402, -0.0065,  ..., -0.0289,  0.0129,  0.0450],
        [ 0.0245,  0.0248, -0.0074,  ..., -0.0344,  0.0273, -0.0038],
        [ 0.0167,  0.0346, -0.0975,  ...,  0.0074,  0.0849, -0.0346]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[ 0.0221, -0.0102, -0.0342,  ..., -0.0159, -0.0002,  0.0062],
        [ 0.0247,  0.0171, -0.0037,  ..., -0.0563,  0.0921, -0.0063],
        [ 0.0150,  0.0101,  0.0405,  ..., -0.0254,  0.0830, -0.0672],
        ...,
        [-0.0106,

  0%|▎                                                                                                   | 5/1562 [00:08<43:45,  1.69s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.46685791015625
error.norm(): 2564.51220703125
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.6747131347656
error.norm(): 2564.51220703125
recons_image_embeddings_default: tensor([[-0.0659, -0.0776, -0.0139,  ..., -0.0386,  0.0279,  0.0018],
        [ 0.0148, -0.0243,  0.0026,  ..., -0.0218,  0.0321,  0.0376],
        [-0.0408, -0.0001, -0.0266,  ..., -0.0062,  0.0039, -0.0037],
        ...,
        [-0.0287,  0.0508, -0.0474,  ...,  0.0316,  0.0009,  0.0108],
        [-0.0258, -0.0096,  0.0075,  ..., -0.0291, -0.0626, -0.0089],
        [ 0.0060, -0.0028, -0.0319,  ..., -0.0128,  0.0170, -0.0358]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-0.0632, -0.0583, -0.0216,  ..., -0.0706,  0.0432, -0.0063],
        [-0.0229, -0.0177,  0.0079,  ..., -0.0422,  

  0%|▍                                                                                                   | 6/1562 [00:10<43:34,  1.68s/it]

recons_image_embeddings_default: tensor([[-0.0091, -0.0222,  0.0052,  ..., -0.0296,  0.0704, -0.0185],
        [-0.0131, -0.0249, -0.0013,  ...,  0.0340, -0.0035,  0.0514],
        [-0.0462,  0.0257, -0.0337,  ..., -0.0528, -0.0081, -0.0102],
        ...,
        [-0.0051, -0.0604, -0.0089,  ..., -0.0775, -0.0198, -0.0302],
        [-0.0528, -0.0173,  0.0297,  ..., -0.0252,  0.0219,  0.0388],
        [ 0.0022,  0.0004, -0.0324,  ..., -0.0625,  0.0004,  0.0154]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-0.0655, -0.0474, -0.0244,  ..., -0.0617,  0.0812, -0.0572],
        [-0.0407, -0.0129, -0.0231,  ...,  0.0182, -0.0011,  0.0063],
        [-0.0710,  0.0287, -0.0384,  ..., -0.0522,  0.0016, -0.0267],
        ...,
        [-0.0286, -0.0381, -0.0374,  ..., -0.0578, -0.0418, -0.0754],
        [-0.0856, -0.0374, -0.0069,  ..., -0.0449,  0.0231, -0

  0%|▍                                                                                                   | 7/1562 [00:12<43:50,  1.69s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.8359680175781
error.norm(): 2667.07763671875
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.68646240234375
error.norm(): 2667.07763671875
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.7897644042969
error.norm(): 2667.07763671875
recons_image_embeddings_default: tensor([[ 0.0136,  0.0643,  0.0395,  ..., -0.0361, -0.0052, -0.0236],
        [ 0.0247, -0.0430,  0.0023,  ...,  0.0616,  0.0209,  0.0213],
        [ 0.0044,  0.0217, -0.0404,  ..., -0.0586,  0.0506,  0.0444],
        ...,
        [-0.0156, -0.0046, -0.0264,  ..., -0.0680, -0.0048,  0.0319],
        [ 0.0188, -0.0514, -0.0120,  ...,  0.0095, -0.0219, -0.0319],
        [-0.0656,  0.0328,  0.0056,  ..., -0.0298,  0.0554, -0.0084]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor

  1%|▌                                                                                                   | 8/1562 [00:13<43:56,  1.70s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.7771911621094
error.norm(): 2228.6064453125
recons_image_embeddings_default: tensor([[-0.0251,  0.0197, -0.1125,  ...,  0.0378, -0.0042, -0.0264],
        [ 0.0071,  0.0089,  0.0063,  ...,  0.0080, -0.0004,  0.0045],
        [-0.0034, -0.0238, -0.0226,  ..., -0.0305,  0.0216,  0.0113],
        ...,
        [-0.0507, -0.0230,  0.0469,  ...,  0.0239,  0.0170,  0.0040],
        [-0.0284,  0.0204,  0.0041,  ..., -0.0395, -0.0102, -0.0003],
        [-0.0714,  0.0389,  0.0543,  ..., -0.0956,  0.0180,  0.0454]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-0.0368,  0.0256, -0.1152,  ...,  0.0309,  0.0186, -0.0421],
        [-0.0197,  0.0014,  0.0021,  ..., -0.0055, -0.0028, -0.0483],
        [-0.0465, -0.0164, -0.0285,  ..., -0.0345,  0.0568, -0.0328],
        ...,
        [-0.0624,

  1%|▌                                                                                                   | 9/1562 [00:15<43:06,  1.67s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.8271179199219
error.norm(): 3044.28955078125
recons_image_embeddings_default: tensor([[ 0.0040, -0.0179, -0.0437,  ...,  0.0046,  0.0661,  0.0756],
        [ 0.0060,  0.0323, -0.0683,  ..., -0.0043, -0.0139, -0.0015],
        [ 0.0149, -0.0036, -0.0349,  ...,  0.0409,  0.0110,  0.0132],
        ...,
        [ 0.0022,  0.0096,  0.0305,  ..., -0.0361,  0.0311,  0.0256],
        [ 0.0050, -0.0871, -0.0127,  ..., -0.0107,  0.0622,  0.0261],
        [-0.0142, -0.0271,  0.0240,  ..., -0.0353,  0.0251, -0.0124]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-0.0230, -0.0092, -0.0291,  ..., -0.0068,  0.0309,  0.0453],
        [-0.0226,  0.0294, -0.0678,  ..., -0.0298,  0.0094, -0.0346],
        [-0.0211,  0.0217, -0.0278,  ...,  0.0287, -0.0109, -0.0067],
        ...,
        [-0.0327

  1%|▋                                                                                                  | 10/1562 [00:17<43:22,  1.68s/it]

recons_image_embeddings_default: tensor([[ 0.0080,  0.0201, -0.0240,  ..., -0.0847,  0.0018, -0.1048],
        [ 0.0137,  0.0007, -0.0480,  ..., -0.0553,  0.0555,  0.0146],
        [-0.0104, -0.0148,  0.0434,  ...,  0.0549,  0.0064, -0.0264],
        ...,
        [-0.0120, -0.0477, -0.0498,  ..., -0.0423,  0.0376,  0.0224],
        [ 0.0618,  0.0314, -0.0301,  ..., -0.0508,  0.0342, -0.0084],
        [-0.0197, -0.0276, -0.0188,  ..., -0.0247,  0.0591,  0.0142]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-0.0359,  0.0419, -0.0168,  ..., -0.0664,  0.0324, -0.1142],
        [-0.0144,  0.0092, -0.0708,  ..., -0.0678,  0.0430,  0.0126],
        [-0.0280,  0.0043,  0.0355,  ...,  0.0291,  0.0006, -0.0713],
        ...,
        [-0.0089, -0.0360, -0.0296,  ..., -0.0573,  0.0615, -0.0039],
        [ 0.0196,  0.0515, -0.0387,  ..., -0.0598,  0.0364, -0

  1%|▋                                                                                                  | 11/1562 [00:18<43:24,  1.68s/it]

steering norm: 399.1156921386719
error.norm(): 3338.303955078125
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.5744934082031
error.norm(): 3338.303955078125
recons_image_embeddings_default: tensor([[-0.0130, -0.0772, -0.0206,  ..., -0.0815,  0.0094, -0.0239],
        [-0.0159, -0.0297, -0.0343,  ...,  0.0062, -0.0066,  0.0323],
        [ 0.0178,  0.0122,  0.0315,  ..., -0.0626,  0.0436,  0.0448],
        ...,
        [-0.0498, -0.0380, -0.0179,  ...,  0.0013, -0.0028,  0.0136],
        [-0.0087,  0.0241,  0.0746,  ..., -0.0297, -0.0380, -0.0052],
        [-0.0129,  0.0097, -0.0030,  ..., -0.0588,  0.0387, -0.0520]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-0.0320, -0.1234, -0.0292,  ..., -0.1073,  0.0277, -0.0638],
        [-0.0547, -0.0396, -0.0517,  ..., -0.0268, -0.0128, -0.0280],
        [-0.0038, -0.0074,  0.023

  1%|▊                                                                                                  | 12/1562 [00:20<43:46,  1.69s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.67474365234375
error.norm(): 2579.15478515625
recons_image_embeddings_default: tensor([[ 0.0439,  0.0014, -0.0270,  ...,  0.0224,  0.0149, -0.0006],
        [ 0.0296, -0.0272,  0.0470,  ..., -0.0549,  0.0208,  0.0793],
        [-0.0102, -0.0428,  0.0118,  ..., -0.0232,  0.0985, -0.0043],
        ...,
        [ 0.0045, -0.0752, -0.0561,  ..., -0.0132, -0.0089,  0.0292],
        [ 0.0264, -0.0697, -0.0281,  ...,  0.0057, -0.0057,  0.0436],
        [-0.0290, -0.0008,  0.0018,  ..., -0.0300, -0.0533, -0.0107]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[ 0.0062,  0.0185, -0.0456,  ...,  0.0198,  0.0208, -0.0356],
        [-0.0026, -0.0160,  0.0291,  ..., -0.0860,  0.0207,  0.0224],
        [-0.0572, -0.0447,  0.0176,  ..., -0.0356,  0.1024, -0.0363],
        ...,
        [-0.023

  1%|▊                                                                                                  | 13/1562 [00:22<43:39,  1.69s/it]

recons_image_embeddings_default: tensor([[ 0.0066, -0.0035, -0.0339,  ..., -0.0250, -0.0086,  0.0168],
        [ 0.0135,  0.0411, -0.0546,  ..., -0.0252,  0.0080, -0.0085],
        [-0.0094,  0.0080,  0.0095,  ..., -0.0244, -0.0238, -0.0208],
        ...,
        [ 0.0237,  0.0415, -0.0585,  ...,  0.0506,  0.0206, -0.0185],
        [-0.0209,  0.0144,  0.0238,  ..., -0.0035,  0.0503, -0.0142],
        [-0.0378, -0.0238, -0.0048,  ...,  0.0256,  0.0074,  0.0099]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-0.0228,  0.0052, -0.0448,  ..., -0.0392, -0.0343, -0.0332],
        [-0.0070,  0.0683, -0.0529,  ..., -0.0270,  0.0254, -0.0363],
        [-0.0284,  0.0227, -0.0013,  ..., -0.0392, -0.0360, -0.0583],
        ...,
        [ 0.0057,  0.0474, -0.0523,  ...,  0.0347,  0.0480, -0.0451],
        [-0.0476,  0.0329,  0.0226,  ..., -0.0291,  0.0679, -0

  1%|▉                                                                                                  | 14/1562 [00:23<43:01,  1.67s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.7748107910156
error.norm(): 2549.20654296875
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.63983154296875
error.norm(): 2549.20654296875
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.7454528808594
error.norm(): 2549.20654296875
recons_image_embeddings_default: tensor([[ 3.0857e-02, -3.7327e-02,  2.9741e-02,  ..., -1.3618e-02,
         -3.1297e-02, -6.3099e-02],
        [ 4.6497e-02, -6.0229e-02, -1.1432e-01,  ...,  5.0438e-02,
         -6.5523e-03,  1.9855e-02],
        [-1.8793e-02,  8.1332e-05, -3.0495e-02,  ..., -5.7876e-02,
          3.4546e-02, -6.6863e-03],
        ...,
        [ 4.3789e-02, -2.6261e-02, -3.2087e-02,  ..., -5.9834e-03,
          2.9529e-02,  3.4642e-02],
        [ 4.9388e-03, -4.1829e-02,  2.6968e-02,  ..., -2.0120e-02,
         -4.0581e-04,  6.1464e-02],
        [-3.2608e-02,  4.9058e-03, -2.9654e-02,  ...,  1.9730e-02,
         -1.3808e-02, -

  1%|▉                                                                                                  | 15/1562 [00:25<43:29,  1.69s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.6725158691406
error.norm(): 2811.168701171875
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.7216491699219
error.norm(): 2811.168701171875
recons_image_embeddings_default: tensor([[ 0.0529,  0.0204, -0.0831,  ...,  0.0161,  0.0585,  0.0724],
        [-0.0071, -0.0163, -0.0256,  ..., -0.0149, -0.0009,  0.0406],
        [-0.0194, -0.0645, -0.0798,  ...,  0.0213,  0.0217,  0.0056],
        ...,
        [-0.0190, -0.0162,  0.0042,  ..., -0.0216, -0.0398,  0.0481],
        [-0.0355,  0.0009,  0.0078,  ..., -0.0203, -0.0028,  0.0338],
        [ 0.0376,  0.0622, -0.0016,  ..., -0.0032, -0.0219,  0.0199]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[ 0.0292,  0.0297, -0.0914,  ...,  0.0079,  0.0736,  0.0424],
        [-0.0765, -0.0367, -0.0353,  ..., -0.0315, 

  1%|█                                                                                                  | 16/1562 [00:27<43:26,  1.69s/it]

recons_image_embeddings_default: tensor([[-1.0123e-02, -1.9493e-05, -2.6600e-02,  ...,  2.2191e-03,
         -4.0753e-02,  5.6922e-02],
        [-2.8107e-02, -4.8151e-02,  4.8595e-02,  ..., -1.0011e-02,
          4.8170e-03, -3.7860e-03],
        [-4.6994e-02,  2.6106e-02,  3.9159e-02,  ..., -2.8148e-02,
          1.2817e-02,  1.9352e-02],
        ...,
        [ 1.4949e-02,  4.6788e-02, -8.0091e-02,  ...,  3.7032e-02,
          1.3509e-02, -2.4777e-02],
        [-5.2004e-02,  1.7640e-02, -2.0967e-02,  ..., -3.8390e-02,
          2.7551e-02,  7.8389e-03],
        [ 2.4011e-02,  2.7851e-02, -6.4959e-02,  ...,  3.3868e-02,
         -6.8626e-03,  5.3334e-03]], device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-0.0309, -0.0206, -0.0512,  ..., -0.0117, -0.0191,  0.0078],
        [-0.0556, -0.0445,  0.0261,  ..., -0.0408,  0.0124, -0.0370],
        [-0.0851,  0.061

  1%|█                                                                                                  | 17/1562 [00:29<43:52,  1.70s/it]

recons_image_embeddings_default: tensor([[-0.0322, -0.0296, -0.0416,  ..., -0.0232,  0.0461, -0.0357],
        [ 0.0014,  0.0217, -0.0247,  ...,  0.0009, -0.0133, -0.0056],
        [-0.0023, -0.0210, -0.0320,  ...,  0.0059,  0.0468, -0.0193],
        ...,
        [ 0.0278,  0.0113, -0.0529,  ..., -0.0562,  0.0390,  0.0336],
        [ 0.0494, -0.0434,  0.0149,  ..., -0.0331,  0.0291,  0.0328],
        [ 0.0299, -0.0482,  0.0056,  ...,  0.0083,  0.0061, -0.0401]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-5.9756e-02, -6.6766e-03, -4.3641e-02,  ..., -7.1750e-02,
          4.0879e-02, -4.9814e-02],
        [-5.8973e-02,  4.6296e-02, -1.0862e-02,  ...,  6.8289e-03,
         -2.2303e-03, -7.9670e-02],
        [-3.7009e-02,  3.8733e-03, -3.4702e-02,  ...,  5.9589e-03,
          4.0708e-02, -6.4557e-02],
        ...,
        [-3.6566e-05, -2.5564e-03

  1%|█▏                                                                                                 | 18/1562 [00:30<43:27,  1.69s/it]

recons_image_embeddings_default: tensor([[ 0.0249, -0.0052,  0.0406,  ..., -0.0369, -0.0524,  0.0142],
        [-0.0185,  0.0091, -0.0139,  ...,  0.0364,  0.0191,  0.0067],
        [ 0.0373,  0.0415, -0.0883,  ..., -0.0298, -0.0183, -0.0097],
        ...,
        [-0.0168, -0.0594, -0.0642,  ..., -0.0121, -0.0572, -0.0024],
        [-0.0403, -0.0072, -0.0385,  ...,  0.0055,  0.0492,  0.0175],
        [ 0.0266,  0.0374, -0.0974,  ..., -0.0150,  0.0210,  0.0188]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-0.0055,  0.0101, -0.0014,  ..., -0.0178, -0.0478, -0.0230],
        [-0.0295,  0.0111, -0.0199,  ...,  0.0351,  0.0306, -0.0317],
        [-0.0123,  0.0462, -0.0715,  ..., -0.0491, -0.0019, -0.0520],
        ...,
        [-0.0411, -0.0489, -0.0562,  ..., -0.0478, -0.0388, -0.0299],
        [-0.0832, -0.0141, -0.0681,  ...,  0.0055,  0.0657, -0

  1%|█▏                                                                                                 | 19/1562 [00:32<43:12,  1.68s/it]

steering norm: 399.6585998535156
error.norm(): 3094.82275390625
recons_image_embeddings_default: tensor([[ 0.0103,  0.0020, -0.0307,  ...,  0.0154,  0.0204,  0.0228],
        [ 0.0271, -0.0033, -0.0069,  ...,  0.0360, -0.0547, -0.0055],
        [-0.0337, -0.0462, -0.0443,  ...,  0.1008,  0.0419, -0.0356],
        ...,
        [-0.0360, -0.0501, -0.0725,  ..., -0.0636,  0.0273,  0.0100],
        [-0.0168, -0.0183, -0.0212,  ..., -0.0512,  0.0355, -0.0173],
        [ 0.0083, -0.0236, -0.0229,  ..., -0.0151, -0.0228,  0.0216]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-0.0137,  0.0011, -0.0220,  ...,  0.0168,  0.0234,  0.0026],
        [ 0.0180,  0.0149, -0.0243,  ...,  0.0258, -0.0379, -0.0434],
        [-0.0644, -0.0072, -0.0187,  ...,  0.0738,  0.0150, -0.0611],
        ...,
        [-0.0469, -0.0507, -0.0599,  ..., -0.0843,  0.0251, -0.0128]

  1%|█▎                                                                                                 | 20/1562 [00:34<43:22,  1.69s/it]

recons_image_embeddings_default: tensor([[-0.0090, -0.0442, -0.0502,  ..., -0.0434,  0.0128,  0.0115],
        [ 0.0505,  0.0353,  0.0103,  ..., -0.0384,  0.0117, -0.0180],
        [ 0.0080, -0.0256,  0.0519,  ..., -0.0230,  0.0744,  0.0185],
        ...,
        [ 0.0477,  0.0352, -0.0426,  ...,  0.0079,  0.0191,  0.0104],
        [ 0.0263, -0.0340, -0.0586,  ...,  0.0640,  0.0076,  0.0071],
        [ 0.0060,  0.0241, -0.0246,  ...,  0.0483, -0.0094,  0.0154]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-0.0247, -0.0359, -0.0426,  ..., -0.0415,  0.0179, -0.0295],
        [ 0.0237,  0.0375, -0.0126,  ..., -0.0503,  0.0310, -0.0524],
        [-0.0249, -0.0169,  0.0125,  ..., -0.0355,  0.0605, -0.0554],
        ...,
        [ 0.0080,  0.0368, -0.0478,  ...,  0.0033,  0.0423, -0.0340],
        [ 0.0119,  0.0040, -0.0254,  ...,  0.0810,  0.0269, -0

  1%|█▎                                                                                                 | 21/1562 [00:35<43:52,  1.71s/it]

steering norm: 399.8009033203125
error.norm(): 2323.871826171875
recons_image_embeddings_default: tensor([[-0.0306, -0.0250, -0.0007,  ..., -0.0178, -0.0084,  0.0104],
        [ 0.0227,  0.0263, -0.0147,  ..., -0.0113,  0.0022,  0.0263],
        [-0.0270, -0.0141, -0.0620,  ..., -0.0605,  0.0086,  0.0268],
        ...,
        [ 0.0152, -0.0904, -0.0523,  ...,  0.0430,  0.0220,  0.0356],
        [ 0.0120, -0.0614, -0.0669,  ...,  0.0342,  0.0187,  0.0043],
        [-0.0288,  0.0040, -0.0067,  ..., -0.0666,  0.0190,  0.0108]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-0.0459, -0.0309,  0.0143,  ..., -0.0130,  0.0021,  0.0025],
        [ 0.0076,  0.0201, -0.0091,  ..., -0.0262,  0.0165,  0.0136],
        [-0.0363, -0.0331, -0.0740,  ..., -0.0569,  0.0324,  0.0139],
        ...,
        [-0.0183, -0.0590, -0.0324,  ...,  0.0112,  0.0331, -0.0130

  1%|█▍                                                                                                 | 22/1562 [00:37<43:59,  1.71s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.68316650390625
error.norm(): 2312.8857421875
recons_image_embeddings_default: tensor([[-0.0445,  0.0464,  0.0130,  ...,  0.0248, -0.0082,  0.0380],
        [ 0.0079, -0.0317, -0.0168,  ..., -0.0359, -0.0070,  0.0316],
        [ 0.0051,  0.0209,  0.0186,  ...,  0.0262,  0.0384,  0.0252],
        ...,
        [-0.0704,  0.0434,  0.0131,  ..., -0.0152,  0.0271,  0.0243],
        [ 0.0044,  0.0426,  0.0189,  ..., -0.0370,  0.0333, -0.0055],
        [ 0.0121, -0.0312, -0.0082,  ...,  0.0059, -0.0518, -0.0207]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-0.0551,  0.0481,  0.0035,  ...,  0.0051, -0.0078, -0.0031],
        [-0.0117,  0.0068, -0.0068,  ..., -0.0464,  0.0002, -0.0133],
        [-0.0080,  0.0328,  0.0348,  ...,  0.0271,  0.0492, -0.0010],
        ...,
        [-0.1054

  1%|█▍                                                                                                 | 23/1562 [00:39<43:04,  1.68s/it]

recons_image_embeddings_default: tensor([[ 0.0445,  0.0149, -0.0705,  ...,  0.0050,  0.0617,  0.0108],
        [ 0.0189,  0.0040, -0.0451,  ..., -0.0066,  0.0361,  0.0216],
        [-0.0351, -0.0348, -0.0253,  ...,  0.0324, -0.0282, -0.0271],
        ...,
        [ 0.0527, -0.0521, -0.0397,  ..., -0.0349,  0.0333,  0.0349],
        [ 0.0179,  0.0562, -0.0993,  ..., -0.0440,  0.0437,  0.0751],
        [-0.0468,  0.0275, -0.0362,  ...,  0.0414,  0.0054,  0.0049]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-0.0011,  0.0163, -0.0807,  ...,  0.0021,  0.0502, -0.0266],
        [-0.0235,  0.0089, -0.0394,  ..., -0.0194,  0.0347,  0.0071],
        [-0.0652, -0.0049, -0.0305,  ...,  0.0109, -0.0143, -0.0655],
        ...,
        [ 0.0489, -0.0270, -0.0130,  ..., -0.0367,  0.0274, -0.0110],
        [-0.0165,  0.0661, -0.0632,  ..., -0.0354,  0.0494,  0

  2%|█▌                                                                                                 | 24/1562 [00:40<43:21,  1.69s/it]

recons_image_embeddings_default: tensor([[-0.0294,  0.0457, -0.0690,  ..., -0.0641, -0.0291, -0.0056],
        [-0.0077,  0.0052, -0.0322,  ...,  0.0229,  0.0335,  0.0005],
        [ 0.0368,  0.0442, -0.1021,  ...,  0.0114,  0.0835, -0.0410],
        ...,
        [ 0.0247, -0.0175, -0.0634,  ..., -0.0233, -0.0365, -0.0306],
        [ 0.0066,  0.0023, -0.0128,  ...,  0.0422, -0.0042,  0.0455],
        [ 0.0490,  0.0039, -0.0487,  ...,  0.0212, -0.0274, -0.0346]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-0.0915,  0.0400, -0.0542,  ..., -0.0440,  0.0003, -0.0385],
        [-0.0227,  0.0294, -0.0093,  ...,  0.0129,  0.0097, -0.0287],
        [ 0.0117,  0.0422, -0.1084,  ..., -0.0162,  0.0911, -0.0511],
        ...,
        [ 0.0074, -0.0243, -0.0633,  ..., -0.0192, -0.0225, -0.0344],
        [-0.0218, -0.0130, -0.0319,  ...,  0.0048, -0.0195, -0

  2%|█▌                                                                                                 | 25/1562 [00:42<43:11,  1.69s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.7519226074219
error.norm(): 3388.11767578125
recons_image_embeddings_default: tensor([[ 0.0270,  0.0096, -0.0460,  ...,  0.0924, -0.0420,  0.0171],
        [-0.0119, -0.0317, -0.0292,  ..., -0.0379, -0.0559,  0.0552],
        [-0.0186,  0.0147,  0.0256,  ..., -0.0259,  0.0489,  0.0561],
        ...,
        [ 0.0264,  0.0227, -0.0372,  ..., -0.0570,  0.0118,  0.0153],
        [ 0.0347, -0.0542, -0.0238,  ..., -0.0610,  0.0917, -0.0443],
        [-0.0255, -0.0291, -0.0003,  ..., -0.0300,  0.0092,  0.0352]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-0.0341,  0.0347, -0.0440,  ...,  0.0610, -0.0223, -0.0075],
        [-0.0593, -0.0338, -0.0440,  ..., -0.0506, -0.0519,  0.0241],
        [-0.0340,  0.0106,  0.0356,  ..., -0.0372,  0.0344,  0.0280],
        ...,
        [-0.0202

  2%|█▋                                                                                                 | 26/1562 [00:44<43:34,  1.70s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.91253662109375
error.norm(): 1731.2430419921875
recons_image_embeddings_default: tensor([[-0.0041, -0.0084, -0.0384,  ...,  0.0701,  0.0236, -0.0138],
        [-0.0111, -0.0433, -0.0195,  ...,  0.0080,  0.0627,  0.0130],
        [ 0.0099, -0.0763, -0.0002,  ..., -0.0086,  0.0178,  0.0177],
        ...,
        [-0.0519,  0.0217, -0.0446,  ...,  0.0055, -0.0005,  0.0500],
        [-0.0681,  0.0553, -0.0292,  ...,  0.0284,  0.0135, -0.0313],
        [-0.0419, -0.0415,  0.0084,  ..., -0.0029,  0.0206,  0.0439]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-0.0352, -0.0016, -0.0244,  ...,  0.0370,  0.0240, -0.0358],
        [-0.0276, -0.0269, -0.0251,  ...,  0.0009,  0.0799, -0.0154],
        [-0.0031, -0.0563, -0.0010,  ..., -0.0383,  0.0165, -0.0172],
        ...,
        [-0.0

  2%|█▋                                                                                                 | 27/1562 [00:45<43:10,  1.69s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.6404724121094
error.norm(): 2684.2802734375
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.8171081542969
error.norm(): 2684.2802734375
recons_image_embeddings_default: tensor([[ 0.0039, -0.0403, -0.0278,  ...,  0.0384, -0.0143,  0.0516],
        [-0.0704, -0.0366,  0.0039,  ...,  0.0477, -0.0467,  0.0216],
        [ 0.0377, -0.0097, -0.0432,  ..., -0.0317,  0.0262, -0.0007],
        ...,
        [ 0.0572,  0.0105, -0.0037,  ..., -0.0018,  0.0049,  0.1183],
        [-0.0230, -0.0606, -0.0256,  ..., -0.0892,  0.0020,  0.0583],
        [-0.0327, -0.0007,  0.0053,  ..., -0.0364,  0.0500,  0.0154]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-0.0096, -0.0091, -0.0418,  ...,  0.0113, -0.0273, -0.0142],
        [-0.0976, -0.0511,  0.0023,  ...,  0.0149, -0.0

  2%|█▊                                                                                                 | 28/1562 [00:47<42:53,  1.68s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.8408203125
error.norm(): 2358.03076171875
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.6822814941406
error.norm(): 2358.03076171875
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.7277526855469
error.norm(): 2358.03076171875
recons_image_embeddings_default: tensor([[-0.0078,  0.0389, -0.0365,  ..., -0.0480, -0.0111, -0.0273],
        [-0.0442, -0.0164,  0.0088,  ..., -0.0148, -0.0095,  0.0082],
        [-0.0058,  0.0105,  0.0261,  ...,  0.0267,  0.0702,  0.0393],
        ...,
        [-0.0747, -0.0715,  0.0100,  ...,  0.0082, -0.0435,  0.0129],
        [-0.0294,  0.0415, -0.0458,  ...,  0.0290,  0.0223, -0.0160],
        [-0.0163, -0.0082, -0.0504,  ...,  0.0229, -0.0041, -0.0147]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-

  2%|█▊                                                                                                 | 29/1562 [00:49<42:54,  1.68s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.7201232910156
error.norm(): 2920.3466796875
recons_image_embeddings_default: tensor([[ 4.0848e-03, -7.8454e-04, -3.9075e-05,  ...,  5.5200e-02,
         -4.7378e-02,  4.3859e-03],
        [ 1.6751e-02, -7.4950e-02, -9.3433e-03,  ..., -6.2861e-02,
          6.9777e-03,  3.0260e-02],
        [-2.5740e-02,  2.0883e-02, -4.8572e-02,  ..., -1.0951e-02,
         -2.3305e-02, -2.4724e-02],
        ...,
        [ 4.7538e-02, -2.9911e-02, -5.7623e-02,  ...,  2.7642e-02,
          5.5574e-02, -3.3954e-02],
        [-6.0117e-02, -2.5773e-03, -5.2277e-02,  ..., -2.0661e-02,
         -9.3998e-03,  5.6400e-03],
        [ 7.4040e-02, -5.6314e-02, -4.1555e-02,  ..., -3.4234e-02,
          4.3222e-02,  3.0359e-02]], device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-0.0093,  0.0163, -0.0286,  ...,  0.0077

  2%|█▉                                                                                                 | 30/1562 [00:50<43:09,  1.69s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.81964111328125
error.norm(): 2064.347412109375
recons_image_embeddings_default: tensor([[ 0.0116,  0.0422, -0.0598,  ..., -0.0006,  0.0308, -0.0182],
        [ 0.0163, -0.0430, -0.0818,  ...,  0.0249,  0.0426,  0.0366],
        [-0.0138, -0.0088, -0.0265,  ..., -0.0035,  0.0206, -0.0343],
        ...,
        [ 0.0094, -0.0205,  0.0028,  ..., -0.0030,  0.0488, -0.0270],
        [-0.0498,  0.0028, -0.0310,  ..., -0.0162, -0.0259, -0.0035],
        [-0.0482,  0.0163,  0.0063,  ..., -0.0193, -0.0568,  0.0857]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-0.0069,  0.0390, -0.0788,  ...,  0.0095,  0.0369, -0.0382],
        [-0.0170, -0.0178, -0.0822,  ..., -0.0112,  0.0366,  0.0064],
        [-0.0248, -0.0343, -0.0313,  ..., -0.0073, -0.0037, -0.0772],
        ...,
        [-0.03

  2%|█▉                                                                                                 | 31/1562 [00:52<43:22,  1.70s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.7594909667969
error.norm(): 2245.955322265625
recons_image_embeddings_default: tensor([[-6.2864e-03, -3.4818e-02, -4.4114e-02,  ...,  1.7224e-03,
          4.7383e-02, -1.8782e-02],
        [ 6.0741e-03,  5.7774e-02, -4.1267e-02,  ..., -4.4846e-02,
         -1.2800e-02, -1.3800e-02],
        [-1.0291e-02,  1.0373e-02, -8.1637e-02,  ..., -1.1007e-02,
         -1.4136e-06, -1.6210e-02],
        ...,
        [ 1.3172e-02,  2.4769e-02, -5.0021e-02,  ..., -7.3066e-02,
          2.3670e-02,  4.7158e-02],
        [-3.7633e-03,  4.8112e-03, -5.0173e-02,  ...,  4.2311e-02,
          2.5183e-04,  1.4903e-02],
        [-6.2204e-02,  4.5754e-03, -5.1301e-02,  ..., -2.5571e-02,
          3.5229e-02, -1.8355e-02]], device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-0.0310, -0.0370, -0.0448,  ..., -0.02

  2%|██                                                                                                 | 32/1562 [00:54<42:36,  1.67s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.8326721191406
error.norm(): 2746.43505859375
recons_image_embeddings_default: tensor([[-0.0298, -0.0360, -0.0337,  ..., -0.0339,  0.0265,  0.0199],
        [ 0.0094,  0.0669, -0.0373,  ...,  0.0015, -0.0134,  0.0047],
        [-0.0278, -0.0507, -0.0408,  ...,  0.0289,  0.0574, -0.0410],
        ...,
        [ 0.0047, -0.0008, -0.0305,  ..., -0.0711,  0.0237,  0.0103],
        [-0.0008,  0.0238, -0.0671,  ..., -0.0078,  0.0576, -0.0305],
        [ 0.0358,  0.0244, -0.0376,  ...,  0.0069,  0.0310, -0.0235]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-0.0828, -0.0390, -0.0509,  ..., -0.0576,  0.0417, -0.0025],
        [-0.0200,  0.0722, -0.0404,  ..., -0.0264, -0.0098, -0.0273],
        [-0.0173, -0.0087, -0.0197,  ...,  0.0250,  0.0438, -0.0762],
        ...,
        [-0.0045

  2%|██                                                                                                 | 33/1562 [00:56<43:05,  1.69s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.5735168457031
error.norm(): 2933.069580078125
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.7022399902344
error.norm(): 2933.069580078125
recons_image_embeddings_default: tensor([[-0.0360, -0.0183,  0.0082,  ..., -0.0053,  0.0276,  0.0477],
        [ 0.0397,  0.0348, -0.0864,  ..., -0.0472, -0.0012,  0.0530],
        [ 0.0714, -0.0002, -0.0066,  ..., -0.0094,  0.0279,  0.0507],
        ...,
        [ 0.0392, -0.0492, -0.0364,  ...,  0.0063,  0.0090,  0.0298],
        [ 0.0292,  0.0723,  0.0164,  ...,  0.0657, -0.0080,  0.0032],
        [ 0.0287, -0.0444, -0.0405,  ..., -0.0087, -0.0261,  0.0236]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-0.0886, -0.0230, -0.0004,  ..., -0.0297,  0.0381,  0.0083],
        [ 0.0083,  0.0472, -0.0868,  ..., -0.0548, 

  2%|██▏                                                                                                | 34/1562 [00:57<42:30,  1.67s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.7355651855469
error.norm(): 2285.02587890625
recons_image_embeddings_default: tensor([[ 0.0796,  0.0781, -0.0077,  ..., -0.0863,  0.0117,  0.0183],
        [ 0.0302,  0.0386, -0.0574,  ..., -0.0549,  0.0865, -0.0051],
        [-0.0264,  0.0155,  0.0059,  ..., -0.0588,  0.0069,  0.0122],
        ...,
        [-0.0376, -0.0352, -0.0430,  ...,  0.0070,  0.0009, -0.0034],
        [-0.0387,  0.0273,  0.0886,  ..., -0.0516,  0.0201,  0.0242],
        [ 0.0291, -0.0030, -0.0537,  ..., -0.0139, -0.0238,  0.0449]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[ 0.0639,  0.0737, -0.0092,  ..., -0.0900,  0.0422, -0.0064],
        [ 0.0019,  0.0363, -0.0500,  ..., -0.0611,  0.0956, -0.0425],
        [-0.0445,  0.0096,  0.0054,  ..., -0.0539,  0.0205, -0.0046],
        ...,
        [-0.0487

  2%|██▏                                                                                                | 35/1562 [00:59<43:04,  1.69s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.4921569824219
error.norm(): 2890.235107421875
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.62542724609375
error.norm(): 2890.235107421875
recons_image_embeddings_default: tensor([[-0.0094, -0.0034,  0.0173,  ..., -0.0465,  0.0069, -0.0010],
        [-0.0318,  0.0336, -0.0542,  ...,  0.0669,  0.0317,  0.0294],
        [-0.0006,  0.0260, -0.0503,  ...,  0.0270, -0.0689, -0.0355],
        ...,
        [ 0.0054,  0.0011,  0.0068,  ...,  0.0412, -0.0130,  0.0178],
        [-0.0513,  0.0085,  0.0171,  ..., -0.0081,  0.0315,  0.0210],
        [ 0.0371,  0.0499, -0.0535,  ..., -0.0140,  0.0429,  0.0090]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-0.0334, -0.0123,  0.0200,  ..., -0.0687, -0.0091, -0.0309],
        [-0.0695,  0.0189, -0.0625,  ...,  0.0200,

  2%|██▎                                                                                                | 36/1562 [01:01<42:39,  1.68s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.8243713378906
error.norm(): 2761.372314453125
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.5870056152344
error.norm(): 2761.372314453125
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.6393127441406
error.norm(): 2761.372314453125
recons_image_embeddings_default: tensor([[ 0.0623,  0.0777,  0.0063,  ..., -0.0067,  0.0522,  0.0476],
        [ 0.0187,  0.0050, -0.0835,  ..., -0.0154,  0.0371,  0.0015],
        [-0.0297, -0.0223, -0.0064,  ...,  0.0017,  0.0356,  0.0184],
        ...,
        [ 0.0063, -0.0421, -0.0617,  ..., -0.0168,  0.0134,  0.0109],
        [-0.0118, -0.0239,  0.0317,  ..., -0.0229,  0.0298,  0.0419],
        [-0.0203, -0.0077, -0.0238,  ...,  0.0089,  0.0357,  0.0574]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tens

  2%|██▎                                                                                                | 37/1562 [01:02<42:35,  1.68s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.8702087402344
error.norm(): 1906.9107666015625
recons_image_embeddings_default: tensor([[ 0.0153, -0.0178, -0.0092,  ..., -0.0420, -0.0117, -0.0437],
        [-0.0590, -0.0581,  0.0203,  ..., -0.0150,  0.0105,  0.0091],
        [ 0.0420,  0.0145, -0.0139,  ..., -0.0262,  0.0234,  0.0172],
        ...,
        [ 0.0100, -0.0004, -0.0169,  ...,  0.0108,  0.0929, -0.0491],
        [-0.0086, -0.0131, -0.0811,  ..., -0.0376, -0.0078,  0.0257],
        [-0.0255, -0.0063, -0.0057,  ..., -0.0507,  0.0096,  0.0145]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-0.0161,  0.0096,  0.0122,  ..., -0.0294, -0.0002, -0.0660],
        [-0.0795, -0.0382,  0.0056,  ..., -0.0332,  0.0075, -0.0508],
        [ 0.0169, -0.0014, -0.0295,  ..., -0.0165,  0.0287,  0.0082],
        ...,
        [-0.01

  2%|██▍                                                                                                | 38/1562 [01:04<42:38,  1.68s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.67303466796875
error.norm(): 2543.878173828125
recons_image_embeddings_default: tensor([[ 0.0206, -0.0176,  0.0019,  ..., -0.0258,  0.0213,  0.0081],
        [ 0.0374, -0.0067, -0.0290,  ...,  0.0250,  0.0456, -0.0041],
        [ 0.0119,  0.0131, -0.0342,  ...,  0.0338,  0.0611, -0.0178],
        ...,
        [-0.0136,  0.0118, -0.0549,  ...,  0.0049, -0.0020,  0.0126],
        [ 0.0001,  0.0953, -0.0401,  ..., -0.0334,  0.0399, -0.0450],
        [ 0.0198, -0.0157, -0.0860,  ...,  0.0869, -0.0042, -0.0235]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[ 0.0078, -0.0039,  0.0195,  ...,  0.0044,  0.0085, -0.0360],
        [ 0.0144, -0.0022, -0.0352,  ..., -0.0069,  0.0384, -0.0154],
        [-0.0058,  0.0177, -0.0487,  ...,  0.0190,  0.0689, -0.0385],
        ...,
        [-0.01

  2%|██▍                                                                                                | 39/1562 [01:06<42:38,  1.68s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.77349853515625
error.norm(): 2702.371337890625
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.68170166015625
error.norm(): 2702.371337890625
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.7886962890625
error.norm(): 2702.371337890625
recons_image_embeddings_default: tensor([[ 0.0208,  0.0174, -0.0100,  ...,  0.0130,  0.0114,  0.0620],
        [ 0.0056, -0.0272, -0.0757,  ...,  0.0084,  0.0550, -0.0204],
        [-0.0194, -0.0068,  0.0360,  ...,  0.0281, -0.0466,  0.0503],
        ...,
        [-0.0024, -0.0506,  0.0197,  ...,  0.0383,  0.0111,  0.0027],
        [ 0.0657, -0.0998, -0.0177,  ..., -0.0476,  0.0080,  0.0378],
        [ 0.0013, -0.0177, -0.0558,  ...,  0.0050, -0.0182,  0.0125]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: te

  3%|██▌                                                                                                | 40/1562 [01:07<42:51,  1.69s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.9031066894531
error.norm(): 2384.51953125
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.6091003417969
error.norm(): 2384.51953125
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.7984619140625
error.norm(): 2384.51953125
recons_image_embeddings_default: tensor([[ 0.0327,  0.0327,  0.0290,  ...,  0.0002, -0.0156, -0.0199],
        [-0.0068,  0.0143, -0.0491,  ..., -0.0767,  0.0697,  0.0362],
        [-0.0405,  0.0222,  0.0400,  ...,  0.0199,  0.0150, -0.0419],
        ...,
        [ 0.0385,  0.0141, -0.0252,  ..., -0.0299, -0.0383, -0.0056],
        [ 0.0039, -0.0324, -0.0766,  ..., -0.0037,  0.0138,  0.0538],
        [-0.0154, -0.0218, -0.0105,  ..., -0.0592,  0.0420, -0.0065]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-0.0084

  3%|██▌                                                                                                | 41/1562 [01:09<42:10,  1.66s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.83837890625
error.norm(): 2487.546142578125
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.669677734375
error.norm(): 2487.546142578125
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.831787109375
error.norm(): 2487.546142578125
recons_image_embeddings_default: tensor([[ 0.0038, -0.0030, -0.0318,  ...,  0.0277, -0.0047, -0.0105],
        [-0.0520,  0.0187, -0.0260,  ..., -0.0397,  0.0387,  0.0057],
        [ 0.0050, -0.0153, -0.1023,  ..., -0.0014,  0.0257, -0.0226],
        ...,
        [ 0.0321, -0.0063,  0.0231,  ..., -0.0128,  0.0256,  0.0230],
        [-0.0446,  0.0276, -0.0388,  ...,  0.0508, -0.0457,  0.0376],
        [-0.0578, -0.0290, -0.0101,  ...,  0.0145, -0.0243, -0.0122]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([

  3%|██▋                                                                                                | 42/1562 [01:11<42:46,  1.69s/it]

recons_image_embeddings_default: tensor([[-0.0756, -0.0266, -0.0128,  ..., -0.0494,  0.0504,  0.0251],
        [ 0.0052, -0.0543, -0.0775,  ..., -0.0442,  0.0455,  0.0536],
        [ 0.0047, -0.0291, -0.0701,  ..., -0.0118, -0.0065, -0.0005],
        ...,
        [-0.0337, -0.0094, -0.0565,  ...,  0.0074, -0.0578, -0.0021],
        [-0.0336, -0.0162,  0.0400,  ..., -0.0483,  0.0295,  0.0096],
        [-0.0156, -0.0092, -0.0141,  ..., -0.0880,  0.0511,  0.0418]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-0.1031, -0.0260, -0.0497,  ..., -0.0396,  0.0426, -0.0103],
        [-0.0235, -0.0580, -0.0734,  ..., -0.0494,  0.0593,  0.0570],
        [-0.0047, -0.0154, -0.0713,  ..., -0.0496,  0.0012, -0.0041],
        ...,
        [-0.0450, -0.0012, -0.0605,  ..., -0.0141, -0.0510, -0.0481],
        [-0.0529, -0.0028,  0.0408,  ..., -0.0612,  0.0383, -0

  3%|██▋                                                                                                | 43/1562 [01:12<43:19,  1.71s/it]

steering norm: 399.60870361328125
error.norm(): 2847.44189453125
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.68988037109375
error.norm(): 2847.44189453125
recons_image_embeddings_default: tensor([[-0.0264, -0.0048, -0.0275,  ...,  0.0320, -0.0023, -0.0129],
        [ 0.0170,  0.0268,  0.0287,  ...,  0.0118, -0.0186, -0.0113],
        [-0.0332,  0.0402,  0.0162,  ..., -0.0112,  0.0574,  0.0274],
        ...,
        [ 0.0075, -0.0360,  0.0197,  ..., -0.0116,  0.0291,  0.0257],
        [ 0.0219, -0.0562, -0.0146,  ...,  0.0158,  0.0213, -0.0159],
        [-0.0508,  0.0356, -0.0332,  ..., -0.0227,  0.0652, -0.0146]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-0.0040,  0.0037, -0.0176,  ...,  0.0312, -0.0139, -0.0547],
        [-0.0140,  0.0207,  0.0379,  ..., -0.0160,  0.0009, -0.0358],
        [-0.0441,  0.0444,  0.001

  3%|██▊                                                                                                | 44/1562 [01:14<43:46,  1.73s/it]

steering norm: 399.6084289550781
error.norm(): 3489.086181640625
recons_image_embeddings_default: tensor([[-0.0372, -0.0003,  0.0311,  ...,  0.0058, -0.0351,  0.0403],
        [-0.0159, -0.0750, -0.0120,  ..., -0.0080,  0.0770,  0.0432],
        [-0.0108,  0.0213, -0.0195,  ..., -0.0213,  0.0551, -0.0030],
        ...,
        [-0.0561, -0.0420, -0.0491,  ..., -0.0307,  0.0164,  0.0061],
        [ 0.0026, -0.0198, -0.0171,  ..., -0.0088, -0.0354,  0.0424],
        [-0.0314, -0.0192, -0.0408,  ..., -0.0396,  0.0017,  0.0833]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-0.0610, -0.0068,  0.0277,  ..., -0.0136, -0.0267, -0.0146],
        [-0.0459, -0.0590, -0.0309,  ..., -0.0308,  0.0537,  0.0013],
        [-0.0384, -0.0056, -0.0169,  ..., -0.0550,  0.0540, -0.0148],
        ...,
        [-0.0715, -0.0371, -0.0323,  ..., -0.0412,  0.0447, -0.0030

  3%|██▊                                                                                                | 45/1562 [01:16<42:57,  1.70s/it]

recons_image_embeddings_default: tensor([[-0.0555,  0.0137, -0.0155,  ...,  0.0488,  0.0085,  0.0223],
        [-0.0172, -0.0359, -0.0245,  ...,  0.0476, -0.0398, -0.0033],
        [-0.0356, -0.0242,  0.0067,  ...,  0.0349,  0.0214,  0.0597],
        ...,
        [ 0.0616, -0.0427, -0.0782,  ...,  0.0495, -0.0093,  0.0176],
        [ 0.0295, -0.0040,  0.0004,  ..., -0.0007,  0.0075,  0.0221],
        [-0.0573, -0.0324, -0.0708,  ...,  0.0041,  0.0383, -0.0776]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-0.0794, -0.0048, -0.0154,  ...,  0.0094,  0.0053,  0.0089],
        [-0.0349, -0.0282, -0.0362,  ...,  0.0249, -0.0274, -0.0090],
        [-0.0483, -0.0283, -0.0013,  ..., -0.0164,  0.0201,  0.0062],
        ...,
        [ 0.0307, -0.0373, -0.0791,  ...,  0.0235, -0.0199, -0.0322],
        [-0.0067, -0.0131, -0.0053,  ..., -0.0396,  0.0302, -0

  3%|██▉                                                                                                | 46/1562 [01:17<42:51,  1.70s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.74609375
error.norm(): 2070.007568359375
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.8988037109375
error.norm(): 2070.007568359375
recons_image_embeddings_default: tensor([[-0.0135,  0.0028, -0.0464,  ...,  0.0061,  0.0236,  0.0134],
        [-0.0037, -0.0084, -0.0312,  ...,  0.0148,  0.0124,  0.0584],
        [-0.0409,  0.0490, -0.0033,  ..., -0.0373,  0.0073,  0.0153],
        ...,
        [-0.0304, -0.0194, -0.0035,  ..., -0.0535, -0.0132, -0.0169],
        [-0.0105,  0.0022,  0.0520,  ...,  0.0423, -0.0117, -0.0063],
        [-0.0101,  0.0076, -0.0240,  ..., -0.0270, -0.0290,  0.0380]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-0.0377,  0.0155, -0.0480,  ...,  0.0049,  0.0021, -0.0116],
        [-0.0326, -0.0173, -0.0097,  ...,  0.0102,  0.01

  3%|██▉                                                                                                | 47/1562 [01:19<42:57,  1.70s/it]

steering norm: 399.59130859375
error.norm(): 2984.962646484375
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.55816650390625
error.norm(): 2984.962646484375
recons_image_embeddings_default: tensor([[ 0.0521,  0.0165, -0.0256,  ...,  0.0430, -0.0125,  0.0520],
        [ 0.0248, -0.0020,  0.0309,  ..., -0.0669,  0.0251, -0.0055],
        [-0.0534, -0.0005,  0.0121,  ..., -0.0266, -0.0282, -0.0020],
        ...,
        [ 0.0768,  0.0514, -0.0789,  ...,  0.0125, -0.0130, -0.0299],
        [ 0.0036, -0.0011, -0.0235,  ..., -0.0082,  0.0333,  0.0272],
        [ 0.0373, -0.0013,  0.0136,  ..., -0.0003, -0.0160,  0.0235]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[ 0.0118, -0.0182, -0.0150,  ...,  0.0528, -0.0179,  0.0052],
        [-0.0261,  0.0026,  0.0253,  ..., -0.0594,  0.0485, -0.0467],
        [-0.0590, -0.0009,  0.0092

  3%|███                                                                                                | 48/1562 [01:21<43:20,  1.72s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.88555908203125
error.norm(): 2155.67431640625
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.7493591308594
error.norm(): 2155.67431640625
torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.8669128417969
error.norm(): 2155.67431640625
recons_image_embeddings_default: tensor([[ 0.0174, -0.0189, -0.0162,  ..., -0.0267,  0.0213,  0.0317],
        [-0.0229, -0.0609, -0.0705,  ..., -0.0527,  0.0073, -0.0096],
        [-0.0325, -0.0482, -0.0023,  ..., -0.0642,  0.0346, -0.0249],
        ...,
        [-0.0184, -0.0531, -0.0552,  ..., -0.0228,  0.0293,  0.0555],
        [ 0.0378, -0.0011,  0.0073,  ...,  0.0153,  0.0326,  0.0245],
        [-0.0041,  0.0480, -0.0357,  ..., -0.0175,  0.0561, -0.0123]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor

  3%|███                                                                                                | 49/1562 [01:23<43:05,  1.71s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.8680114746094
error.norm(): 2436.70849609375
recons_image_embeddings_default: tensor([[-0.0210,  0.0054, -0.0584,  ..., -0.0181, -0.0146,  0.0125],
        [-0.0476, -0.0062,  0.0063,  ..., -0.0086,  0.0420, -0.0382],
        [-0.0116, -0.0373, -0.0592,  ...,  0.0321,  0.0078, -0.0585],
        ...,
        [ 0.0259, -0.0400, -0.0852,  ...,  0.0130, -0.0031, -0.0418],
        [-0.0521,  0.0055,  0.0449,  ..., -0.0570,  0.0615,  0.0019],
        [-0.0859, -0.0205, -0.0013,  ..., -0.0651,  0.0528,  0.0204]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-0.0302,  0.0038, -0.0607,  ..., -0.0203,  0.0111, -0.0216],
        [-0.0614, -0.0044, -0.0197,  ..., -0.0242,  0.0512, -0.0700],
        [-0.0241, -0.0275, -0.0461,  ..., -0.0003,  0.0155, -0.0820],
        ...,
        [-0.0140

  3%|███                                                                                                | 49/1562 [01:24<43:41,  1.73s/it]

torch.Size([32, 50, 768])
torch.Size([32, 50, 768])
steering norm: 399.6003723144531
error.norm(): 2679.20703125
recons_image_embeddings_default: tensor([[-0.0276,  0.0279, -0.0222,  ...,  0.0251,  0.0720,  0.0278],
        [-0.0290,  0.0151, -0.0826,  ...,  0.0505,  0.0133, -0.0214],
        [-0.0013, -0.0136,  0.0077,  ...,  0.0214,  0.0162, -0.0035],
        ...,
        [ 0.0025,  0.0162,  0.0033,  ..., -0.0245,  0.0289, -0.0068],
        [-0.0219,  0.0249, -0.0274,  ...,  0.0505, -0.0062,  0.0114],
        [-0.0473, -0.0418, -0.0403,  ..., -0.0356, -0.0307, -0.0174]],
       device='cuda:0')
recons_image_embeddings_default.shape: torch.Size([32, 512])
recons_image_embeddings_default: torch.Size([32, 512])
recons_image_embeddings_feat_altered: tensor([[-0.0467,  0.0075, -0.0438,  ...,  0.0140,  0.0654, -0.0050],
        [-0.0463,  0.0269, -0.0787,  ...,  0.0406,  0.0394, -0.0640],
        [-0.0280,  0.0095, -0.0061,  ..., -0.0143,  0.0132, -0.0206],
        ...,
        [-0.0238,  




In [123]:
len(feature_steered_embeds[random_feat_idxs[0]])

1600

In [124]:
default_embeds.shape
len(default_embeds_list)
default_embeds = torch.cat(default_embeds_list)
default_embeds.shape

torch.Size([1600, 512])

In [125]:
len(altered_embeds_list), altered_embeds_list[0].shape, default_embeds.shape

(25, torch.Size([32, 512]), torch.Size([1600, 512]))

In [80]:
og_model.cuda()

CLIP(
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
    (patch_dropout): Identity()
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): ModuleList(
        (0-11): 12 x ResidualAttentionBlock(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ls_1): Identity()
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): GELU(approximate='none')
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ls_2): Identity()
        )
      )
    )
    (ln_post): LayerNorm((768,), eps=1e-05, elementwise_affine

In [83]:
with open("/workspace/clip_dissect_raw.txt", "r") as f:
    larger_vocab = [line[:-1] for line in f.readlines()][:5000]

# with open("/workspace/better_img_desc.txt", "r") as f:
#     larger_vocab = [line[:-1] for line in f.readlines()][:5000]

In [126]:
# use clip vocab here and compare embeds
import torch
from PIL import Image

tokenizer = open_clip.get_tokenizer('ViT-B-32')
text = tokenizer(larger_vocab)
text_features = og_model.encode_text(text.cuda())
text_features_normed = text_features/text_features.norm(dim=-1, keepdim=True)


print(f"text_features_normed.shape: {text_features_normed.shape}")
text_probs_altered_list = []
# can probs make this one tensor 
with torch.no_grad(), torch.cuda.amp.autocast():
    for key in feature_steered_embeds:
        print(key)
        text_probs_altered = (100.0 * torch.stack(feature_steered_embeds[key]) @ text_features_normed.T).softmax(dim=-1)
        text_probs_altered_list.append(text_probs_altered)
    text_probs_default = (100.0 * default_embeds @ text_features_normed.T).softmax(dim=-1)


# for altered_embeds in altered_embeds_list:
#     with torch.no_grad(), torch.cuda.amp.autocast():
#         # might want to still normalize
        
#         # already normalized
#         # altered_embeds /= altered_embeds.norm(dim=-1, keepdim=True)

#         text_probs_altered = (100.0 * altered_embeds @ text_features_normed.T).softmax(dim=-1)
#         text_probs_altered_list.append(text_probs_altered)
#     # default_embds_norm = default_embeds.norm(dim=-1, keepdim=True)
#     text_probs_default = (100.0 * default_embeds @ text_features_normed.T).softmax(dim=-1)

print("Label probs altered:", text_probs_altered.shape)  # prints: [[1., 0., 0.]]
print("Label probs default:", text_probs_default.shape)  # prints: [[1., 0., 0.]]

text_features_normed.shape: torch.Size([5000, 512])
1126
434
438
1887
1963
2129
1342
534
929
1023
415
905
2705
443
2640
2570
1121
1997
1516
47
473
2321
2632
1408
2703
Label probs altered: torch.Size([1600, 5000])
Label probs default: torch.Size([1600, 5000])


In [127]:
feature_steered_embeds.keys()

dict_keys([1126, 434, 438, 1887, 1963, 2129, 1342, 534, 929, 1023, 415, 905, 2705, 443, 2640, 2570, 1121, 1997, 1516, 47, 473, 2321, 2632, 1408, 2703])

In [128]:
text_probs_altered

tensor([[4.8402e-05, 6.6817e-04, 2.8902e-05,  ..., 2.5453e-06, 1.4878e-05,
         6.3988e-06],
        [5.5200e-06, 3.3551e-05, 4.4354e-06,  ..., 4.0702e-06, 5.1344e-07,
         2.4495e-06],
        [4.8671e-06, 7.4371e-05, 1.8473e-06,  ..., 2.2812e-06, 1.3728e-06,
         2.3213e-07],
        ...,
        [3.9178e-06, 1.8657e-04, 1.5164e-06,  ..., 1.7085e-05, 3.4846e-06,
         3.5672e-06],
        [3.6270e-06, 2.6800e-05, 5.7387e-07,  ..., 3.0912e-04, 1.8381e-06,
         7.6346e-05],
        [1.9271e-05, 4.9210e-05, 9.2268e-07,  ..., 4.6858e-06, 1.9271e-05,
         3.8325e-05]], device='cuda:0')

### Summed Logit Difference

In [130]:
# subtract from default, label, and print trends
text_probs_altered.shape

# selected_vocab = all_imagenet_class_names
selected_vocab = larger_vocab

# switch to datacomp eval??
# we run this for each feature over all of imagenet eval and create average absolute/ratio
# difference vectors
for j, text_probs_altered in enumerate(text_probs_altered_list):
    print(f"{'============================================'*2}\n\nFor Feature {random_feat_idxs[j]}")
    logit_diff = text_probs_altered - text_probs_default
#     print(f"logit_diff.shape: {logit_diff.shape}")
    logit_diff_aggregate = logit_diff.sum(dim=0)
#     print(f"logit_diff_aggregate.shape: {logit_diff_aggregate.shape}")
    
    logit_ratio = text_probs_altered/text_probs_default
    logit_ratio_aggregate = logit_ratio.mean(dim=0)
    
    vals_agg, idxs_agg = torch.topk(logit_diff_aggregate,k=10)
    vals_least_agg, idxs_least_agg = torch.topk(logit_diff_aggregate,k=10,largest=False)
    
    ratios_agg, ratios_idxs_agg = torch.topk(logit_ratio_aggregate,k=10)
    ratios_least_agg, ratios_idxs_least_agg = torch.topk(logit_ratio_aggregate,k=10,largest=False)
    
    vals, idxs = torch.topk(logit_diff,k=5)
    vals_least, idxs_least = torch.topk(logit_diff,k=5,largest=False)
    
    ratios, ratios_idxs = torch.topk(logit_ratio,k=5)
    ratios_least, ratios_idxs_least = torch.topk(logit_ratio,k=5,largest=False)
    
#     print(f"\nMost Changed, by Absolute Diff Over {logit_diff.shape[0]} Images:\n{vals_agg}")
#     print(np.array(selected_vocab)[idxs_agg.cpu()])
#     print(vals_least_agg)
#     print(np.array(selected_vocab)[idxs_least_agg.cpu()])
    
#     print(f"\nMost Changed, by Ratio Over {logit_diff.shape[0]} Images:")
    print(ratios_agg)
    print(np.array(selected_vocab)[ratios_idxs_agg.cpu()])
    print(vals_least_agg)
    print(np.array(selected_vocab)[ratios_idxs_least_agg.cpu()])
    
    
#     for i in range(logit_diff.shape[0]):
# #         print(f"\nImage {i} ========================\nMost Changed, by Absolute Diff\n:{vals[i]}")
# #         print(np.array(selected_vocab)[idxs.cpu()][i])
# #         print(vals_least[i])
# #         print(np.array(selected_vocab)[idxs_least.cpu()][i])
#         print(f"\nImage {i} Most Changed, by Ratio:")
#         print(ratios[i])
#         print(np.array(selected_vocab)[ratios_idxs.cpu()][i])
#         print(ratios_least[i])
#         print(np.array(selected_vocab)[ratios_idxs_least.cpu()][i])
        
# #         text = tokenizer(np.array(selected_vocab)[ratios_idxs.cpu()][i])
# #         text_least = tokenizer(np.array(selected_vocab)[ratios_idxs_least.cpu()][i])
# #         text_features = og_model.encode_text(text.cuda())
# #         text_features_least = og_model.encode_text(text_least.cuda())


For Feature 1126
tensor([141.0693, 126.0712, 123.7205,  88.9756,  86.8450,  80.2493,  74.8723,
         71.8137,  69.8538,  65.6063], device='cuda:0')
['cruises' 'logic' 'islands' 'vancouver' 'cruise' 'editors' 'queen'
 'directions' 'massage' 'mountains']
tensor([-11.8158,  -9.8920,  -8.1005,  -7.7588,  -7.4382,  -6.9706,  -6.5900,
         -5.7896,  -5.6117,  -4.9677], device='cuda:0')
['extension' 'sitting' 'notebook' 'charge' 'study' 'string' 'platform'
 'complaint' 'fr' 'figure']

For Feature 434
tensor([678.9808, 664.3954, 506.6415, 461.0171, 199.4146, 195.0756, 162.5014,
        158.7604, 140.9463, 125.3289], device='cuda:0')
['lighting' 'cutting' 'binding' 'eating' 'dining' 'setting' 'planning'
 'measurements' 'turning' 'lodge']
tensor([-9.7327, -6.8688, -6.4923, -5.8528, -5.6588, -5.5747, -5.5701, -4.5992,
        -4.2149, -4.0465], device='cuda:0')
['foto' 'figure' 'employee' 'basketball' 'platform' 'track' 'video'
 'company' 'paypal' 'et']

For Feature 438
tensor([144.8336, 

## Stop

### Covariance Analysis

In [None]:
# subtract from default, label, and print trends
text_probs_altered.shape

# selected_vocab = all_imagenet_class_names
selected_vocab = larger_vocab

cov_stuff_avgs = []

# switch to datacomp eval??
# we run this for each feature over all of imagenet eval and create average absolute/ratio
# difference vectors
per_feat_avg_vectors = []
for j, text_probs_altered in enumerate(text_probs_altered_list):
    print(f"\n\nFor Feature {random_feat_idxs[j]}")
    print(f"\n\nFor Feature {random_feat_idxs[j]}")
    logit_diff = text_probs_altered - text_probs_default
    logit_ratio = text_probs_altered/text_probs_default
    
    vals, idxs = torch.topk(logit_diff,k=5)
    vals_least, idxs_least = torch.topk(logit_diff,k=5,largest=False)
    
    ratios, ratios_idxs = torch.topk(logit_ratio,k=5)
    ratios_least, ratios_idxs_least = torch.topk(logit_ratio,k=5,largest=False)
    
    feat_avg_vectors = []
    for i in range(logit_diff.shape[0]):
#         print(f"\nImage {i} ========================\nMost Changed, by Absolute Diff\n:{vals[i]}")
#         print(np.array(all_imagenet_class_names)[idxs.cpu()][i])
#         print(vals_least[i])
#         print(np.array(all_imagenet_class_names)[idxs_least.cpu()][i])
        
        print("\nMost Changed, by Ratio:")
        print(ratios[i])
        print(np.array(selected_vocab)[ratios_idxs.cpu()][i])
#         print(ratios_least[i])
#         print(np.array(selected_vocab)[ratios_idxs_least.cpu()][i])
        
        text = tokenizer(np.array(selected_vocab)[ratios_idxs.cpu()][i])
#         text_least = tokenizer(np.array(selected_vocab)[ratios_idxs_least.cpu()][i])
        text_features = og_model.encode_text(text.cuda())
        cov_over_images.append(text_features)
#         text_features_least = og_model.encode_text(text_least.cuda())
        print(torch.tril(torch.cov(text_features), diagonal=-1))
        print(torch.tril(torch.cov(text_features), diagonal=-1).sum()/10)
#         cov_stuff = torch.tril(torch.cov(text_features), diagonal=-1).sum()/10
#         cov_stuff_least = torch.tril(torch.cov(text_features_least), diagonal=-1).sum()/10
    print(torch.tril(torch.cov(torch.cat(cov_over_images)), diagonal=-1).shape)
    n = torch.tril(torch.cov(torch.cat(cov_over_images)), diagonal=-1).shape[0]
    num_elements = (n**2)/2 - n
    cov_stuff = torch.tril(torch.cov(torch.cat(cov_over_images)), diagonal=-1).sum()/num_elements
    cov_stuff_avgs.append(cov_stuff)
#         cov_stuff_avgs_least.append(cov_stuff_least)
    if j > 10:
        break
print(torch.tensor(cov_stuff_avgs).mean())
# print(torch.tensor(cov_stuff_avgs_least).mean())

In [None]:
top_k_imgnet_labels = 10

In [None]:
from matplotlib import pyplot as plt
from collections import defaultdict, Counter

feat_autolabels_default = defaultdict(Counter)
for i in range(text_probs_default.shape[0]):
    vals, idxs = torch.topk(text_probs_default[i],k=top_k_imgnet_labels)
    print(i)
    for k, idx in enumerate(idxs):
        feat_autolabels_default[i][all_imagenet_class_names[idx]] += vals[k]
        print("\t", all_imagenet_class_names[idx])
feat_autolabels_default

In [None]:
# subtract from default, label, and print trends
text_probs_altered.shape

for text_probs_altered in text_probs_altered_list:
    logit_diff = text_probs_altered - text_probs_default
    print(logit_diff)
    vals, idxs = torch.topk(logit_diff,k=5)
    print(vals, np.array(all_imagenet_class_names[idxs])
    break

In [None]:
all_imagenet_class_names

In [None]:
from collections import defaultdict, Counter

feat_autolabels_altered_list = []
for text_probs_altered in text_probs_altered_list:
    feat_autolabels_altered = defaultdict(Counter)
    for i in range(text_probs_altered.shape[0]):
        vals, idxs = torch.topk(text_probs_altered[i],k=top_k_imgnet_labels)
#         print(i)
        for k, idx in enumerate(idxs):
            feat_autolabels_altered[i][all_imagenet_class_names[idx]] += vals[k]
#             print("\t", all_imagenet_class_names[idx])
    feat_autolabels_altered_list.append(feat_autolabels_altered)

start_idx = 9
end_idx = 10
    
h = 0
for key in feat_autolabels_default:
    print(f"\nfeat_autolabels_default img {key}:\n {feat_autolabels_default[key]}\n")
    h += 1
    if h > end_idx:
        break
for i, f_a_a in enumerate(feat_autolabels_altered_list):
    print("============= feature number ", i, "====================")
    h = 0
    for key in range(start_idx, end_idx):
#         print("\n", key)
#         for item in f_a_a[key]:
#             print("\t", item, f_a_a[key][item].cpu().item())
        print(f"\nf_a_a img {key}:\n {f_a_a[key]}\n")


In [None]:
for i in range(text_probs_default.shape[0]):
    vals, idxs = torch.topk(text_probs_default[i],k=1000)
    print(i, ind_to_name[str(idxs[0].cpu().item())][1])
    fig, ax = plt.subplots(figsize=(10, 10))
#     ax.xaxis.set_ticks((1000))
#     ax.set_xticks(list(range(1000)), [ind_to_name[str(idxs[idx].cpu().item())][1] for idx in idxs])
    plt.bar(idxs.cpu(), vals.cpu(), width=5)
    break

In [None]:
plt.imshow(batch_images[2].cpu().permute((1,2,0)).numpy())

In [None]:

@torch.no_grad()
def get_heatmap(
          image,
          model,
          sparse_autoencoder,
          feature_id,
): 
    image = image.to(cfg.device)
    _, cache = model.run_with_cache(image.unsqueeze(0))

    post_reshaped = einops.rearrange(cache[sparse_autoencoder.cfg.hook_point], "batch seq d_mlp -> (batch seq) d_mlp")
    # Compute activations (not from a fwd pass, but explicitly, by taking only the feature we want)
    # This code is copied from the first part of the 'forward' method of the AutoEncoder class
    sae_in =  post_reshaped - sparse_autoencoder.b_dec # Remove decoder bias as per Anthropic
    print(f"sae_in.shape: {sae_in.shape}")
    acts = einops.einsum(
            sae_in,
            sparse_autoencoder.W_enc[:, feature_id],
            "x d_in, d_in -> x",
        )
    return acts 
     
def image_patch_heatmap(activation_values,image_size=224, pixel_num=14):
    activation_values = activation_values.detach().cpu().numpy()
    activation_values = activation_values[1:]
    activation_values = activation_values.reshape(pixel_num, pixel_num)

    # Create a heatmap overlay
    heatmap = np.zeros((image_size, image_size))
    patch_size = image_size // pixel_num

    for i in range(pixel_num):
        for j in range(pixel_num):
            heatmap[i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size] = activation_values[i, j]

    return heatmap


In [None]:
from matplotlib import pyplot as plt

grid_size = 1
fig, axs = plt.subplots(int(np.ceil(len(images)/grid_size)), grid_size, figsize=(15, 15))
name=  f"Category: uhh,  Feature: {0}"
fig.suptitle(name)#, y=0.95)
for ax in axs.flatten():
    ax.axis('off')
complete_bid = []

heatmap = get_heatmap(batch_images[2], model,sparse_autoencoder, 10000)
heatmap = image_patch_heatmap(heatmap, pixel_num=224//cfg.patch_size)

display = batch_images[2].cpu().numpy().transpose(1, 2, 0)

has_zero = False

In [None]:
plt.imshow(display)
plt.imshow(heatmap, alpha=0.3)

In [None]:
plt.imshow(display)
plt.imshow(heatmap, alpha=0.3)