diff --git a/dsp/modules/gpt3.py b/dsp/modules/gpt3.py index 7c677ad712..49daf6f876 100644 --- a/dsp/modules/gpt3.py +++ b/dsp/modules/gpt3.py @@ -1,7 +1,20 @@ +import logging +from logging.handlers import RotatingFileHandler + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(message)s', + handlers=[ + logging.FileHandler('openai_usage.log') + ] +) + import functools import json from typing import Any, Literal, Optional, cast +import dsp import backoff import openai @@ -98,6 +111,13 @@ def __init__( def _openai_client(self): return openai + def log_usage(self, response): + """Log the total tokens from the OpenAI API response.""" + usage_data = response.get('usage') + if usage_data: + total_tokens = usage_data.get('total_tokens') + logging.info(f'{total_tokens}') + def basic_request(self, prompt: str, **kwargs): raw_kwargs = kwargs @@ -169,6 +189,9 @@ def __call__( response = self.request(prompt, **kwargs) + if dsp.settings.log_openai_usage: + self.log_usage(response) + choices = response["choices"] completed_choices = [c for c in choices if c["finish_reason"] != "length"] diff --git a/dsp/utils/settings.py b/dsp/utils/settings.py index 228a960a2d..3b2d3eaf00 100644 --- a/dsp/utils/settings.py +++ b/dsp/utils/settings.py @@ -37,6 +37,7 @@ def __new__(cls): skip_logprobs=False, trace=None, release=0, + log_openai_usage=False, bypass_assert=False, bypass_suggest=False, langchain_history=[]