<a href='https://colab.research.google.com/github/tweks/sae-sd/blob/main/clip_embeddings_analysis.ipynb' target='_parent'><img src='https://colab.research.google.com/assets/colab-badge.svg' alt='Open In Colab'/></a>

# Setup

In [None]:
try:
    import google.colab
    !pip install datasets diffusers accelerate transformers tuned-lens
except:
    pass

In [None]:
import torch
import torch.nn.functional as F
from datasets import load_dataset
from transformers import CLIPTextModel, CLIPTokenizer
import matplotlib.pyplot as plt

# Dataset

In [None]:
subset_size = 10000
dataset = load_dataset('pixparse/cc3m-wds', split=f'train[:{subset_size}]', data_files='cc3m-train-000*.tar')
dataset = dataset.shuffle(seed=42)

In [None]:
dataset['txt'][:10]

In [None]:
model_id = 'openai/clip-vit-large-patch14'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = CLIPTokenizer.from_pretrained(model_id)
text_encoder = CLIPTextModel.from_pretrained(model_id)
text_encoder = text_encoder.to(device)
text_encoder.eval()

# Embeddings analysis

In [None]:
num_tokens = 77

In [None]:
def compute_batch_embeddings(text_list):
    """Helper to tokenize and encode a batch of text, returning last_hidden_state."""
    tokens = tokenizer(
        text_list,
        padding='max_length',
        truncation=True,
        max_length=num_tokens,
        return_tensors='pt'
    )
    tokens = {k: v.to(device) for k, v in tokens.items()}
    with torch.no_grad():
        outputs = text_encoder(**tokens)
    # outputs.last_hidden_state -> shape: (batch_size, 77, hidden_dim)
    return outputs.last_hidden_state

In [None]:
batch_size = 32
cos_sim_matrix = torch.zeros(num_tokens, num_tokens, dtype=torch.float32).to(device)
l2_matrix = torch.zeros(num_tokens, num_tokens, dtype=torch.float32).to(device)
num_samples = 0
embs = []
for start_idx in range(0, subset_size, batch_size):
    end_idx = min(start_idx + batch_size, subset_size)
    batch_captions = dataset['txt'][start_idx:end_idx]

    # Encode batch
    batch_emb = compute_batch_embeddings(batch_captions)  # (B, 77, hidden_dim)
    embs.append(batch_emb)

    # Get batch sizes
    bsize, seq_len, emb_dim = batch_emb.shape

    # For each sample in the batch:
    #   1) get the [77, emb_dim] embeddings
    #   2) compute 77 x 77 cos sim
    #   3) compute 77 x 77 L2 distances
    #   4) accumulate
    # We'll do this in a vectorized manner.

    # Normalize for cosine similarity
    # shape: (B, 77, 1)
    normed = F.normalize(batch_emb, p=2, dim=-1)  # shape: (B, 77, emb_dim)
    # Cosine similarity per sample: (B, 77, 77)
    cos_sims = torch.bmm(normed, normed.transpose(1, 2))

    # L2 distances per sample using cdist for each in the batch
    # (though cdist won't vectorize across B easily)
    # We'll do a loop or stack:
    l2s = []
    for i in range(bsize):
        # shape: (77, 77)
        dists = torch.cdist(batch_emb[i], batch_emb[i], p=2)
        l2s.append(dists)
    l2s = torch.stack(l2s, dim=0)  # (B, 77, 77)

    # Accumulate
    cos_sim_matrix += cos_sims.sum(dim=0)
    l2_matrix += l2s.sum(dim=0)
    num_samples += bsize

# Average over the total number of samples
cos_sim_matrix /= num_samples
l2_matrix /= num_samples
embeddings = torch.cat(embs, dim=0)

In [None]:
# Move to CPU for plotting
cos_sim_matrix = cos_sim_matrix.cpu().numpy()
l2_matrix = l2_matrix.cpu().numpy()

In [None]:
plt.imshow(cos_sim_matrix, origin='lower')
plt.colorbar()
plt.title('Average Cosine Similarity (Token vs. Token)')
plt.xlabel('Token index')
plt.ylabel('Token index')
plt.show()

In [None]:
# L2 distance heatmap
plt.imshow(l2_matrix, origin='lower')
plt.colorbar()
plt.title('Average L2 Distance (Token vs. Token)')
plt.xlabel('Token index')
plt.ylabel('Token index')
plt.show()

# Appendix

In [None]:
dataset_lenghts = [len(tokenizer.encode(s)) for s in dataset['txt']]
max(dataset_lenghts)

In [None]:
plt.hist(dataset_lenghts, bins=77)
plt.xlabel('Caption Length (tokens)')
plt.ylabel('Frequency')
plt.title('Histogram of Caption Lengths')
plt.show()

In [None]:
caption = dataset['txt'][0]
tokens = tokenizer.encode(caption)
tokens

In [None]:
tokenizer.convert_ids_to_tokens(tokens)

In [None]:
torch.set_printoptions(threshold=700)
for i, row in enumerate(embeddings[0]):
    print(i, row)
torch.set_printoptions(profile='default')