# Imports

In [None]:
# Configs
%load_ext autoreload
%autoreload 3
## other standard packages
import sys
## Env variables and preparation stuffs
sys.path.insert(0, "../")
from src_code.data_utils.dataset import GridDataset
from src_code.data_utils.dataset_utils import CellType

# Dataset

In [None]:
dataset = GridDataset(grid_size=5, seed = 42, wall_symbol="#", free_symbol=".")

img_rgb, ascii_inp, path = dataset[0]

In [None]:
img_rgb

In [None]:
print(ascii_inp)

In [None]:
print(f"{path = }")

# 3. VLM Inference

## 3.1 Open Flamingo
Source: https://github.com/mlfoundations/open_flamingo|

In [None]:
from open_flamingo import create_model_and_transforms
model, image_processor, tokenizer = create_model_and_transforms(
    clip_vision_encoder_path="ViT-L-14",
    clip_vision_encoder_pretrained="openai",
    lang_encoder_path="anas-awadalla/mpt-1b-redpajama-200b",
    tokenizer_path="anas-awadalla/mpt-1b-redpajama-200b",
    cross_attn_every_n_layers=1
)

In [None]:
# grab model checkpoint from huggingface hub
from huggingface_hub import hf_hub_download
import torch

checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-3B-vitl-mpt1b", "checkpoint.pt")
model.load_state_dict(torch.load(checkpoint_path), strict=False)

In [None]:
from PIL import Image
import requests
import torch
import numpy as np
"""
Step 1: Load images
"""
dataset = GridDataset(grid_size=5, seed = 42, wall_symbol="#", free_symbol=".")

img_rgb, ascii_inp, path = dataset[0]
img_rgb.show()
"""
Step 2: Preprocessing images
Details: For OpenFlamingo, we expect the image to be a torch tensor of shape 
 batch_size x num_media x num_frames x channels x height x width. 
 In this case batch_size = 1, num_media = 3, num_frames = 1,
 channels = 3, height = 224, width = 224.
"""
vision_x = [image_processor(img_rgb).unsqueeze(0)]
vision_x = torch.cat(vision_x, dim=0)
vision_x = vision_x.unsqueeze(1).unsqueeze(0)
print(vision_x.shape)
"""
Step 3: Preprocessing text
Details: In the text we expect an <image> special token to indicate where an image is.
 We also expect an <|endofchunk|> special token to indicate the end of the text 
 portion associated with an image.
"""
tokenizer.padding_side = "left" # For generation padding tokens should be on the left
instruction = "Find the shortest path from the start to the goal and give the sequence of actions using the symbols l, r, u, d"
lang_x = tokenizer(
    [f"<image>{instruction}<|endofchunk|>"],
    return_tensors="pt",
)


"""
Step 4: Generate text
"""
generated_text = model.generate(
    vision_x=vision_x,
    lang_x=lang_x["input_ids"],
    attention_mask=lang_x["attention_mask"],
    max_new_tokens=20,
    num_beams=3,
)

print("Generated text: ", tokenizer.decode(generated_text[0]))