In [1]:
import uuid
import json

from typing import NamedTuple
from random import randint

from tqdm.notebook import tqdm
from dicttoxml import dicttoxml

from pydantic import BaseModel

from google import genai
from google.genai.types import Part

In [2]:
class Error(NamedTuple):
    idx: int
    key: str
    true: str
    predicted: str

class Errors(NamedTuple):
    rate: float
    errors: list[Error]

In [3]:
class Response(BaseModel):
    value: int

In [4]:
client = genai.Client(vertexai=True, location='us-east5')

In [5]:
def fill_level(level: int, count: int = 10):
    if level == 1:
        l = {}
        d = {}
        for idx in range(count):
            key = str(uuid.uuid4())
            value = randint(0, 1000)
            l[key] = value
            d[key] = value
    else:
        l = {}
        d = {}
        for idx in range(count):
            d1, l1 = fill_level(level -1, count)
            l.update(l1)
            d[str(uuid.uuid4())] = d1

    return d, l

In [6]:
def analyze(data, values, model='gemini-2.0-flash-001', mode='JSON'):
    e_counter = 0
    e_list = []

    for idx, key in enumerate(tqdm(values)):

        if mode == 'JSON':
            part = json.dumps(data)
        if mode == 'XML':
            part = dicttoxml(data)
        
        contents = [
            Part.from_text('<DATA>'),
            Part.from_text(part),
            Part.from_text('</DATA>'),
            Part.from_text(f'Extract information from DATA. What is value for key {key}. Respond with the value only')
        ]

        response = client.models.generate_content(
            model='gemini-2.0-flash-001',
            contents=contents,
            config={
                'response_mime_type': 'application/json',
                'response_schema': Response,
            }
        )

        if response.parsed.value != values[key]:
            e_counter += 1
            e_list.append(Error(idx=idx, key=key, true=values[key], predicted=response.candidates[0].content.parts[0].text.strip()))

    return Errors(
            rate = e_counter / len(values),
            errors = e_list)

In [7]:
MODEL = 'gemini-1.5-pro'

In [8]:
data, values = fill_level(1, 1000)
errs = analyze(data, values, model=MODEL, mode='XML')
print(f'Error rate for 1 level = {errs.rate*100}%')

  0%|          | 0/1000 [00:00<?, ?it/s]

Error rate for 1 level = 2.1999999999999997%


In [9]:
data, values = fill_level(2, 32)
errs = analyze(data, values, model=MODEL, mode='XML')
print(f'Error rate for 2 level = {errs.rate*100}%')

  0%|          | 0/1024 [00:00<?, ?it/s]

Error rate for 2 level = 3.125%


In [10]:
data, values = fill_level(3, 10)
errs = analyze(data, values, model=MODEL, mode='XML')
print(f'Error rate for 3 level = {errs.rate*100}%')

  0%|          | 0/1000 [00:00<?, ?it/s]

Error rate for 3 level = 4.1000000000000005%


In [11]:
response = client.models.generate_content(
    model='gemini-1.5-pro',
    contents='what is 1+1?',
)