In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import dspy
from dspy.datasets import DataLoader
from dspy.evaluate.metrics import answer_exact_match
from typing import List
from dspy.evaluate import Evaluate

import dotenv

dotenv.load_dotenv()

def debug_exact_match(example, pred, trace=None, frac=1.0):
    print(example.inputs())
    print(example.answer)
    print(pred)
    # print(trace)
    # print(frac)
    return answer_exact_match(example, pred, trace, frac)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# vllm serve Qwen/Qwen2-VL-7B-Instruct --trust-remote-code --limit-mm-per-prompt image=16 --seed 42 --pipeline-parallel-size 2
qwen_lm = dspy.LM(model="openai/Qwen/Qwen2-VL-7B-Instruct", api_base="http://localhost:8000/v1", api_key="sk-fake-key", max_tokens=5000)
gpt_lm = dspy.LM(model="openai/gpt-4o-mini", max_tokens=5000)

dspy.settings.configure(lm=qwen_lm)

In [5]:
%%capture
from concurrent.futures import ThreadPoolExecutor

input_keys = tuple([f"image_{i}" for i in range(1, 3)] + ["question", "options"])
subsets = ['Accounting', 'Agriculture', 'Architecture_and_Engineering', 'Art', 'Art_Theory', 'Basic_Medical_Science', 'Biology', 'Chemistry', 'Clinical_Medicine', 'Computer_Science', 'Design', 'Diagnostics_and_Laboratory_Medicine', 'Economics', 'Electronics', 'Energy_and_Power', 'Finance', 'Geography', 'History', 'Literature', 'Manage', 'Marketing', 'Materials', 'Math', 'Mechanical_Engineering', 'Music', 'Pharmacy', 'Physics', 'Psychology', 'Public_Health', 'Sociology']

devset = []
valset = []
with ThreadPoolExecutor(max_workers=len(subsets)) as executor:
    def load_dataset(subset_index_subset):
        subset_index, subset = subset_index_subset
        dataset = DataLoader().from_huggingface("MMMU/MMMU", subset, split=["dev", "validation"], input_keys=input_keys)
        return subset_index, dataset["dev"], dataset["validation"]
    
    results = list(executor.map(load_dataset, enumerate(subsets)))
    
    results.sort(key=lambda x: x[0])
    
    for _, dev, val in results:
        devset.extend(dev)
        valset.extend(val)

print(len(devset))
print(len(valset))

In [6]:
import ast

def count_images(dataset):
    image_counts = {i: 0 for i in range(6)}  # Initialize counts for 0 to 2 images
    for example in dataset:
        count = sum(1 for key in example.inputs().keys() if key.startswith('image_') and example.inputs()[key] is not None)
        image_counts[count] += 1
    return image_counts

def count_multiple_choice_questions(dataset):
    return sum(1 for example in dataset if example["question_type"] == "multiple-choice")
max_images = 5

num_images = 1

devset_filtered = [example for example in devset if sum(1 for key in example.inputs().keys() if key.startswith('image_') and example.inputs()[key] is not None) == num_images]
valset_filtered = [example for example in valset if sum(1 for key in example.inputs().keys() if key.startswith('image_') and example.inputs()[key] is not None) == num_images]

devset_image_counts = count_images(devset_filtered)
valset_image_counts = count_images(valset_filtered)

devset_multiple_choice_questions = count_multiple_choice_questions(devset_filtered)
valset_multiple_choice_questions = count_multiple_choice_questions(valset_filtered)

print("Image counts in devset:")
for count, num_examples in devset_image_counts.items():
    print(f"{count} image(s): {num_examples} examples")

print("\nImage counts in valset:")
for count, num_examples in valset_image_counts.items():
    print(f"{count} image(s): {num_examples} examples")

print("\nMultiple choice questions in devset:")
print(devset_multiple_choice_questions, "out of", len(devset_filtered))
print("\nMultiple choice questions in valset:")
print(valset_multiple_choice_questions, "out of", len(valset_filtered))

def convert_multiple_choice_to_letter(dataset):
    new_dataset = []
    for example in dataset:
        if example["question_type"] == "multiple-choice":
            # print(example["options"])
            options = ast.literal_eval(example["options"])
            example["answer_choices"] = str([chr(65 + i) + ". " + option for i, option in enumerate(options)])
        else:
            example["answer_choices"] = str(ast.literal_eval(example["options"]))
            if example["answer_choices"] == []:
                example["answer_choices"] = "Free response"

        updated_example = example.with_inputs(*example.inputs().keys(), "answer_choices")
        new_dataset.append(updated_example)
    return new_dataset

print(devset_filtered[0])
updated_devset = convert_multiple_choice_to_letter(devset_filtered)
print(updated_devset[0])
updated_valset = convert_multiple_choice_to_letter(valset_filtered)


Image counts in devset:
0 image(s): 0 examples
1 image(s): 146 examples
2 image(s): 0 examples
3 image(s): 0 examples
4 image(s): 0 examples
5 image(s): 0 examples

