# 0. Imports

In [2]:
import os
import torch
from PIL import Image
from src.dataset_and_loader.loader import get_transform
from src.dataset_and_loader.dataset import RecipeDataset
from src.dataset_and_loader.loader import dataset_loader
from src.model import FusionModel, InfoNCELoss
from src.utils.save import load_model
from src.pipeline import train_and_validate_model
from src.reterival import retrieve
from src.utils.plot import show_image
from src.utils.utils import count_trainable_params, count_total_params
from src.pipeline import encode_texts_in_batches, build_images_in_batches

# 1. Hyperparameters

In [3]:
IMAGE_SIZE = 224
BATCH_SIZE = 16
LEARNING_RATE = 3e-5 #1e-4
TEMPERATURE = 0.05 #0.05
MODEL_PATH = "fusion_model.pth"
device = "mps" if torch.backends.mps.is_available() else "cpu"

# 2. Dataset and DataLoader

In [4]:
base_dir = os.getcwd()
csv_path = os.path.join(base_dir, "data", "Food Ingredients and Recipe Dataset with Image Name Mapping.csv")
image_dir = os.path.join(base_dir, "data", "Food Images")

train_transform, val_transform = get_transform(IMAGE_SIZE)
full_dataset = RecipeDataset(csv_file=csv_path, image_dir=image_dir, transform=train_transform)
train_loader, val_loader = dataset_loader(full_dataset, train_transform, val_transform, BATCH_SIZE)

Train size: 10800, Val size: 2701


# 3. Initializations

In [5]:
model = FusionModel(device=device)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_fun = InfoNCELoss(temperature=TEMPERATURE)

In [6]:
model, optimizer, last_epoch = load_model(model, optimizer, path=MODEL_PATH, device=device)

In [7]:
if last_epoch > 1:
    last_epoch += 1
print(f"last epoch: {last_epoch}")

last epoch: 1


In [8]:
print(f"Trainable params: {count_trainable_params(model)}")
print(f"Total params: {count_total_params(model)}")

Trainable params: 159910976
Total params: 159910976


# 4. Train and Validate Model

In [None]:
TRAIN_NEXT_EPOCHS = 5
pre_last_epoch = last_epoch
train_and_validate_model(
    model, train_loader, val_loader,
    optimizer, loss_fun, device,
    start_epoch=pre_last_epoch,
    epochs=TRAIN_NEXT_EPOCHS,
    save_path=MODEL_PATH
)

Total epochs: 5
Epoch [1]


Training: 100%|██████████| 675/675 [12:58<00:00,  1.15s/it]


Training Loss: 4.9765


Validating: 100%|██████████| 169/169 [00:59<00:00,  2.83it/s]


Validation Loss: 3.7709
Text->Image: R@1 0.020, R@5 0.053, R@10 0.087, MedR 166.000, MRR 0.046
Text->LongText: R@1 0.073, R@5 0.217, R@10 0.320, MedR 31.000, MRR 0.152
Image->Text: R@1 0.016, R@5 0.057, R@10 0.096, MedR 157.000, MRR 0.046
Image->LongText: R@1 0.024, R@5 0.074, R@10 0.122, MedR 97.000, MRR 0.062
Model saved to fusion_model.pth
Loss saved to metrics/training_and_validation_loss_log.csv
Retrieval Metrics saved to metrics/retrieval_metrics_log.csv
Epoch [2]


Training:  30%|███       | 203/675 [03:24<08:16,  1.05s/it]

# 5. Evaluation metrices and Plot Graph

### a. Train and Validate Loss Graph

In [None]:
from src.utils.plot import plot_training_n_validation_loss
from src.utils.save import load_training_n_validation_loss

epochs, train_loss, val_loss = load_training_n_validation_loss()
plot_training_n_validation_loss(epochs, train_loss, val_loss)


# 6. Test Model

In [None]:
loaded_model, _, _ = load_model(model, optimizer, path=MODEL_PATH, device=device)

### a. Dataset Embeddings

In [None]:
from tqdm import tqdm

dataset_texts = []
dataset_ingredients_instructions = []
dataset_metadata = []
dataset_images = []

for i in tqdm(range(len(full_dataset)), desc="Preparing texts & metadata"):
    sample = full_dataset[i]
    dataset_texts.append(sample["input_text"])  # title
    dataset_ingredients_instructions.append(sample["target_text"])  # ingredients, instructions
    dataset_images.append(sample["image"])
    dataset_metadata.append(sample["metadata"])
# Stack images into a single tensor [N, C, H, W]
dataset_images = torch.stack(dataset_images)

In [None]:
# with torch.no_grad():
#     dataset_title_embeds = loaded_model.forward_text(dataset_texts)
dataset_title_embeds = encode_texts_in_batches(
    loaded_model,
    dataset_texts,
    batch_size=BATCH_SIZE,
    device=device,
    desc="Building titles embeddings:"
)

In [None]:
# with torch.no_grad():
#     dataset_ingredients_instructions_embeds = loaded_model.forward_long_text(dataset_ingredients_instructions).to(device)
dataset_ingredients_instructions_embeds = encode_texts_in_batches(
    loaded_model,
    dataset_ingredients_instructions,
    batch_size=BATCH_SIZE,
    device=device,
    desc="Building Ingredients and Instructions embeddings: "
)

