In [None]:
cd ../

In [None]:
root_path = '.'

import os
import logging
os.environ["DSP_NOTEBOOK_CACHEDIR"] = os.path.join(root_path, 'cache')
os.environ["DSP_CACHE_SQLITE_PATH"] = "cache.db"
os.environ["DSP_LOGGING_LEVEL"] = str(logging.DEBUG)

import dsp
import openai
import hashlib
from dsp.utils.cache import sqlite_cache_wrapper
from dsp.utils.cache import sqlite_cache_splitter

# openai_key = os.getenv('OPENAI_API_KEY')  # or replace with your API key
colbert_server = 'http://ec2-44-228-128-229.us-west-2.compute.amazonaws.com:8893/api/search'

lm = dsp.GPT3(model='gpt-3.5-turbo', model_type="chat")
rm = dsp.ColBERTv2(url=colbert_server)


In [None]:
from datetime import datetime, timedelta
import time
import random
import string


### SQLite cache unit tests

In [None]:
@sqlite_cache_splitter
@sqlite_cache_wrapper
def test_function_cached(**kwargs):
    kwargs_ = ','.join(f'{key}={value}' for key, value in kwargs.items())
    return hashlib.sha256(kwargs_.encode()).hexdigest()

def test_function(**kwargs):
    kwargs_ = ','.join(f'{key}={value}' for key, value in kwargs.items())
    return hashlib.sha256(kwargs_.encode()).hexdigest()

In [None]:
def generate_random_string(length):
    # Define the pool of characters to choose from
    characters = string.ascii_letters + string.digits  

    # Generate a random string of given length
    random_string = ''.join(random.choices(characters, k=length))
    return random_string

def generate_random_value():
    value_type = random.choice(['long_string', 'short_string', 'float', 'int', 'list'])

    if value_type == 'long_string':
        return generate_random_string(random.randint(20, 50))
    elif value_type == 'short_string':
        return generate_random_string(random.randint(5, 10))
    elif value_type == 'float':
        return random.uniform(0.0, 100.0)
    elif value_type == 'int':
        return random.randint(1, 100)
    elif value_type == 'list':
        list_length = random.randint(1, 5)
        return [generate_random_value() for _ in range(list_length)]

def generate_random_dict(num_items):
    random_dict = {}
    for _ in range(num_items):
        key = generate_random_string(random.randint(5, 10))
        value = generate_random_value()
        random_dict[key] = value
    return random_dict

kwargs = [generate_random_dict(random.randint(0, 30)) for _ in range(50)]

print(kwargs)

In [None]:
# test write performance
start_time = time.time()
for kwarg in kwargs:
    test_function_cached(**kwarg)
end_time = time.time()
print(f'Cached function writes took {end_time - start_time} seconds')

# test read performance + save outputs
outputs = []
start_time2 = time.time()
for kwarg in kwargs:
    outputs.append(test_function_cached(**kwarg))
end_time2 = time.time()
print(f'Cached function reads took {end_time2 - start_time2} seconds')


# test correctness
for i, kwarg in enumerate(kwargs):
    assert test_function(**kwarg) == outputs[i]




In [None]:
# threaded tests
import threading 

@sqlite_cache_splitter
@sqlite_cache_wrapper
def test_function_cached2(**kwargs):
    kwargs_ = ','.join(f'{key}={value}' for key, value in kwargs.items())
    time.sleep(5)
    return hashlib.sha256(kwargs_.encode()).hexdigest()

def run_function_in_thread(kwargs):
    result = test_function_cached2(**kwargs)
    
start_time = time.time()
# thread_results = []
threads = []

kwargs = [{"kwarg1": 1, "kwarg2": 2, "kwarg3": 5}, {"kwarg1": 1, "kwarg2": 2, "kwarg3": 5}, {"kwarg1": 1, "kwarg2": 2, "kwarg3": 5}, {"kwarg1": 1, "kwarg2": 2, "kwarg3": 5}]

