# Exercise Image-Text

In [None]:
from pathlib import Path
import paths
from datasets import VOCDataset

In [None]:
# 1. Image Captioning
# Your first task will be to complete the inference code to generate captions for the given VOC dataset.
from eval_captioning import extract_evaluate_write_captions

In [None]:
# 1.1 Complete Caption Generation with Greedy Search

# TODO: In file models/blip/blip_caption.py complete the methods generate and greedy_search. 
#       Generate and evaluate captions for the VOC dataset. You should get about 28% BLEU score. (2 points)

extract_evaluate_write_captions(use_topk_sampling=False, temperature=1.0)

In [None]:
voc_path = Path(paths.CV_PATH_VOC)
dataset = VOCDataset(voc_path, voc_path / "ImageSets" / "Segmentation" / "val.txt",
                     load_captions=True)

# load and show generated captions
# todo update the path to match your experiment
pred_captions_file = "outputs/eval_captioning/2024_06_04_11_49_15/pred_captions.txt"
with open(pred_captions_file, "r", encoding="utf-8") as f:
    pred_captions = f.readlines()
    
from PIL import Image
for i in range(10):
    data = dataset[i]
    display(data["image"])
    print(f"Pred caption: {pred_captions[i]}")
    print(f"Reference caption: {data['caption']}")

In [None]:
# 1.2 Complete Caption Generation with Sampling
# TODO: In file models/blip/blip_caption.py complete the method sampling. 
#       There, Top-K sampling instead of greedy search is used to select the next token when decoding. 
#       Evaluate again with Top-K sampling.
#       You should get a lower BLEU score of about 7% for temperature τ = 1.0, and about 12% for τ = 0.7. 
#       Why do the results improve with lower temperature? (1 point)
extract_evaluate_write_captions(use_topk_sampling=True, topk=50, temperature=1.0)

In [None]:
extract_evaluate_write_captions(use_topk_sampling=True, topk=50, temperature=0.7)

In [None]:
# 1.3 Prompt Engineering
# TODO: Experiment with different prompts (the default prompt is “a picture of ”). 
#       Can you improve the BLEU score? Try at least 3 new settings.
#       Note the prompt and the resulting BLEU score in a table for each setting. 
#       Add the table to your report. (1 point)
extract_evaluate_write_captions(use_topk_sampling=True, topk=50, temperature=0.7, prompt="a picture of ")

In [None]:
extract_evaluate_write_captions(use_topk_sampling=True, topk=50, temperature=0.7, prompt="an image of ")

In [None]:
extract_evaluate_write_captions(use_topk_sampling=True, topk=50, temperature=0.7, prompt="a depiction of ")

In [None]:
extract_evaluate_write_captions(use_topk_sampling=True, topk=50, temperature=0.7, prompt="a figure of ")

In [None]:
# 1.4 Student Hyperparameter Search
# TODO: Experiment with different decoding parameters. (Top-K with different K and temperature or greedy decoding). 
#       Can you improve the BLEU score? Try at least 3 new settings.
#       Note the hyperparameters, the prompt and the resulting BLEU score in a table for each setting.
#       Add the table to your report. (1 point)
extract_evaluate_write_captions(use_topk_sampling=True, topk=50, temperature=0.7, prompt="a picture of ")

In [None]:
extract_evaluate_write_captions(use_topk_sampling=True, topk=100, temperature=0.5, prompt="an image of ")

In [None]:
extract_evaluate_write_captions(use_topk_sampling=True, topk=25, temperature=0.3, prompt="a depiction of ")

In [None]:
extract_evaluate_write_captions(use_topk_sampling=True, topk=10, temperature=1.0, prompt="a figure of ")

In [None]:
extract_evaluate_write_captions(use_topk_sampling=True, topk=75, temperature=1.0, prompt="a figure of ")

In [None]:
extract_evaluate_write_captions(use_topk_sampling=True, topk=100, temperature=0.3, prompt="an image of ")

In [None]:
# 2. Image-Text Retrieval
# Your second task will be to train and evaluate the retrieval head of the BLIP model.
import torch
torch.cuda.empty_cache()

In [None]:
# 2.1 Complete Forward Pass and Evaluate
#     Todo: Complete the forward pass in file models/blip/blip_retrieval.py. 
#           Evaluate your implementation with the provided checkpoint. 
#           You should get about 54% image-to-text R@1. (2 points)
from eval_retrieval import eval_without_args    
eval_without_args()

In [None]:
# 2.2 Complete Loss and Train from Scratch
# Todo: Complete the loss computation in file train_retrieval.py function train_epoch. 
#       Train the retrieval projection layers from scratch (i.e. from random initialization).
#       You should get about 43% image-to-text R@1. (1 point)
from train_retrieval import train_retrieval_without_args
train_retrieval_without_args(finetune=False, learning_rate=1e-3, weight_decay=1e-3, epochs=5, temperature=0.1)

# Optional: start a tensorboard server tensorboard --logdir outputs --port 6006 and watch the experiment in the browser.

In [None]:
# 2.3 Finetune instead of Train from Scratch
# Todo: Now, try finetuning the head instead with --finetune. 
#       Set learning rate to 1e-5, weight decay to 0 and train for 3 epochs. 
#       What score do you get and how can you explain the difference to the score when training from scratch? (1 point)
#       Try different search queries. What do you observe?
    
train_retrieval_without_args(finetune=True, learning_rate=1e-5, weight_decay=1e-3, epochs=3, temperature=0.1)

In [None]:
train_retrieval_without_args(finetune=True, learning_rate=1e-5, weight_decay=0, epochs=3, temperature=0.1)

In [None]:
# 2.4 Student Hyperparameter Search
# Todo: Experiment with different hyperparameters in the random initialization setting (i.e. without finetuning). 
#       Try at least 3 new hyperparameter settings.
#       Note the hyperparameters and the resulting image-to-text R@1 score in a table for each setting. 
#       Can you improve over the baselines? Add the table to your report. (1 point)

# train_retrieval_without_args(...)
train_retrieval_without_args(finetune=False, learning_rate=1e-5, weight_decay=0, epochs=5, temperature=0.1)

In [None]:
train_retrieval_without_args(finetune=False, learning_rate=5e-1, weight_decay=1e-3, epochs=5, temperature=0.05)

In [None]:
train_retrieval_without_args(finetune=False, learning_rate=5e-3, weight_decay=1e-4, epochs=8, temperature=0.1)

In [None]:
train_retrieval_without_args(finetune=False, learning_rate=1e-3, weight_decay=0, epochs=8, temperature=0.05)

In [None]:
train_retrieval_without_args(finetune=False, learning_rate=1e-5, weight_decay=0, epochs=3, temperature=0.1)

In [None]:
# Optional: Visualize Top 10 results for a search query.

from search_retrieval import get_top10

search_query = "a picture of a plane"
dict_top10 = get_top10(eval_ckpt=None, query = search_query)

from PIL import Image
for i in range(len(dict_top10["id"])):
    image_pil = Image.open(dict_top10["fname"][i])
    display(image_pil)
    print(f"Sim. Score: {dict_top10['sim'][i]}")
    print(f"Caption: {dict_top10['caption'][i]}")
    #print(f"Name: {dict_top10['name'][i]}")