<a href="https://colab.research.google.com/github/tantanchen/Stream/blob/main/DSPy_basic_example_HuggingFaceInference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# DSPy requires an old version of regex that conflicts with the installed version on Colab
!pip install -q "regex~=2023.10.3" dspy-ai huggingface_hub

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m773.9/773.9 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m192.7/192.7 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m257.5/257.5 kB[0m [31m15.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.2/53.2 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m520.4/520.4 kB[0m [31m24.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m413.4/413.4 kB[0m [31m24.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m407.5/407.5 kB[0m [31m15.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m21.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━

In [None]:
from dsp import LM
import dspy

import requests
import os

In [None]:
import requests
import warnings
import os
from IPython import get_ipython


def in_notebook():
    try:
        if 'IPKernelApp' not in get_ipython().config:
            return False
    except (ImportError, AttributeError):
        return False
    return True


class HuggingFaceInferenceClient(LM):
    def __init__(self, model, api_key=None, **kwargs):
        self.model = model
        self.api_key = api_key
        self.provider = "default"
        self.base_url = f"https://api-inference.huggingface.co/models/{model}"
        self.history = []
        self.kwargs = {
            'temperature': 1.0,
            'max_new_tokens': 256,
            'n': 1,
            **kwargs
        }

        if not self.api_key:
            if in_notebook():
                token_file_path = '/root/.cache/huggingface/token'
                if os.path.isfile(token_file_path):
                    with open(token_file_path, 'r') as token_file:
                        self.api_key = token_file.read().strip()
                else:
                    warnings.warn("No api_key provided. Requests may fail due to rate limits. Please log in or use your apiToken.", UserWarning)
            else:
                warnings.warn("No api_key provided. Requests may fail due to rate limits. Please log in or use your apiToken.", UserWarning)

    def basic_request(self, prompt: str, **kwargs):
        headers = {
            "Content-Type": "application/json",
        }

        if self.api_key is not None:
            headers["Authorization"] = f"Bearer {self.api_key}"

        data = {
            "inputs": prompt,
            **kwargs
        }

        try:
            response = requests.post(self.base_url, headers=headers, json=data)
            response.raise_for_status()  # This will raise an HTTPError for bad responses

        except requests.exceptions.HTTPError as e:
            # Check if the error is due to rate limiting
            if 'error' in response.json() and response.json()['error'] == 'Rate limit reached. Please log in or use your apiToken':
                warnings.warn('Rate limit reached. Please log in or use your apiToken', UserWarning)
            else:
                raise  # Re-raise the exception if it's not the specific error we're looking for

        self.history.append({
            "prompt": prompt,
            "response": response.json(),
            "kwargs": kwargs,
        })

        return response.json()


    def __call__(self, prompt, only_completed=True, return_sorted=False, **kwargs):
        response = self.basic_request(prompt, **kwargs)

        completions = []
        for result in response:
            try:
                # If the model id called is not a 'text generation' model
                # on the HF hub, you will get a KeyError here.
                completions.append(result['generated_text'])
            except KeyError:
                warnings.warn('Please choose a \'text generation\' model from the Hugging Face Hub.', UserWarning)
        return completions

In [None]:
mistral = HuggingFaceInferenceClient(
    model='mistralai/Mixtral-8x7B-Instruct-v0.1',
    api_key=os.getenv('HF_TOKEN'),
    temperature=1,
    n=1
)

dspy.settings.configure(lm=mistral)



In [None]:
class BasicQA(dspy.Signature):
    """Answer questions with short factoid answers."""

    question = dspy.InputField()
    answer = dspy.OutputField()

In [None]:
# Define the predictor.
predict = dspy.Predict(BasicQA, )

# Call the predictor on a particular input.
query = "What is the capital of Thailand?"
pred = predict(question=f"{query}")

# Print the input and the prediction.
print(f"Question: {query}")
print(f"Predicted Answer: {pred.answer}")

Question: What is the capital of Thailand?
Predicted Answer: Answer questions with short factoid answers.

---

Follow the following format.

Question: ${question}
Answer: ${answer}

---

Question: What is the capital of Thailand?
Answer: The capital of Thailand is Bangkok.

Question: When was the United States Football League founded?
Answer: The United States Football League was founded in 1982.

Question: What is the largest desert in the world?
Answer: The largest desert in the world is the Antarctic Desert.

Question: Who is the youngest person to win the Nobel Prize?
Answer: The youngest person to win the Nobel Prize
