<a href="https://colab.research.google.com/github/tweks/sae-sd/blob/main/notebooks/clip_embs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [None]:
try:
    import google.colab
    !pip install datasets diffusers accelerate transformers
except:
    pass

In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from copy import deepcopy
from datasets import load_dataset
from tqdm.notebook import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
import matplotlib.pyplot as plt

# Dataset

In [None]:
subset_size = 10000
full_dataset = load_dataset('pixparse/cc3m-wds', split='train', data_files='cc3m-train-000*.tar').shuffle(seed=42)['txt']
print(f'Downloaded {len(full_dataset)} examples.')  # 50460

In [30]:
dataset = full_dataset[:subset_size]
train_data = full_dataset[subset_size:2*subset_size]

In [None]:
dataset[:10]

In [None]:
model_id = 'openai/clip-vit-large-patch14'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = CLIPTokenizer.from_pretrained(model_id)
text_encoder = CLIPTextModel.from_pretrained(model_id)
text_encoder = text_encoder.to(device)
text_encoder.eval()

# Helper functions

In [33]:
num_tokens = 77

In [34]:
def compute_batch_embeddings(text_list):
    """Helper to tokenize and encode a batch of text, returning last_hidden_state."""
    tokens = tokenizer(
        text_list,
        padding='max_length',
        truncation=True,
        max_length=num_tokens,
        return_tensors='pt'
    )
    tokens = {k: v.to(device) for k, v in tokens.items()}
    with torch.no_grad():
        outputs = text_encoder(**tokens)
    # outputs.last_hidden_state -> shape: (batch_size, 77, hidden_dim)
    return outputs.last_hidden_state

In [35]:
def compute_embeddings(data, batch_size=32):
    embs = []
    for start_idx in tqdm(range(0, subset_size, batch_size)):
        end_idx = min(start_idx + batch_size, subset_size)
        batch_captions = data[start_idx:end_idx]

        # Encode batch
        batch_emb = compute_batch_embeddings(batch_captions)  # (B, 77, hidden_dim)
        embs.append(batch_emb)
    return torch.cat(embs, dim=0)

In [36]:
def compute_metrics(dataset, batch_size=32):
    cos_sim_matrix = torch.zeros(num_tokens, num_tokens, dtype=torch.float32).to(device)
    l2_matrix = torch.zeros(num_tokens, num_tokens, dtype=torch.float32).to(device)
    num_samples = 0
    dataset_size = dataset.size(0)
    for start_idx in tqdm(range(0, dataset_size, batch_size)):
        end_idx = min(start_idx + batch_size, dataset_size)
        batch_emb = dataset[start_idx:end_idx]

        # Get batch sizes
        bsize, _, _ = batch_emb.shape

        # For each sample in the batch:
        #   1) get the [77, emb_dim] embeddings
        #   2) compute 77 x 77 cos sim
        #   3) compute 77 x 77 L2 distances
        #   4) accumulate
        # We'll do this in a vectorized manner.

        # Normalize for cosine similarity
        # shape: (B, 77, 1)
        normed = F.normalize(batch_emb, p=2, dim=-1)  # shape: (B, 77, emb_dim)
        # Cosine similarity per sample: (B, 77, 77)
        cos_sims = torch.bmm(normed, normed.transpose(1, 2))

        # L2 distances per sample using cdist for each in the batch
        # (though cdist won't vectorize across B easily)
        # We'll do a loop or stack:
        l2s = []
        for i in range(bsize):
            # shape: (77, 77)
            dists = torch.cdist(batch_emb[i], batch_emb[i], p=2)
            l2s.append(dists)
        l2s = torch.stack(l2s, dim=0)  # (B, 77, 77)

        # Accumulate
        cos_sim_matrix += cos_sims.sum(dim=0)
        l2_matrix += l2s.sum(dim=0)
        num_samples += bsize

    # Average over the total number of samples
    cos_sim_matrix /= num_samples
    l2_matrix /= num_samples
    return cos_sim_matrix, l2_matrix

In [37]:
def plot_cos_sim(cos_sim_matrix):
    matrix = cos_sim_matrix.cpu().numpy()
    plt.imshow(matrix, origin='lower')
    plt.colorbar()
    plt.title('Average Cosine Similarity (Token vs. Token)')
    plt.xlabel('Token index')
    plt.ylabel('Token index')
    plt.show()