Image counts in valset:
0 image(s): 0 examples
1 image(s): 857 examples
2 image(s): 0 examples
3 image(s): 0 examples
4 image(s): 0 examples
5 image(s): 0 examples

Multiple choice questions in devset:
137 out of 146

Multiple choice questions in valset:
805 out of 857
Example({'id': 'dev_Accounting_1', 'question': 'Each of the following situations relates to a different company. <image 1> For company B, find the missing amounts.', 'options': "['$63,020', '$58,410', '$71,320', '$77,490']", 'explanation': '', 'image_1': <PIL.PngImagePlugin.PngImageFile image mode=RGBA size=1234x289 at 0x727CD804B400>, 'image_2': None, 'image_3': None, 'image_4': None, 'image_5': None, 'image_6': None, 'image_7': None, 'img_type': "['Tables']", 'answer': 'D', 'topic_difficulty': 'Easy', 'question_type': 'multiple-choice', 'subfield': 'Financi

In [7]:
class MMMUSignature(dspy.Signature):
    """Output a rationale and the answer to a multiple choice question about an image with the letter of the correct answer, if present, otherwise the exact answer."""

    question: str = dspy.InputField(desc="A question about the image(s)")
    image_1: dspy.Image = dspy.InputField(desc="An image relating to the shown problem")
    # image_2: dspy.Image = dspy.InputField(desc="An image relating to the shown problem")
    answer_choices: List[str] = dspy.InputField(desc="The answer options for the question")
    answer: str = dspy.OutputField(desc="The single letter of the correct answer. Do not include the entire answer or a period at the end.")

class MMMUModule(dspy.Module):
    def __init__(self):
        super().__init__()
        self.predictor = dspy.ChainOfThought(MMMUSignature)

    def __call__(self, **kwargs):
        return self.predictor(**kwargs)


In [13]:


sample_input = updated_devset[0]
# print(sample_input.inputs())
# print(encode_image(sample_input.inputs()["image_1"]))
mmmu = MMMUModule()
print(sample_input.inputs())
print(mmmu(**sample_input.inputs()))
print(sample_input.answer)

evaluate_mmmu = Evaluate(metric=answer_exact_match, num_threads=50, devset=updated_valset, display_progress=True, max_errors=500, return_outputs=True)

Example({'question': 'Each of the following situations relates to a different company. <image 1> For company B, find the missing amounts.', 'options': "['$63,020', '$58,410', '$71,320', '$77,490']", 'image_1': <PIL.PngImagePlugin.PngImageFile image mode=RGBA size=1234x289 at 0x727CD804B400>, 'image_2': None, 'answer_choices': "['A. $63,020', 'B. $58,410', 'C. $71,320', 'D. $77,490']"}) (input_keys={'answer_choices', 'question', 'image_2', 'options', 'image_1'})
Prediction(
    reasoning='To find the missing amounts for company B, we need to balance the income statement. We know the revenues, expenses, gains, and losses. We can calculate the missing amounts by subtracting the known amounts from the total.',
    answer='B'
)
D


In [14]:
scores, outputs = evaluate_mmmu(mmmu)
# lm.inspect_history()


[2m2024-10-30T01:22:28.978705Z[0m [[31m[1merror    [0m] [1mError for example in dev set: 		 Images are not yet supported in JSON mode.. Set `provide_traceback=True` to see the stack trace.[0m [[0m[1m[34mdspy.evaluate.evaluate[0m][0m [36mfilename[0m=[35mevaluate.py[0m [36mlineno[0m=[35m200[0m
[2m2024-10-30T01:22:29.031727Z[0m [[31m[1merror    [0m] [1mError for example in dev set: 		 Images are not yet supported in JSON mode.. Set `provide_traceback=True` to see the stack trace.[0m [[0m[1m[34mdspy.evaluate.evaluate[0m][0m [36mfilename[0m=[35mevaluate.py[0m [36mlineno[0m=[35m200[0m


Expected dict_keys(['reasoning', 'answer']) but got dict_keys(['reasoning']) from [[ ## reasoning ## ]]
The internal angles of a pentagon sum up to 540°. Given that the internal angle at vertex 1 is 70°, we can calculate the internal angle at vertex 2 by subtracting the known angles from 540°. The internal angle at vertex 2 is 145°. The internal angle at vertex 3 is 85°. The internal angle at vertex 4 is 120°. The internal angle at vertex 5 is 120°. The internal angle at vertex 1 is 70°. The internal angle at vertex 2 is 145°. The internal angle at vertex 3 is 85°. The internal angle at vertex 4 is 120°. The internal angle at vertex 5 is 120°. The internal angle at vertex 1 is 70°. The internal angle at vertex 2 is 145°. The internal angle at vertex 3 is 85°. The internal angle at vertex 4 is 120°. The internal angle at vertex 5 is 120°. The internal angle at vertex 1 is 70°. The internal angle at vertex 2 is 145°. The internal angle at vertex 3 is 85°. The internal angle at vertex 4 i

Average Metric: 363.0 / 846  (42.9):  99%|█████████▊| 846/857 [06:34<00:05,  2.14it/s]
Average Metric: 4.0 / 21  (19.0):   2%|▏         | 20/857 [00:00<00:44, 18.90it/s][2m2024-10-30T01:22:30.636510Z[0m [[31m[1merror    [0m] [1mError for example in dev set: 		 Images are not yet supported in JSON mode.. Set `provide_traceback=True` to see the stack trace.[0m [[0m[1m[34mdspy.evaluate.evaluate[0m][0m [36mfilename[0m=[35mevaluate.py[0m [36mlineno[0m=[35m200[0m
Average Metric: 5.0 / 23  (21.7):   3%|▎         | 22/857 [00:01<00:56, 14.88it/s]

Expected dict_keys(['reasoning', 'answer']) but got dict_keys(['reasoning']) from [[ ## reasoning ## ]]
The pH of the extracellular fluid is not shown in the image, but it is mentioned in the question. The pH of the extracellular fluid is not shown in the image, but it is mentioned in the question. The pH of the extracellular fluid is not shown in the image, but it is mentioned in the question. The pH of the extracellular fluid is not shown in the image, but it is mentioned in the question. The pH of the extracellular fluid is not shown in the image, but it is mentioned in the question. The pH of the extracellular fluid is not shown in the image, but it is mentioned in the question. The pH of the extracellular fluid is not shown in the image, but it is mentioned in the question. The pH of the extracellular fluid is not shown in the image, but it is mentioned in the question. The pH of the extracellular fluid is not shown in the image, but it is mentioned in the question. The pH of the 

Average Metric: 6.0 / 27  (22.2):   3%|▎         | 26/857 [00:01<00:54, 15.14it/s][2m2024-10-30T01:22:31.117720Z[0m [[31m[1merror    [0m] [1mError for example in dev set: 		 Images are not yet supported in JSON mode.. Set `provide_traceback=True` to see the stack trace.[0m [[0m[1m[34mdspy.evaluate.evaluate[0m][0m [36mfilename[0m=[35mevaluate.py[0m [36mlineno[0m=[35m200[0m
Average Metric: 7.0 / 31  (22.6):   4%|▎         | 31/857 [00:01<00:48, 17.06it/s]

Expected dict_keys(['reasoning', 'answer']) but got dict_keys(['reasoning']) from [[ ## reasoning ## ]]
To find the total weight of the minimum spanning tree, we need to select the edges with the smallest weights that connect all the vertices without forming any cycles. The minimum spanning tree of the given graph can be constructed by selecting the edges with the smallest weights: (1, 2) with weight 1, (2, 3) with weight 1, (3, 4) with weight 1, (4, 5) with weight 1, (5, 6) with weight 1, (6, 7) with weight 1, (7, 8) with weight 1, (8, 9) with weight 1, (9, 10) with weight 1, (10, 11) with weight 1, (11, 12) with weight 1, (12, 13) with weight 1, (13, 14) with weight 1, (14, 15) with weight 1, (15, 16) with weight 1, (16, 17) with weight 1, (17, 18) with weight 1, (18, 19) with weight 1, (19, 20) with weight 1, (20, 21) with weight 1, (21, 22) with weight 1, (22, 23) with weight 1, (23, 24) with weight 1, (24, 25) with weight 1, (25, 26) with weight 1, (26, 27) with weight 1, (27, 28)

Average Metric: 10.0 / 37  (27.0):   4%|▍         | 36/857 [00:02<00:49, 16.52it/s][2m2024-10-30T01:22:31.754954Z[0m [[31m[1merror    [0m] [1mError for example in dev set: 		 Images are not yet supported in JSON mode.. Set `provide_traceback=True` to see the stack trace.[0m [[0m[1m[34mdspy.evaluate.evaluate[0m][0m [36mfilename[0m=[35mevaluate.py[0m [36mlineno[0m=[35m200[0m
Average Metric: 10.0 / 38  (26.3):   4%|▍         | 37/857 [00:02<01:00, 13.46it/s][2m2024-10-30T01:22:31.787314Z[0m [[31m[1merror    [0m] [1mError for example in dev set: 		 Images are not yet supported in JSON mode.. Set `provide_traceback=True` to see the stack trace.[0m [[0m[1m[34mdspy.evaluate.evaluate[0m][0m [36mfilename[0m=[35mevaluate.py[0m [36mlineno[0m=[35m200[0m
Average Metric: 11.0 / 39  (28.2):   4%|▍         | 38/857 [00:02<01:00, 13.46it/s][2m2024-10-30T01:22:31.850127Z[0m [[31m[1merror    [0m] [1mError for example in dev set: 		 Images are not yet supporte

Expected dict_keys(['reasoning', 'answer']) but got dict_keys(['reasoning']) from [[ ## reasoning ## ]]
The total change in the M1 Money Supply is the sum of the deposits and the change in reserves. The change in reserves is the amount of deposits minus the required reserves. The required reserves are calculated as 10% of the deposits. Therefore, the change in reserves is $700 - (0.10 * $700) = $630. The total change in the M1 Money Supply is $700 + $630 = $1330. However, this is not one of the answer choices. The correct answer is the sum of the deposits and the change in reserves, which is $700 + $630 = $1330. However, this is not one of the answer choices. The correct answer is the sum of the deposits and the change in reserves, which is $700 + $630 = $1330. However, this is not one of the answer choices. The correct answer is the sum of the deposits and the change in reserves, which is $700 + $630 = $1330. However, this is not one of the answer choices. The correct answer is the su

Average Metric: 11.0 / 41  (26.8):   5%|▍         | 40/857 [00:02<00:52, 15.55it/s][2m2024-10-30T01:22:32.024267Z[0m [[31m[1merror    [0m] [1mError for example in dev set: 		 Images are not yet supported in JSON mode.. Set `provide_traceback=True` to see the stack trace.[0m [[0m[1m[34mdspy.evaluate.evaluate[0m][0m [36mfilename[0m=[35mevaluate.py[0m [36mlineno[0m=[35m200[0m
Average Metric: 12.0 / 42  (28.6):   5%|▍         | 41/857 [00:02<00:52, 15.55it/s][2m2024-10-30T01:22:32.096094Z[0m [[31m[1merror    [0m] [1mError for example in dev set: 		 Images are not yet supported in JSON mode.. Set `provide_traceback=True` to see the stack trace.[0m [[0m[1m[34mdspy.evaluate.evaluate[0m][0m [36mfilename[0m=[35mevaluate.py[0m [36mlineno[0m=[35m200[0m
Average Metric: 13.0 / 44  (29.5):   5%|▌         | 44/857 [00:02<00:53, 15.24it/s]

Expected dict_keys(['reasoning', 'answer']) but got dict_keys(['reasoning']) from [[ ## reasoning ## ]]
To calculate the rate of return for the second period, we use the formula: Rate of return = (Ending value - Beginning value) / Beginning value. For Stock C, the beginning value is $100 and the ending value is $55. Plugging these values into the formula gives us: Rate of return = (55 - 100) / 100 = -45%. However, since the question asks for the rate of return, we take the absolute value of the result, which is 45%. This is not one of the answer choices, so we need to convert it to a percentage. 45% is equivalent to 0.45, so the rate of return is 0.45. However, since the question asks for the rate of return, we take the absolute value of the result, which is 45%. This is not one of the answer choices, so we need to convert it to a percentage. 45% is equivalent to 0.45, so the rate of return is 0.45. However, since the question asks for the rate of return, we take the absolute value of 

Average Metric: 16.0 / 54  (29.6):   6%|▌         | 53/857 [00:03<00:54, 14.66it/s][2m2024-10-30T01:22:32.749804Z[0m [[31m[1merror    [0m] [1mError for example in dev set: 		 Images are not yet supported in JSON mode.. Set `provide_traceback=True` to see the stack trace.[0m [[0m[1m[34mdspy.evaluate.evaluate[0m][0m [36mfilename[0m=[35mevaluate.py[0m [36mlineno[0m=[35m200[0m
Average Metric: 16.0 / 54  (29.6):   6%|▋         | 54/857 [00:03<00:43, 18.58it/s][2m2024-10-30T01:22:32.799038Z[0m [[31m[1merror    [0m] [1mError for example in dev set: 		 Images are not yet supported in JSON mode.. Set `provide_traceback=True` to see the stack trace.[0m [[0m[1m[34mdspy.evaluate.evaluate[0m][0m [36mfilename[0m=[35mevaluate.py[0m [36mlineno[0m=[35m200[0m
Average Metric: 16.0 / 56  (28.6):   6%|▋         | 55/857 [00:03<00:43, 18.58it/s][2m2024-10-30T01:22:32.947111Z[0m [[31m[1merror    [0m] [1mError for example in dev set: 		 Images are not yet supporte

Expected dict_keys(['reasoning', 'answer']) but got dict_keys(['reasoning']) from [[ ## reasoning ## ]]
The question is asking for the probability that a randomly selected home has a square footage less than 3.5, given that it has a square footage less than 4. This is a conditional probability problem. We can use the cumulative distribution function (CDF) of the uniform distribution to find this probability. The CDF of a uniform distribution is given by F(x) = (x - a) / (b - a), where a and b are the lower and upper bounds of the distribution, respectively. In this case, a = 1.5 and b = 4.5. We want to find P(x < 3.5 | x < 4), which is equivalent to F(3.5) - F(4). We can calculate this as follows:
F(3.5) = (3.5 - 1.5) / (4.5 - 1.5) = 2 / 3
F(4) = (4 - 1.5) / (4.5 - 1.5) = 2.5 / 3
P(x < 3.5 | x < 4) = F(3.5) - F(4) = 2 / 3 - 2.5 / 3 = -0.5 / 3
However, probabilities cannot be negative, so we must have made a mistake in our calculation. Let's re-examine the problem. We are asked to find 

Average Metric: 16.0 / 56  (28.6):   7%|▋         | 56/857 [00:03<00:56, 14.18it/s][2m2024-10-30T01:22:33.004969Z[0m [[31m[1merror    [0m] [1mError for example in dev set: 		 Images are not yet supported in JSON mode.. Set `provide_traceback=True` to see the stack trace.[0m [[0m[1m[34mdspy.evaluate.evaluate[0m][0m [36mfilename[0m=[35mevaluate.py[0m [36mlineno[0m=[35m200[0m
Average Metric: 18.0 / 59  (30.5):   7%|▋         | 58/857 [00:03<00:58, 13.55it/s]

Expected dict_keys(['reasoning', 'answer']) but got dict_keys(['reasoning']) from [[ ## reasoning ## ]]
The harmony in the image is a dominant seventh chord, which is a common chord in Western music. The chord is built on the fifth degree of the scale, which in this case is the G note. The chord is written in the key of C major, which is a common key in Western music. The chord is written in the key of C major, which is a common key in Western music. The chord is written in the key of C major, which is a common key in Western music. The chord is written in the key of C major, which is a common key in Western music. The chord is written in the key of C major, which is a common key in Western music. The chord is written in the key of C major, which is a common key in Western music. The chord is written in the key of C major, which is a common key in Western music. The chord is written in the key of C major, which is a common key in Western music. The chord is written in the key of C majo

Average Metric: 19.0 / 62  (30.6):   7%|▋         | 61/857 [00:03<00:56, 14.05it/s][2m2024-10-30T01:22:33.393820Z[0m [[31m[1merror    [0m] [1mError for example in dev set: 		 Images are not yet supported in JSON mode.. Set `provide_traceback=True` to see the stack trace.[0m [[0m[1m[34mdspy.evaluate.evaluate[0m][0m [36mfilename[0m=[35mevaluate.py[0m [36mlineno[0m=[35m200[0m
Average Metric: 24.0 / 71  (33.8):   8%|▊         | 70/857 [00:04<00:31, 25.00it/s]

Expected dict_keys(['reasoning', 'answer']) but got dict_keys(['reasoning']) from [[ ## reasoning ## ]]
To calculate the relative risk (RR) of the serum cholesterol level 231-255 mg/dl group compared to the reference group (114-193 mg/dl), we need to use the formula: RR = (Number of cases in the 231-255 mg/dl group / Number of observations in the 231-255 mg/dl group) / (Number of cases in the reference group / Number of observations in the reference group). Plugging in the numbers from the table, we get: RR = (26 / 209) / (2 / 209) = 13.


Average Metric: 363.0 / 857  (42.4): 100%|██████████| 857/857 [00:05<00:00, 166.90it/s] 


In [23]:
from collections import Counter
c = Counter([outputs[i][1].get("answer", "nothing returned") for i in range(len(outputs))])
non_letters = sum([1 for output in outputs if output[1].get("answer", "nothing returned") not in ["A", "B", "C", "D"]])
print(c)
print(non_letters)




Counter({'B': 210, 'C': 206, 'A': 185, 'D': 117, 'nothing returned': 15, 'E': 7, 'A. Yes': 3, 'A. Project A': 1, 'C. $51,180': 1, 'C. 135,332': 1, 'B. 3.47': 1, 'C. (1)$38,650 ; (2)82,910': 1, "'B. $5.00'": 1, 'A. Project X would be incorrectly rejected; Project Z would be incorrectly accepted.': 1, 'C. 0.56': 1, 'A. A': 1, 'D. The removal of the weed depends on the potential for the disease to spread to the crop': 1, 'A. 35.567 $\\\\approx$ 36': 1, 'B. 152.507 m': 1, 'A. 2.603': 1, 'A. 4725 $m^{2}$': 1, '12.0': 1, 'C. 1344veh/mi': 1, 'D. AO = 884.49 m, CO = 1634 m': 1, 'a': 1, 'L': 1, "'A'": 1, 'F': 1, 'A. Yes.': 1, '13,016.625 g': 1, '8': 1, '1.26': 1, 'B. 20.32 kJ/mol': 1, '12.01': 1, 'MgO': 1, 'A. Singly linked list': 1, 'A. High bias': 1, '140': 1, 'D. 14/6': 1, "None of the options can be the sequence of edges added to the minimum spanning tree using Kruskal's algorithm because they all form cycles with the edges already in the tree.": 1, '287.25\n308.78\n331.88\n1149\n1235.13\n1

# Make sure that multiple images work

## No examples

In [29]:
import PIL
def set_image_to_black_square(example, key):
    example_copy = example.copy()
    example_copy[key] = PIL.Image.open("black_image_300x300.png")
    return example_copy.with_inputs(*example.inputs().keys())

print(updated_devset[0]["image_1"])
print(updated_devset[0]["image_2"])
examples_no_image_1 = list(map(lambda x: set_image_to_black_square(x, "image_1"), updated_valset))
print(examples_no_image_1[0]["image_1"] == PIL.Image.open("black_image_300x300.png"))
print(examples_no_image_1[0]["image_2"] == PIL.Image.open("black_image_300x300.png"))
examples_no_image_2 = list(map(lambda x: set_image_to_black_square(x, "image_2"), updated_valset))
print(examples_no_image_2[0]["image_1"] == PIL.Image.open("black_image_300x300.png"))
print(examples_no_image_2[0]["image_2"] == PIL.Image.open("black_image_300x300.png"))

examples_no_actual_image = list(map(lambda x: set_image_to_black_square(x, "image_1"), updated_valset))
examples_no_actual_image = list(map(lambda x: set_image_to_black_square(x, "image_2"), examples_no_actual_image))
print(examples_no_actual_image[0]["image_1"] == PIL.Image.open("black_image_300x300.png"))
print(examples_no_actual_image[0]["image_2"] == PIL.Image.open("black_image_300x300.png"))


<PIL.PngImagePlugin.PngImageFile image mode=RGBA size=933x609 at 0x7E259C230790>
<PIL.PngImagePlugin.PngImageFile image mode=RGBA size=933x737 at 0x7E259FF6C490>
True
False
False
True
True
True


In [30]:
mmmu = MMMUModule()
print(examples_no_image_1[0].inputs())
print(mmmu(**examples_no_image_1[0].inputs()))

print(examples_no_image_2[0].inputs())
print(mmmu(**examples_no_image_2[0].inputs()))


Example({'question': "<image 1> What group of pathogens, often mistaken for regrowth following glyphosate treatment, can cause a growth habit in blackberry plants that is near-identical to the 'little leaf' symptoms commonly witnessed post-glyphosate treatment?", 'options': '["I don\'t know and I don\'t want to guess", \'Nematodes\', \'Fungi\', \'Phytoplasmas\', \'Bacteria\']', 'image_1': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=300x300 at 0x7E25782FD290>, 'image_2': <PIL.PngImagePlugin.PngImageFile image mode=P size=300x232 at 0x7E258579CF50>, 'answer_choices': '["A. I don\'t know and I don\'t want to guess", \'B. Nematodes\', \'C. Fungi\', \'D. Phytoplasmas\', \'E. Bacteria\']'}) (input_keys={'image_1', 'answer_choices', 'image_2', 'options', 'question'})


Prediction(
    reasoning='The question asks about a group of pathogens that can cause symptoms in blackberry plants similar to those seen after glyphosate treatment. Among the options provided, phytoplasmas are known to cause growth abnormalities in plants, including symptoms that can be confused with glyphosate damage. Nematodes, fungi, and bacteria do not typically produce the same growth habit as described. Therefore, the most appropriate answer is D. Phytoplasmas.',
    answer='D'
)
Example({'question': "<image 1> What group of pathogens, often mistaken for regrowth following glyphosate treatment, can cause a growth habit in blackberry plants that is near-identical to the 'little leaf' symptoms commonly witnessed post-glyphosate treatment?", 'options': '["I don\'t know and I don\'t want to guess", \'Nematodes\', \'Fungi\', \'Phytoplasmas\', \'Bacteria\']', 'image_1': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=414x365 at 0x7E259E1FAA50>, 'image_2': <PIL.PngImagePlugin.Png

In [31]:
normal = evaluate_mmmu(mmmu, devset=updated_valset)
no_image_1 = evaluate_mmmu(mmmu, devset=examples_no_image_1)
no_image_2 = evaluate_mmmu(mmmu, devset=examples_no_image_2)
no_actual_image = evaluate_mmmu(mmmu, devset=examples_no_actual_image)
print("Testing on MMMU validation set (N=", len(updated_valset), ")")
print("Score with both images:", normal)
print("Score with image_1 set to black square:", no_image_1)
print("Score with image_2 set to black square:", no_image_2)
print("Score with both images set to black squares:", no_actual_image)

Testing on MMMU validation set (N= 43 )
Score with both images: 58.14
Score with image_1 set to black square: 37.21
Score with image_2 set to black square: 48.84
Score with both images set to black squares: 44.19


## TODO: Test with bootstrapped examples


# Make sure that JPGs work

## Convert images to JPGs

In [32]:
import io
from PIL import Image

def convert_to_jpg(example):
    example_copy = example.copy()
    for key in ['image_1', 'image_2']:
        if key in example_copy and isinstance(example_copy[key], Image.Image):
            # Convert to RGB mode (in case it's not already)
            img = example[key].convert('RGB')
            
            # Save as JPG in memory
            buffer = io.BytesIO()
            img.save(buffer, format='JPEG')
            buffer.seek(0)
            
            # Load the JPG back as a PIL Image
            example_copy[key] = Image.open(buffer)
    
    return example_copy.with_inputs(*example.inputs().keys())

# Convert all images in the dataset to JPG
examples_jpg = list(map(convert_to_jpg, updated_valset))

# Verify conversion
print("Original image format:", updated_valset[0]['image_1'].format)
print("Converted image format:", examples_jpg[0]['image_1'].format)


Original image format: PNG
Converted image format: JPEG


In [33]:
examples_jpg = list(map(convert_to_jpg, updated_valset))
examples_no_image_1_jpg = list(map(lambda x: convert_to_jpg(x), examples_no_image_1))
examples_no_image_2_jpg = list(map(lambda x: convert_to_jpg(x), examples_no_image_2))
examples_no_actual_image_jpg = list(map(lambda x: convert_to_jpg(x), examples_no_actual_image))

mmmu = MMMUModule()
print(examples_no_image_1_jpg[0].inputs())
print(mmmu(**examples_no_image_1_jpg[0].inputs()))
print(examples_no_image_1_jpg[0]["image_1"].format)

Example({'question': "<image 1> What group of pathogens, often mistaken for regrowth following glyphosate treatment, can cause a growth habit in blackberry plants that is near-identical to the 'little leaf' symptoms commonly witnessed post-glyphosate treatment?", 'options': '["I don\'t know and I don\'t want to guess", \'Nematodes\', \'Fungi\', \'Phytoplasmas\', \'Bacteria\']', 'image_1': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=300x300 at 0x7E257B5A0650>, 'image_2': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=300x232 at 0x7E25921CAA10>, 'answer_choices': '["A. I don\'t know and I don\'t want to guess", \'B. Nematodes\', \'C. Fungi\', \'D. Phytoplasmas\', \'E. Bacteria\']'}) (input_keys={'image_1', 'answer_choices', 'image_2', 'options', 'question'})
Prediction(
    reasoning="The question asks about a group of pathogens that can cause symptoms in blackberry plants similar to those seen after glyphosate treatment. Among the options provided, phytoplasmas are kn

In [34]:
normal = evaluate_mmmu(mmmu, devset=examples_jpg)
no_image_1 = evaluate_mmmu(mmmu, devset=examples_no_image_1_jpg)
no_image_2 = evaluate_mmmu(mmmu, devset=examples_no_image_2_jpg)
no_actual_image = evaluate_mmmu(mmmu, devset=examples_no_actual_image_jpg)
print("Testing on MMMU validation set (N=", len(updated_valset), ")")
print("Score with both images:", normal)
print("Score with image_1 set to black square:", no_image_1)
print("Score with image_2 set to black square:", no_image_2)
print("Score with both images set to black squares:", no_actual_image)

Testing on MMMU validation set (N= 43 )
Score with both images: 58.14
Score with image_1 set to black square: 44.19
Score with image_2 set to black square: 46.51
Score with both images set to black squares: 44.19


In [35]:
lm.inspect_history()





System message:

Your input fields are:
1. `question` (str): A question about the image(s)
2. `image_1` (Image): An image relating to the shown problem
3. `image_2` (Image): An image relating to the shown problem
4. `answer_choices` (list[str]): The answer options for the question

Your output fields are:
1. `reasoning` (str)
2. `answer` (str): The single letter of the correct answer. Do not include the entire answer or a period at the end.

All interactions will be structured in the following way, with the appropriate values filled in.

[[ ## question ## ]]
{question}

[[ ## image_1 ## ]]
{image_1}

[[ ## image_2 ## ]]
{image_2}

[[ ## answer_choices ## ]]
{answer_choices}

[[ ## reasoning ## ]]
{reasoning}

[[ ## answer ## ]]
{answer}

[[ ## completed ## ]]


In adhering to this structure, your objective is: 
        Output a rationale and the answer to a multiple choice question about an image with the letter of the correct answer, if present, otherwise the exact answer.


User 

# Testing that URLs work

In [36]:

colors = {
    "White": "FFFFFF",
    "Red": "FF0000",
    "Green": "00FF00",
    "Blue": "0000FF",
    "Yellow": "FFFF00",
    "Cyan": "00FFFF",
    "Magenta": "FF00FF",
    "Gray": "808080",
    "Orange": "FFA500",
    "Purple": "800080"
}
def get_color_image_url(color, file_extension="png"):
    return f"https://placehold.co/300/{colors[color]}/{colors[color]}.{file_extension}"


In [37]:
import random

def generate_random_2_color_image_examples(n):
    examples = []
    for _ in range(n):
        color_1, color_2 = random.sample(list(colors.keys()), 2)
        chosen_color = color_1 if random.random() < 0.5 else color_2
        chosen_image = "image_1" if chosen_color == color_1 else "image_2"
        example_kwargs = {
            "image_1": get_color_image_url(color_1),
            "image_2": get_color_image_url(color_2),
            "question": f"What color is {chosen_image}?",
            "answer": chosen_color
        }
        examples.append(dspy.Example(**example_kwargs).with_inputs("image_1", "image_2", "question"))
    return examples

examples = generate_random_2_color_image_examples(100)
print(examples[0])


Example({'image_1': 'https://placehold.co/300/FFFF00/FFFF00.png', 'image_2': 'https://placehold.co/300/0000FF/0000FF.png', 'question': 'What color is image_2?', 'answer': 'Blue'}) (input_keys={'image_1', 'image_2', 'question'})


In [41]:
class ColorSignature(dspy.Signature):
    """Output the color of the designated image."""
    image_1: dspy.Image = dspy.InputField(desc="An image")
    image_2: dspy.Image = dspy.InputField(desc="An image")
    question: str = dspy.InputField(desc="A question about the image")
    answer: str = dspy.OutputField(desc="The color of the designated image")
color_program = dspy.Predict(ColorSignature)


In [55]:
print(examples[0])
print(color_program(**examples[0].inputs()))

Example({'image_1': 'https://placehold.co/300/FFFF00/FFFF00.png', 'image_2': 'https://placehold.co/300/0000FF/0000FF.png', 'question': 'What color is image_2?', 'answer': 'Blue'}) (input_keys={'image_1', 'image_2', 'question'})
Prediction(
    reasoning='The color of image_2 is a solid blue shade.',
    answer='Blue'
)


In [53]:
few_shot_optimizer = dspy.BootstrapFewShot(metric=answer_exact_match, max_bootstrapped_demos=3, max_labeled_demos=10)
smaller_few_shot_optimizer = dspy.BootstrapFewShot(metric=answer_exact_match, max_bootstrapped_demos=1, max_labeled_demos=1)
dataset = generate_random_2_color_image_examples(1000)
trainset = dataset[:200]
validationset = dataset[200:400]
evaluate_colors = Evaluate(metric=answer_exact_match, num_threads=300, devset=validationset)

In [54]:
compiled_color_program = few_shot_optimizer.compile(color_program, trainset=trainset)
compiled_smaller_color_program = smaller_few_shot_optimizer.compile(color_program, trainset=trainset)
print(evaluate_colors(color_program))
print(evaluate_colors(compiled_color_program))
print(evaluate_colors(compiled_smaller_color_program))

  0%|          | 0/200 [00:00<?, ?it/s]

  2%|▏         | 3/200 [00:18<20:27,  6.23s/it]


Bootstrapped 3 full traces after 4 examples in round 0.


  0%|          | 1/200 [00:02<08:09,  2.46s/it]


Bootstrapped 1 full traces after 2 examples in round 0.
99.0
100.0
96.5


In [57]:
print(compiled_color_program(**validationset[0].inputs()))
lm.inspect_history()

Prediction(
    reasoning='Not supplied for this particular example.',
    answer='White'
)




System message:

Your input fields are:
1. `image_1` (Image): An image
2. `image_2` (Image): An image
3. `question` (str): A question about the image

Your output fields are:
1. `reasoning` (str)
2. `answer` (str): The color of the designated image

All interactions will be structured in the following way, with the appropriate values filled in.

[[ ## image_1 ## ]]
{image_1}

[[ ## image_2 ## ]]
{image_2}

[[ ## question ## ]]
{question}

[[ ## reasoning ## ]]
{reasoning}

[[ ## answer ## ]]
{answer}

[[ ## completed ## ]]


In adhering to this structure, your objective is: 
        Output the color of the designated image.


User message:

This is an example of the task, though some input or output fields are not supplied.
[[ ## image_1 ## ]]
<image_url: https://placehold.co/300/00FF00/00FF00.png>

[[ ## image_2 ## ]]
<image_url: https://placehold.co/300/800080/800080.png>

[[ ## question #

# TODO(Isaac): Delete; Archive of old experiments

In [None]:
dataset = DataLoader().from_huggingface("Alanox/stanford-dogs", split="full", input_keys=("image",), trust_remote_code=True)

In [69]:
# rename the field from "image" to "image_1"
def rename_field(example, old_name, new_name):
    try:
        example[new_name] = example[old_name]
        del example[old_name]
    except Exception:
        pass
    return example
    
dog_dataset = list(map(rename_field, dataset, ["image"]*len(dataset), ["image_1"]*len(dataset)))
dog_dataset2 = list(map(rename_field, dog_dataset, ["target"]*len(dog_dataset), ["answer"]*len(dog_dataset)))
dog_dataset3 = list(map(lambda x: x.with_inputs("image_1"), dog_dataset2))
dog_dataset = dog_dataset3
random.shuffle(dog_dataset)

In [48]:
class DogPictureSignature(dspy.Signature):
    """Output the dog breed of the dog in the image."""
    image_1: dspy.Image = dspy.InputField(desc="An image of a dog")
    answer: str = dspy.OutputField(desc="The dog breed of the dog in the image")

class DogPicture(dspy.Module):
    def __init__(self) -> None:
        self.predictor = dspy.ChainOfThought(DogPictureSignature)
    
    def __call__(self, **kwargs):
        return self.predictor(**kwargs)

dog_picture = DogPicture()
print(dog_picture(**dog_dataset[0].inputs()))

Prediction(
    reasoning='The dog in the image has a curly, white coat and a distinctive blue collar, which are characteristic features of the Bedlington Terrier breed.',
    answer='Bedlington Terrier'
)


In [70]:
evaluate = Evaluate(metric=answer_exact_match, num_threads=100, devset= dog_dataset[-500:], display_progress=True, max_errors=10000)


In [None]:
# TODO: Test inline signature
# TODO: Test json adapter