<a href="https://colab.research.google.com/github/tom-pollak/interp-lora-causal-circuits/blob/main/clip_golden_gate.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Image To Text (LLM + CLIP)

Notebook by Katherine Crowson (https://twitter.com/RiversHaveWings)

This notebook uses reinforcement learning to fine-tune a large language model ([Pythia 160M](https://github.com/EleutherAI/pythia) by default) to interpret a single image according to a [CLIP](https://arxiv.org/abs/2103.00020) based image/text matching loss.

In [1]:
#@title Licensed under the Apache License, Version 2.0 { display-mode: "form" }

# Copyright 2024 Katherine Crowson

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [2]:
#@title Check GPU

!nvidia-smi

Sat Oct 26 11:13:11 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off | 00000000:00:04.0 Off |                    0 |
| N/A   33C    P0              46W / 400W |      2MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [3]:
#@title Install dependencies

!pip install -qqq open_clip_torch peft datasets

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.5 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.5/1.5 MB[0m [31m48.4 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m25.8 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/320.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m320.7/320.7 kB[0m [31m24.5 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/472.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m472.7/472.7 kB[0m [31m32.5 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/116.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [4]:
#@title Import libraries

import textwrap

from google.colab import files
import open_clip
import peft
from datasets import load_dataset
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from torch import optim
from torch.nn import functional as F
from tqdm.auto import tqdm

In [5]:
#@title Define necessary functions

print = tqdm.external_write_mode()(print)


def endless_range(start=0, step=1):
    """An endless range generator."""
    i = start
    while True:
        yield i
        i += step


def logp_completion(logits, tokens, mask):
    """Compute the log probabilities of completions given their prompts.

    Args:
        tokens: The tokens input to the model. Shape: (..., T).
        logits: The logits output from the model. Shape: (..., T, V).
        mask: A mask indicating which tokens should be included in the log probabilities. It should
            exclude prompt tokens and padding tokens. Shape: (..., T).

    Returns:
        The log probabilities of the completions given their prompts. Shape: (...).
    """
    logits = F.log_softmax(logits, dim=-1)
    logp_tokens = logits[..., :-1, :].gather(-1, tokens[..., 1:, None])[..., 0]
    return torch.sum(logp_tokens * mask[..., 1:], dim=-1)


In [6]:
#@title Dataset
ds = load_dataset("tommyp111/golden-gate-photos", split="train")

README.md:   0%|          | 0.00/278 [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/18.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/8 [00:00<?, ? examples/s]

In [14]:
#@title Set parameters { display-mode: "form" }

#@markdown ## Sampling

temperature = 0.9  #@param {type: 'number'}

#@markdown The number of tokens to sample from the LLM:
max_new_tokens = 50  #@param {type: 'integer'}

#@markdown ## Training

batch_size = 64  #@param {type: 'integer'}

lr = 1e-4 #@param {type: 'number'}

beta1 = 0.9 #@param {type: 'number'}
beta2 = 0.95 #@param {type: 'number'}

wd = 1e-2 #@param {type: 'number'}

#@markdown The strength of the KL divergence penalty vs the original LLM:
#@markdown <br><small>The KL divergence penalty specifies the rate at which the optimizer will trade off a decrease in the angle (in radians) between the CLIP text and image embeddings and a decrease in the KL divergence between the model and the reference model.</small>
#@markdown <br>- Pythia 70m: `kl_weight 8e-4 - 1e-3`
kl_weight = 2e-3  #@param {type: 'number'}

#@markdown ## LoRA

#@markdown Hidden LoRA dimension
lora_rank = 16 #@param {type: 'integer'}

#@markdown Scale of applied weights
lora_alpha = 8 #@param {type: 'integer'}

#@markdown Regularisation
lora_dropout = 0. #@param {type: 'number'}



In [8]:
#@title Many different prompts
prompts = [
   # Image-specific prompts
   "The theme of this image is",
   "Looking at this image, I can see",
   "The main elements captured in this image are",
   "What stands out in this picture is",
   "The focal point of this image appears to be",
   "Analyzing this image, one notices",
   "This photograph captures",
   "The most striking aspect of this image is",
   "From an observer's perspective, this image presents",
   "The visual narrative here reveals",
   "Breaking down this image, we can see",
   "At first glance, this image depicts",
   "The primary subject matter consists of",
   "This visual scene contains",
   "Upon careful observation, this image shows",
   "The essence of this image lies in",
   "Let me describe what I see:",
   "Here's what I observe:",
   "I can describe this as",
   "What we have here is",
   "This appears to be",
   "My interpretation is that",
   "From my perspective,",
   "One could say that",
   "It's clear that",
   "In this representation,",
   "The key details reveal",
   "Upon examination,",
   "What's notable here is",
   "The overall impression is",
   "What catches the eye is",
   "This scene illustrates",
   "The main focus here shows",
   "On display here is",
   # General prompts
   "Let me explain:",
   "Here's a scene:",
   "Here's what's happening:",
   "The situation involves",
   "We can observe that",
   "Consider the following:",
   "To put it simply,",
   "In essence,",
   "What we're looking at is",
   "To describe this briefly,",
   "Let me break this down:",
   "To summarize,",
   "The main aspect is",
   "Here's the thing:",
   "What's interesting is",
   "The fundamental nature is",
   "To understand this,",
   "The basic idea here is",
   "If I had to explain it,"
   "Let's examine this:",
   "I would say that",
   "The central theme is",
   "What matters most here is",
   "This reminds me of",
   "The key takeaway is",
   "When you look closely,",
   "There's something about",
   "What fascinates me is",
   "There once was",
   "The story begins with",
   "There's a place where",
]
# assert len(prompts) == batch_size


In [9]:
#@title Load models

clip_name = "ViT-L-14-336"
clip_pretrained = "openai"
model_name = "google/gemma-2-2b"

device = torch.device("cuda:0")

# Load CLIP
clip_tokenizer = open_clip.get_tokenizer(clip_name)
clip_model, _, preprocess = open_clip.create_model_and_transforms(
    clip_name, pretrained=clip_pretrained, device=device
)
clip_model.eval().requires_grad_(False)

# Load language model
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
assert  tokenizer.padding_side == "left"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    attn_implementation='eager',
)


100%|███████████████████████████████████████| 934M/934M [00:11<00:00, 79.1MiB/s]


tokenizer_config.json:   0%|          | 0.00/46.4k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/818 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/481M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/168 [00:00<?, ?B/s]

In [10]:
#@title Prepare LoRA
peft_config = peft.LoraConfig(
    peft.TaskType.CAUSAL_LM,
    inference_mode=False,
    r=lora_rank,
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    target_modules=[
        # For NeoX and Pythia
        # "attention.query_key_value",
        # "attention.dense", # writes to residual stream
        # "mlp.dense_h_to_4h",
        # "mlp.dense_4h_to_h", # writes to residual stream

        # GPT2
        # "attn.c_attn",
        # "mlp.c_fc",

        # For Llama and Mistral 7B & Gemma 2B
        "self_attn.q_proj",
        "self_attn.k_proj",
        "self_attn.v_proj",
        # "self_attn.o_proj", # write to residual stream
        "mlp.gate_proj",
        "mlp.up_proj",
        # "mlp.down_proj", # write to residual stream
    ],
)
model = peft.get_peft_model(model, peft_config)

model.gradient_checkpointing_enable()
model.enable_input_require_grads()
model.train()
model.print_trainable_parameters()

trainable params: 14,163,968 || all params: 2,628,505,856 || trainable%: 0.5389


In [11]:
#@title Prepare tokens & attention mask
inputs = tokenizer(
    prompts,
    padding=True,  # Pad to longest sequence
    return_tensors="pt",
    return_length=True,
).to(device)

# Get the length of each prompt
prompt_lengths = inputs.attention_mask.sum(dim=1)
max_len = inputs.input_ids.shape[1]

# Create logp_mask that accounts for different prompt lengths
logp_mask = torch.zeros((len(prompts), max_len + max_new_tokens), dtype=torch.bool, device=device)

# For each item in batch, mark tokens after the prompt for loss calculation
for i, prompt_len in enumerate(prompt_lengths):
    logp_mask[i, prompt_len:] = True

generated_attention = torch.ones(
    (len(prompts), max_new_tokens),
    dtype=inputs.attention_mask.dtype,
    device=inputs.attention_mask.device
)
attention_mask = torch.cat(
    (inputs.attention_mask, generated_attention),
    dim=1
)

input_ids = inputs.input_ids
input_attention_mask = inputs.attention_mask

In [12]:
#@title Prepare Images
proc_images = torch.stack([preprocess(im) for im in ds['image']]).to(device)
with torch.amp.autocast("cuda"):
    image_embeds = clip_model.encode_image(proc_images).float()

In [13]:
#@title Pre-train check
def debug_kl_divergence(model_logits, ref_logits, tokens, mask, temperature=1.0):
    """
    Debug KL divergence between model and reference distributions

    Args:
        model_logits: Logits from the LoRA model (batch_size, seq_len, vocab_size)
        ref_logits: Logits from the reference model (batch_size, seq_len, vocab_size)
        tokens: Input tokens (batch_size, seq_len)
        mask: Attention mask (batch_size, seq_len)
        temperature: Sampling temperature

    Returns:
        Dict containing diagnostic information
    """
    with torch.no_grad():
        # Get probabilities
        model_probs = torch.softmax(model_logits / temperature, dim=-1)
        ref_probs = torch.softmax(ref_logits / temperature, dim=-1)

        # Calculate KL divergence per token
        kl_per_token = torch.sum(
            ref_probs * (torch.log(ref_probs + 1e-10) - torch.log(model_probs + 1e-10)),
            dim=-1
        )

        # Apply mask to find valid positions
        masked_kl = kl_per_token * mask
        valid_positions = (mask > 0).nonzero()

        if len(valid_positions) > 0:
            # Get batch and position index of max KL
            batch_idx, pos_idx = valid_positions[torch.argmax(masked_kl[valid_positions[:, 0], valid_positions[:, 1]])]
            max_kl_token = tokens[batch_idx, pos_idx].item()
        else:
            max_kl_token = None

        # Get distribution statistics
        model_entropy = -torch.sum(model_probs * torch.log(model_probs + 1e-10), dim=-1)
        ref_entropy = -torch.sum(ref_probs * torch.log(ref_probs + 1e-10), dim=-1)

        # Calculate averages only over valid positions
        valid_mask_sum = mask.sum().item()
        if valid_mask_sum > 0:
            avg_kl = (masked_kl).sum() / valid_mask_sum
            avg_model_entropy = (model_entropy * mask).sum() / valid_mask_sum
            avg_ref_entropy = (ref_entropy * mask).sum() / valid_mask_sum
        else:
            avg_kl = torch.tensor(0.0).to(model_logits.device)
            avg_model_entropy = torch.tensor(0.0).to(model_logits.device)
            avg_ref_entropy = torch.tensor(0.0).to(model_logits.device)

        return {
            "avg_kl": avg_kl.item(),
            "max_kl": masked_kl.max().item() if valid_mask_sum > 0 else 0.0,
            "max_kl_token": max_kl_token,
            "model_entropy": avg_model_entropy.item(),
            "ref_entropy": avg_ref_entropy.item(),
            "prob_diff": (model_probs - ref_probs).abs().mean().item(),
            "sequence_length": tokens.size(1),
            "valid_tokens": valid_mask_sum
        }


model.eval()

sample_idxs = torch.randperm(input_ids.shape[0])[:batch_size]
sample_input_ids = input_ids[sample_idxs]
sample_input_attention_mask = input_attention_mask[sample_idxs]
sample_attention_mask = attention_mask[sample_idxs]

tokens = model.generate(
    sample_input_ids,
    attention_mask=sample_input_attention_mask,
    do_sample=True,
    min_new_tokens=max_new_tokens,
    max_new_tokens=max_new_tokens,
    pad_token_id=tokenizer.eos_token_id,
    temperature=temperature,
    top_k=0,
)

# Get the logits of the samples from the model and the reference model
with torch.no_grad(), model.disable_adapter():
    outputs_ref = model(tokens, attention_mask=sample_attention_mask, use_cache=False)
model.train()
outputs = model(tokens, attention_mask=sample_attention_mask, use_cache=False)
# Add before training loop
with torch.no_grad():
    debug_info = debug_kl_divergence(
        outputs.logits,
        outputs_ref.logits,
        tokens,
        sample_attention_mask,
        temperature
    )
    print("Debug info:", debug_info)

Debug info: {'avg_kl': 0.0, 'max_kl': 0.0, 'max_kl_token': 2, 'model_entropy': 2.364332437515259, 'ref_entropy': 2.364332437515259, 'prob_diff': 0.0, 'sequence_length': 64, 'valid_tokens': 911}


In [None]:
#@title Optimize the LLM

# Settings
torch.set_float32_matmul_precision("high")


# Optimize the LLM
opt = optim.Adam(model.parameters(), lr=lr, betas=(beta1, beta2), weight_decay=wd)

try:
    for i in tqdm(range(1500)):
    # for i in tqdm(endless_range()):
        # Generate a batch of samples from the model
        sample_idxs = torch.randperm(input_ids.shape[0])[:batch_size]
        sample_input_ids = input_ids[sample_idxs]
        sample_input_attention_mask = input_attention_mask[sample_idxs]
        sample_attention_mask = attention_mask[sample_idxs]
        sample_logp_mask = logp_mask[sample_idxs]

        model.eval()
        tokens = model.generate(
            sample_input_ids,
            attention_mask=sample_input_attention_mask,
            do_sample=True,
            min_new_tokens=max_new_tokens,
            max_new_tokens=max_new_tokens,
            pad_token_id=tokenizer.eos_token_id,
            temperature=temperature,
            top_k=0,
            no_repeat_ngram_size=2,
        )

        # Get the logits of the samples from the model and the reference model
        with torch.no_grad(), model.disable_adapter():
            outputs_ref = model(tokens, attention_mask=sample_attention_mask, use_cache=False)
        model.train()
        outputs = model(tokens, attention_mask=sample_attention_mask, use_cache=False)

        # Compute the log probability of the samples under the model and the reference model
        logp = logp_completion(outputs.logits / temperature, tokens, sample_logp_mask)
        logp_ref = logp_completion(outputs_ref.logits / temperature, tokens, sample_logp_mask)

        # Compute the CLIP loss
        texts = [tokenizer.decode(t, skip_special_tokens=True) for t in tokens]
        clip_tokens = clip_tokenizer(texts).to(device)
        with torch.cuda.amp.autocast():
            text_embeds = clip_model.encode_text(clip_tokens).float()

        # Randomly sample an image for each (all of the same thing)
        num_samples = image_embeds.shape[0]
        random_indices = torch.randint(0, num_samples, (batch_size,))
        expanded_image_embeds = image_embeds[random_indices]

        # Compute CLIP penalty
        cost_clip = torch.cosine_similarity(text_embeds, expanded_image_embeds, dim=-1).arccos()

        # Compute the KL penalty
        cost_kl = logp.detach() - logp_ref

        # REINFORCE
        cost = cost_clip + kl_weight * cost_kl
        baseline = (cost.sum() - cost) / (cost.numel() - 1)
        box = torch.exp(logp - logp.detach())
        loss = torch.mean(box * cost + (1 - box) * baseline)

        # Update the model
        opt.zero_grad()
        loss.backward()
        opt.step()

        # Print statistics and the best sample in the batch
        grad_norm = torch.cat(
            [p.grad.flatten() for p in model.parameters() if p.grad is not None]
        ).norm()
        print(
            f"step: {i}, loss: {loss.item():g}, clip: {cost_clip.mean().item():g}, kl: {cost_kl.mean().item():g}, grad: {grad_norm.item():g}"
        )
        best_text = texts[torch.argmin(cost).item()]
        print(textwrap.fill(best_text, width=80))
        print()

except KeyboardInterrupt:
    pass

  0%|          | 0/1500 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


step: 0, loss: 1.4324, clip: 1.4324, kl: 0, grad: 0.511501
If I had to explain it,Let's examine this:  1. The "we" in our society has no
longer been taken care of for a long, long time. 2. In the past, we had a "class
system" which was very well defined. This ensured that the people who

step: 1, loss: 1.43149, clip: 1.43049, kl: 0.497755, grad: 0.459962
The central theme is to introduce to you the most beautiful places that will
allow can inspire the writing of your favourite book. Your book could be written
in those places or on those beaches. It can be a novel, a story in a diary, an
inspirational writing and even

step: 2, loss: 1.43905, clip: 1.43691, kl: 1.07168, grad: 0.538709
One could say that my career in education “started” at the elementary level, but
that would be a bit of a misnomer. I attended elementary school four years
instead of the usual five years, so that I could conveniently move into a public
school that was closer

step: 3, loss: 1.43084, clip: 1.42701, kl: 1.9

In [None]:
input_text = "My physical form is"
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")

output = model.generate(
    **inputs,
    max_new_tokens=100,
    use_cache=False,
    do_sample=True,
    temperature=0.9,
    num_beams=5,
    no_repeat_ngram_size=2,
)

print(tokenizer.decode(output[0], skip_special_tokens=True))

In [None]:
model.push_to_hub("tommyp111/gemma-2b-clip-lora-golden-gate-no-resid-proj-2")

In [None]:
exit(0)