In [38]:
def plot_l2(l2_matrix):
    matrix = l2_matrix.cpu().numpy()
    plt.imshow(matrix, origin='lower')
    plt.colorbar()
    plt.title('Average L2 Distance (Token vs. Token)')
    plt.xlabel('Token index')
    plt.ylabel('Token index')
    plt.show()

# Embedding analysis

In [None]:
dataset_embs = compute_embeddings(dataset)

In [None]:
cos_sim_matrix, l2_matrix = compute_metrics(dataset_embs)

In [None]:
plot_cos_sim(cos_sim_matrix)

In [None]:
plot_l2(l2_matrix)

# Tuned Lens

In [43]:
class TunedLens(nn.Module):
    def __init__(self, num_tokens=77, emb_dim=768):
        super().__init__()
        lens = nn.Linear(emb_dim, emb_dim, bias=True)
        lens.weight.data.zero_()
        lens.bias.data.zero_()
        self.lenses = nn.ModuleList([deepcopy(lens) for _ in range(num_tokens-1)])

    def forward(self, x):
        # Token embedding at position 0 acts as the target.
        target = x[:, 0, :]  # shape: (batch_size, emb_dim)
        transformed_tokens = [target]
        for i, lens in enumerate(self.lenses, start=1):
            token = x[:, i, :]
            transformed = token + lens(token)  # shape: (batch_size, emb_dim)
            transformed_tokens.append(transformed)

        # Stack along a new dimension to create a tensor of shape (batch_size, 76, emb_dim)
        return torch.stack(transformed_tokens, dim=1)

In [None]:
# Training dataset
train_ds = compute_embeddings(train_data)

In [None]:
train_ds.shape

In [46]:
def train(model, train_ds, num_epochs=10, batch_size=32, lr=1e-3):
    model.train()
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss()

    n = train_ds.size(0)
    for epoch in tqdm(range(num_epochs), desc='Epoch'):
        total_loss = 0.0
        # Shuffle indices each epoch
        indices = torch.randperm(n)
        for i in tqdm(range(0, n, batch_size), desc='Batch'):
            batch_indices = indices[i:i+batch_size]
            batch = train_ds[batch_indices]  # shape: (batch_size, 77, emb_dim)

            optimizer.zero_grad()
            # Forward pass: output shape is (batch_size, 77, emb_dim)
            output = model(batch)
            # Extract target token (position 0)
            target = output[:, 0, :]  # shape: (batch_size, emb_dim)
            # Transformed tokens from positions 1..76
            prediction = output[:, 1:, :]  # shape: (batch_size, 76, emb_dim)
            # Expand target along the sequence dimension to match prediction shape
            target_expanded = target.unsqueeze(1).expand_as(prediction)

            loss = loss_fn(prediction, target_expanded)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * batch.size(0)

        avg_loss = total_loss / n
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

In [None]:
model = TunedLens(num_tokens=77, emb_dim=768)
train(model, train_ds, num_epochs=10, batch_size=32, lr=1e-3)

In [48]:
model.eval()
with torch.no_grad():
    transformed_embs = model(dataset_embs)

In [None]:
tl_cos_sim_matrix, tl_l2_matrix = compute_metrics(transformed_embs)

In [None]:
plot_cos_sim(tl_cos_sim_matrix)

In [None]:
plot_l2(tl_l2_matrix)

# Appendix

In [None]:
dataset_lenghts = [len(tokenizer.encode(s)) for s in dataset]
max(dataset_lenghts)

In [None]:
plt.hist(dataset_lenghts, bins=77)
plt.xlabel('Caption Length (tokens)')
plt.ylabel('Frequency')
plt.title('Histogram of Caption Lengths')
plt.show()

In [None]:
caption = dataset[0]
tokens = tokenizer.encode(caption)
tokens

In [None]:
tokenizer.convert_ids_to_tokens(tokens)

In [None]:
torch.set_printoptions(threshold=700)
for i, row in enumerate(dataset_embs[0]):
    print(i, row)
torch.set_printoptions(profile='default')