# 1. Imports

In [None]:
# Configs
%load_ext autoreload
%autoreload 3
## other standard packages
import sys
## Env variables and preparation stuffs
sys.path.insert(0, "../")
from src_code.data_utils.dataset import GridDataset
from src_code.data_utils.dataset_utils import CellType

# 2. Dataset

In [None]:
dataset = GridDataset(grid_size=5, seed = 42, wall_symbol="#", free_symbol=".")

img_rgb1, grid_world1 = dataset[1]

In [None]:
img_rgb1

In [None]:
print(str(grid_world1))

In [None]:
print(f"{grid_world1.a_star() = }")

## 3 Open Flamingo Inference
Source: https://github.com/mlfoundations/open_flamingo

## 3.1  Reference

In [None]:
from open_flamingo import create_model_and_transforms

model, image_processor, tokenizer = create_model_and_transforms(
    clip_vision_encoder_path="ViT-L-14",
    clip_vision_encoder_pretrained="openai",
    lang_encoder_path="togethercomputer/RedPajama-INCITE-Instruct-3B-v1",
    tokenizer_path="togethercomputer/RedPajama-INCITE-Instruct-3B-v1",
    cross_attn_every_n_layers=2
)

# grab model checkpoint from huggingface hub
from huggingface_hub import hf_hub_download
import torch

checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-4B-vitl-rpj3b-langinstruct", "checkpoint.pt")
model.load_state_dict(torch.load(checkpoint_path), strict=False)


In [None]:
# grab model checkpoint from huggingface hub
from huggingface_hub import hf_hub_download
import torch
from PIL import Image
import requests

if torch.cuda.is_available():
    print("CUDA is available! Using GPU for calculations.")
    device = torch.device("cuda")
else:
    print("CUDA is not available. Using CPU for calculations.")
    device = torch.device("cpu")
model.to(device)

In [None]:
from PIL import Image
import requests
import torch

"""
Step 1: Load images
"""
demo_image_one = Image.open(
    requests.get(
        "http://images.cocodataset.org/val2017/000000039769.jpg", stream=True
    ).raw
)

demo_image_two = Image.open(
    requests.get(
        "http://images.cocodataset.org/test-stuff2017/000000028137.jpg",
        stream=True
    ).raw
)

query_image = Image.open(
    requests.get(
        "http://images.cocodataset.org/test-stuff2017/000000028352.jpg", 
        stream=True
    ).raw
)


"""
Step 2: Preprocessing images
Details: For OpenFlamingo, we expect the image to be a torch tensor of shape 
 batch_size x num_media x num_frames x channels x height x width. 
 In this case batch_size = 1, num_media = 3, num_frames = 1,
 channels = 3, height = 224, width = 224.
"""
vision_x = [image_processor(demo_image_one).unsqueeze(0), image_processor(demo_image_two).unsqueeze(0), image_processor(query_image).unsqueeze(0)]
vision_x = torch.cat(vision_x, dim=0)
vision_x = vision_x.unsqueeze(1).unsqueeze(0)

"""
Step 3: Preprocessing text
Details: In the text we expect an <image> special token to indicate where an image is.
 We also expect an <|endofchunk|> special token to indicate the end of the text 
 portion associated with an image.
"""
tokenizer.padding_side = "left" # For generation padding tokens should be on the left
lang_x = tokenizer(
    ["<image>An image of two cats.<|endofchunk|><image>An image of a bathroom sink.<|endofchunk|><image>An image of"],
    return_tensors="pt",
)


"""
Step 4: Generate text
"""
generated_text = model.generate(
    vision_x=vision_x,
    lang_x=lang_x["input_ids"],
    attention_mask=lang_x["attention_mask"],
    max_new_tokens=20,
    num_beams=3,
)

print("Generated text: ", tokenizer.decode(generated_text[0]))

## 3.2 Few Shot Learning - only rgb input

In [None]:
from PIL import Image
import requests
import torch
import numpy as np
"""
Step 1: Load images
"""
dataset = GridDataset(grid_size=5, seed = 42, wall_symbol="#", free_symbol=".", cell_size=5)
img_rgb1, gridworld1 = dataset[0]
img_rgb2, gridworld2 = dataset[1]
img_rgb3, gridworld3 = dataset[2]
img_rgb4, gridworld4 = dataset[3]
ascii_inp1, path1 = str(gridworld1), gridworld1.a_star()
ascii_inp2, path2 = str(gridworld2), gridworld2.a_star()
ascii_inp3, path3 = str(gridworld3), gridworld3.a_star()
ascii_inp4, path4 = str(gridworld4), gridworld4.a_star()

In [None]:
from src_code.data_utils.dataset_utils import draw_image_grid
draw_image_grid([(img_rgb1, path1), (img_rgb2, path2), (img_rgb3, path3)])

### 3.2.1 One Shot Learning - totally wrong answers

In [None]:
from src_code.model_utils.openflamingo_utils import generate_inputs_for_openflamingo
num_shots = 1
vision_x, lang_x = generate_inputs_for_openflamingo(tokenizer, image_processor, num_shots=num_shots, dataset=dataset)

"""
Step 4: Generate text
"""
generated_text = model.generate(
    vision_x=vision_x,
    lang_x=lang_x["input_ids"],
    attention_mask=lang_x["attention_mask"],
    max_new_tokens=20,
    num_beams=3,
)

print("Generated text: ", tokenizer.decode(generated_text[0]))

In [None]:
draw_image_grid([(d[0], d[2]) for d in [dataset[i] for i in range(num_shots)]] + [(dataset[num_shots][0], "gt: " + str(dataset[num_shots][2]))])

### 3.2.2 Few Shot Learning 

In [None]:
from src_code.model_utils.openflamingo_utils import generate_inputs_for_openflamingo
num_shots = 1
vision_x, lang_x = generate_inputs_for_openflamingo(tokenizer, image_processor, num_shots=num_shots, dataset=dataset)

"""
Step 4: Generate text
"""
generated_text = model.generate(
    vision_x=vision_x,
    lang_x=lang_x["input_ids"],
    attention_mask=lang_x["attention_mask"],
    max_new_tokens=20,
    num_beams=3,
)

print("Generated text: ", tokenizer.decode(generated_text[0]))

In [None]:
draw_image_grid([(d[0], d[2]) for d in [dataset[i] for i in range(num_shots)]] + [(dataset[num_shots][0], "gt: " + str(dataset[num_shots][2]))])

## 3.3  Few Shot ICL - only ASCII Input

### 3.3.1  Zero Shot

In [None]:
img_rgb1, gridworld1 = dataset[1]
ascii_inp1, path1 = str(gridworld1), gridworld1.a_star()

In [None]:
ascii_inp1

In [None]:
print(ascii_inp1)

In [None]:
img_rgb2, gridworld2 = dataset[2]

In [None]:
from src_code.model_utils.openflamingo_utils import generate_inputs_for_openflamingo_llm

"""
Step 4: Generate text
"""
num_shots = 2
dummy_vision_x, lang_x = generate_inputs_for_openflamingo_llm(tokenizer, image_processor, num_shots, dataset)
generated_text = model.generate(
    vision_x=dummy_vision_x,
    lang_x=lang_x["input_ids"],
    attention_mask=lang_x["attention_mask"],
    max_new_tokens=20,
    num_beams=3,
)
print("Generated text: \n", tokenizer.decode(generated_text[0]))

# What we tried.
1. directly prompt - garbage answer
2. one shot propmt - ???
3. two shot prompt - doesnt work
4. different image resolutions
5. ascii
6. bigger models??