In [None]:
# with torch.no_grad():
#   dataset_image_embeds = loaded_model.forward_image(dataset_images).to(device)
dataset_image_embeds = build_images_in_batches(
    model=loaded_model,
    dataset_images=dataset_images,
    device=device
)

In [None]:
print(dataset_title_embeds.shape)
print(dataset_ingredients_instructions_embeds.shape)
print(dataset_image_embeds.shape)

### b. Retrieve from your query

#### i. text -> title

In [None]:
query_texts = ["crispy salt and pepper potatoes"]

# Retrieve
results = retrieve(
    model=loaded_model,
    query_texts=query_texts,
    query_images=None,
    dataset_title_embeds=dataset_title_embeds,
    dataset_ingredients_instructions_embeds=None,
    dataset_image_embeds=None,
    top_k=3,
    device=device
)

# Example access
indices, scores = results["text->text"]
print("text->ingredients_instructions indices:", indices)
print("text->ingredients_instructions scores:", scores)


In [None]:
# Loop through the hits
for rank, idx in enumerate(indices[0]):  # 0 because batch size is 1
    idx = idx.item()  # convert from tensor to int
    sample = full_dataset[idx]
    print(f"{rank + 1}. Score: {scores[0][rank].item():.3f}")
    print(f"   Title: {sample['metadata']['title']}")
    print(f"   Ingredients+Instructions: {sample['target_text']}")

#### ii. image -> image

In [None]:
image_name = "crispy-salt-and-pepper-potatoes-dan-kluger"
image_dir = os.path.join(base_dir, "data", "Food Images")
image_path = os.path.join(image_dir, image_name + ".jpg")
img = Image.open(image_path).convert("RGB")
img_tensor = val_transform(img)
img_tensor = img_tensor.unsqueeze(0)  # add batch dimension [1, C, H, W]
query_images = img_tensor.to(device)

# Retrieve
results = retrieve(
    model=loaded_model,
    query_texts=None,
    query_images=query_images,
    dataset_title_embeds=None,
    dataset_ingredients_instructions_embeds=None,
    dataset_image_embeds=dataset_image_embeds,
    top_k=3,
    device=device
)

# Example access
indices, scores = results["image->image"]
print("image->image indices:", indices)
print("image->image scores:", scores)

In [None]:
# Loop through the hits
for rank, idx in enumerate(indices[0]):  # 0 because batch size is 1
    idx = idx.item()  # convert from tensor to int
    sample = full_dataset[idx]
    print(f"{rank + 1}. Score: {scores[0][rank].item():.3f}")
    print(f"   Title: {sample['metadata']['title']}")
    print(f"   Image Path: {sample['metadata']['image_path']}")
    show_image(sample['metadata']['image_path'])

#### iii. text -> image

In [None]:
query_texts = ["crispy salt and pepper potatoes"]
# query_texts = ["pizza"]
# query_texts = ["Salt-and-Pepper Fish"]

# Retrieve
results = retrieve(
    model=loaded_model,
    query_texts=query_texts,
    query_images=None,
    dataset_title_embeds=None,
    dataset_ingredients_instructions_embeds=None,
    dataset_image_embeds=dataset_image_embeds,
    top_k=3,
    device=device
)

# Example access
indices, scores = results["text->image"]
print("text->image indices:", indices)
print("text->image scores:", scores)

In [None]:
# Loop through the hits
for rank, idx in enumerate(indices[0]):  # 0 because batch size is 1
    idx = idx.item()  # convert from tensor to int
    sample = full_dataset[idx]
    print(f"{rank + 1}. Score: {scores[0][rank].item():.3f}")
    print(f"   Title: {sample['metadata']['title']}")
    print(f"   Image Path: {sample['metadata']['image_path']}")
    show_image(sample['metadata']['image_path'])

#### iv. image -> text

In [None]:
image_name = "miso-butter-roast-chicken-acorn-squash-panzanella"
# image_name = "pan-seared-salt-and-pepper-fish"

image_dir = os.path.join(base_dir, "data", "Food Images")
image_path = os.path.join(image_dir, image_name + ".jpg")
img = Image.open(image_path).convert("RGB")
img_tensor = val_transform(img)
img_tensor = img_tensor.unsqueeze(0)  # add batch dimension [1, C, H, W]
query_images = img_tensor.to(device)
print(query_images.shape)

# Retrieve
results = retrieve(
    model=loaded_model,
    query_texts=None,
    query_images=query_images,
    dataset_title_embeds=dataset_title_embeds,
    dataset_ingredients_instructions_embeds=None,
    dataset_image_embeds=None,
    top_k=3,
    device=device
)

# Example access
indices, scores = results["image->text"]
print("image->text indices:", indices)
print("image->text scores:", scores)

In [None]:
# Loop through the hits
for rank, idx in enumerate(indices[0]):  # 0 because batch size is 1
    idx = idx.item()  # convert from tensor to int
    sample = full_dataset[idx]
    print(f"{rank + 1}. Score: {scores[0][rank].item():.3f}")
    print(f"   Title: {sample['metadata']['title']}")
    print(f"   Image Path: {sample['metadata']['image_path']}")
    show_image(sample['metadata']['image_path'])