In [4]:
#flamingo paper: https://arxiv.org/pdf/2204.14198

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.models import resnet18, ResNet18_Weights
from transformers import AutoModelForCausalLM, AutoTokenizer

from PIL import Image
import numpy as np
from io import BytesIO

import requests

  from .autonotebook import tqdm as notebook_tqdm


## Paper notes

- flamingo is novel architecture to combine large language models (LMs) with vision models
- flamingo supports arbitrarily interleaving text and images
- Architecture:
  - Part 1: Image processing
    - input is a series of interleaved text and images
    - images are taken from the input, and forward passed through pre-trained, frozen vision encoder (vae?)
    - outputs of vision encoder are passed to "Perciever Resampler"
  - Part 2: text processing:
    - images are removed from text/image stream and replaced with a special image token
  - Part 3: joint transformer forward pass:
    - Each layer of the LM transformer is kept the same, but before each layer a Gated XATTN-DENSE layer is inserted.
    - Images and text are both passed to the ith GATED XATTN-DENSE layer, then to the ith LM layer
    - finally, the loss function is auto-regressive model on text-only outputs (final layer only produces a probability distribution over text tokens)

### Gated XATTN-DENSE Layers
- these are like regular attention and feed-forward layers in a transformer network, with the following differences:
  - XATTN (cross attention): rather than having a query matrix multiplied by all tokens in the text string (as in regular transformer), the queries are generated by the text part of the input and the keys and values are generated by the image part of the input. Thus, this layer only modifies the text tokens by learning how image embeddings from vision encoder "attend to" the text embeddings of text tokens in the input. it's "cross" attention since the attention is across modalities
  - DENSE: the layer following the XATTN is a dense feed-forward layer, as in the original transformer architecture
  - Gated: there are "tamping" parameters applied to the residual outputs from the XATTN and DENSE FFW layers. In the original transformer architecture, the output of each attention layer/dense ffw layer is added to the original input. In the gated layers, a tanh gating is applied to regulate the influence of the attn/ffw layer outputs on x.
    - Regular transformer: x_attn = attn(x), x += x_attn, x_ffw = ffw(x), x += x_ffw.
    - Gated: x_attn = attn(x), x += tanh(alpha_attn) * attn(x), x_ffw = ffw(x), x += tanh(alpha_ffw) * x_ffw. Alpha_attn and alpha_ffw are trainable scalar values (1 of each for each layer), initialized at 0.

### Vision Encoder
- uses a Normalizer-Free ResNet architecture. Normalizer-free just means it uses some gradient techniques to not have to use BatchNorm

### Perciever Resampler
- To allow the model to accept video inputs, the perciever resampler is used to map a variable number of images (ie video of variable length) into a fixed size output of 5 image tokens/embeddings

### Dataset
the flamingo models use the ALIGN dataset, a proprietary text-image pair dataset from Google, using scraped images + alt text. Alternatives that are open source include Conceptual Captions, from Google, LAION-5B from Allen AI, and MS-COCO (Microsoft Common Objects in Context)

In [None]:
#pseudo code
def perceiver_resampler(
  x_f, # The [T, S, d] visual features (T=time, S=space)
  time_embeddings, # The [T, 1, d] time pos embeddings.
  x, # R learned latents of shape [R, d]
  num_layers, # Number of layers
):
  """The Perceiver Resampler model."""
  #Add the time position embeddings and flatten.
  x_f = x_f + time_embeddings
  x_f = flatten(x_f) # [T, S, d] -> [T * S, d]

  #Apply the Perceiver Resampler layers.
  for i in range(num_layers):

  #Attention.
  x = x + attention_i(q=x, kv=concat([x_f, x]))

  #Feed forward.
  x = x + ffw_i(x)
  return x

### Implementing Forward Pass

In [5]:
#get dataset
f = open('Train_GCC-training.tsv', 'r')
lines = [f.readline().replace('\n', '').split('\t') for i in range(64)]

In [12]:
url = lines[0][1]
res = requests.get(url)
image = Image.open(BytesIO(res.content))
np_array = np.asarray(image)

In [15]:
np_imgs = []
for i, line in enumerate(lines):
  try:
    img_req = requests.get(line[1])
    print(img_req)
    img_data = img_req.content
    image = Image.open(BytesIO(img_data))
    np_array = np.asarray(image)
    print(np_array.shape)
    np_imgs.append(np_array)
  except Exception as e:
    print('failed to get image: ', i)

<Response [200]>
(534, 800, 3)
<Response [200]>
(441, 500, 3)
<Response [400]>
failed to get image:  2
<Response [200]>
(470, 450, 3)
<Response [200]>
(470, 450, 3)
<Response [400]>
failed to get image:  5
<Response [200]>
(678, 1200, 3)
<Response [200]>
(470, 450, 3)
<Response [200]>
(612, 491, 3)
<Response [200]>
failed to get image:  9
<Response [200]>
(422, 634, 3)
<Response [404]>
failed to get image:  11
failed to get image:  12
<Response [400]>
failed to get image:  13
<Response [200]>
(640, 640, 3)
<Response [200]>
failed to get image:  15
<Response [400]>
failed to get image:  16
<Response [404]>
failed to get image:  17
<Response [404]>
failed to get image:  18
<Response [200]>
(532, 800, 3)
<Response [404]>
(200, 200)
<Response [400]>
failed to get image:  21
<Response [200]>
(688, 735, 3)
<Response [200]>
(450, 450, 3)
<Response [410]>
failed to get image:  24
<Response [404]>
failed to get image:  25
<Response [200]>
(446, 640, 3)
<Response [200]>
(447, 640, 3)
<Response [

In [None]:
#instantiate Qwen/Qwen3-0.6B for pre-trained LLM

model_name = "Qwen/Qwen3-0.6B"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(model_name)


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

In [None]:
tokenizer.encode('hello there, my name is jacob ')

[14990, 1052, 11, 847, 829, 374, 502, 38951]

In [None]:
for mod in model.named_modules():
  print(mod)

('', Qwen3ForCausalLM(
  (model): Qwen3Model(
    (embed_tokens): Embedding(151936, 1024)
    (layers): ModuleList(
      (0-27): 28 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): Linear(in_features=1024, out_features=2048, bias=False)
          (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=1024, bias=False)
          (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MLP(
          (gate_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (up_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (down_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen3RMSNorm((1024,), eps=1e-06)
        (post_attention_layernorm)