### Imports

In [None]:
import os

import numpy as np
from PIL import Image
from torchview import draw_graph
import torch
from transformers import AlignModel, AlignProcessor

os.environ["HF_HOME"] = "../.hf_home"
random_seed = 42
torch.manual_seed(random_seed)
device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

### Load Model & Processor

In [None]:
model = AlignModel.from_pretrained("kakaobrain/align-base", cache_dir=os.environ["HF_HOME"])
processor = AlignProcessor.from_pretrained("kakaobrain/align-base", cache_dir=os.environ["HF_HOME"])
text_model = model.text_model
vision_model = model.vision_model
text_projection = model.text_projection

#### Switch model to eval mode

In [None]:
_ = model.eval()
_ = text_model.eval()
_ = vision_model.eval()
_ = text_projection.eval()

text_projection

### Model architecture exploration

#### Image Size & Text Length

In [None]:
text_model.config.max_position_embeddings, processor.tokenizer.model_max_length, vision_model.config.image_size

#### Tokenizer / Text Embedding

In [None]:
processor.tokenizer

In [None]:
target_sentence = "a photo of a cat"
processor.tokenizer.tokenize(target_sentence)

In [None]:
cut_len = 15
tokenized = processor.tokenizer(target_sentence, return_tensors="pt", padding="max_length")

print(f"""Input IDs: \t\t{torch.flatten(tokenized.input_ids).tolist()[:cut_len]}
Attention Mask: \t{torch.flatten(tokenized.attention_mask).tolist()[:cut_len]}
Token Type IDs: \t{torch.flatten(tokenized.token_type_ids).tolist()[:cut_len]}
Total Length: \t\t{tokenized.input_ids.shape[1]}""")

In [None]:
with torch.no_grad():
    text_model_out = text_model(**tokenized)
text_model_out.pooler_output.shape

In [None]:
text_embedding = text_projection(text_model_out.pooler_output)
text_embedding.shape

#### Image Processing & Image Embedding

In [None]:
image = Image.open("sample_images/cat.jpg").convert("RGB")
image

In [None]:
processor.image_processor

In [None]:
processed_image = processor.image_processor(images=image, return_tensors="pt")
display(Image.fromarray(np.uint8(processed_image.pixel_values[0].permute(1, 2, 0).numpy() * 255)))
processed_image.pixel_values.shape

In [None]:
with torch.no_grad():
    vision_model_out = vision_model(**processed_image)
vision_embedding = vision_model_out.pooler_output
vision_embedding.shape

#### Model Pipeline Visualization

In [None]:
graph = draw_graph(model, input_size=[(1,64), (1,3,289,289)], dtypes=[torch.long, torch.float32], expand_nested=True)
graph.visual_graph