# Create a thread for each call to test_function_cached
for kwarg in kwargs:
    thread = threading.Thread(target=run_function_in_thread, args=(kwarg,))
    threads.append(thread)
    thread.start()

# Wait for all threads to complete
for thread in threads:
    thread.join()

end_time = time.time()

print(f'Cached function writes (threaded) took {end_time - start_time} seconds')

# test read performance + save outputs
outputs = []
start_time2 = time.time()
threads = []

# Create a thread for each call to test_function_cached
for kwarg in kwargs:
    thread = threading.Thread(target=run_function_in_thread, args=(kwarg,))
    threads.append(thread)
    thread.start()

# Wait for all threads to complete
for thread in threads:
    thread.join()

end_time2 = time.time()
print(f'Cached function reads (threaded) took {end_time2 - start_time2} seconds')

In [None]:
# check correctness threaded
for kwarg in kwargs:
    ref = test_function(**kwarg)
    cached = test_function_cached(**kwarg)
    assert ref == cached

### SQLite cache integration tests


###### Case 1: 
- Example doesn't exists in the cache and the experiment end timestamp is not set to the future
- Expected:
    - Runs into an exception since it affects reproducibility. The example is missing in that timerange and should not be computed.

In [None]:
case1_cache_end_timerange = datetime.now().timestamp()

with dsp.settings.context(lm=lm, experiment_end_timestamp=case1_cache_end_timerange):
    answer = dsp.settings.lm.basic_request(prompt="Q: At which year was synchronized swimming considered as a valid olympic sport?\nA:")
    print(answer)

###### Case 2: 
- Example doesn't exists in the cache and the experiment end timestamp is set to the future (by default)
- Expected:
    - The experiment is not trying to reproduce but create new reproducible results. Therefore, it is allowed to re-compute

In [None]:

with dsp.settings.context(lm=lm):
    answer = dsp.settings.lm.basic_request(prompt="Q: At which year was synchronized swimming considered as a valid olympic sport?\nA:")
    print(answer)

###### Case 3: 
- Example already exists in the cache within the specified timerange
- Expected:
    - Returns the results of the example without re-computing

In [None]:
case3_cache_end_timerange = (datetime.now() + timedelta(days=2)).timestamp()

with dsp.settings.context(lm=lm, experiment_end_timestamp=case3_cache_end_timerange):
    answer = dsp.settings.lm.basic_request(prompt="Q: At which year was synchronized swimming considered as a valid olympic sport?\nA:")
    print(answer)

###### Case 4: 
- There exists an example specified in the timerange but it failed for some reason.
- Expected:
    - Returns the exact error that occured, for the sake of reproducibility

In [None]:
# Let's intentionally make it fail
case4_cache_end_timerange = (datetime.now() + timedelta(seconds=3)).timestamp()
try:
    with dsp.settings.context(lm=lm, experiment_end_timestamp=case4_cache_end_timerange):
        answer = dsp.settings.lm.basic_request(prompt="Q: At which year was synchronized swimming considered as a valid olympic sport?\nA:", dummy_kwargs="sample")
finally:
    time.sleep(3)

In [None]:
# Should fail with the same exception
case4_cache_end_timerange = datetime.now().timestamp()

with dsp.settings.context(lm=lm, experiment_end_timestamp=case4_cache_end_timerange):
    answer = dsp.settings.lm.basic_request(prompt="Q: At which year was synchronized swimming considered as a valid olympic sport?\nA:", dummy_kwargs="sample")

# concurrency test

In [None]:
import threading
case5_cache_end_timerange = (datetime.now() + timedelta(days=2)).timestamp()

    
def run_snippet():
    with dsp.settings.context(lm=lm, experiment_end_timestamp=case5_cache_end_timerange):
        answer = dsp.settings.lm.basic_request(prompt="Q: which war did Archimedes invent war machines for?\nA:")

threads = []
for _ in range(10):
    thread = threading.Thread(target=run_snippet)
    thread.start()
    threads.append(thread)

for thread in threads:
    thread.join()
