In [None]:
import sys 
sys.path.append("..")

In [None]:
import os

import math

import copy

import collections

import logging

from typing import Iterable, List, Dict, Any, Tuple

import backoff

import openai

from tqdm import tqdm

import pandas as pd

import numpy as np

from src.models.apimodel import APIModel

from src.demonstrations import *

from src.datasets import HateXplainRace

from src.utils import metrics

from src.__main__  import build_demonstration


logger = logging.getLogger(__name__ + ".models")
logging.getLogger("openai").setLevel(logging.WARNING)

In [None]:
def merge_dicts(*dicts):
    res = collections.defaultdict(list)
    for d in dicts:
        for k, v in d.items():
            res[k].append(v)
    return res

In [None]:
class GPTLogProb(APIModel):
    """Code modified from
    https://github.com/isabelcachola/generative-prompting/blob/main/genprompt/models.py
    """

    def __init__(self, model_name: str, temperature: float = 1, max_tokens: int = 5):

        super().__init__(model_name, temperature, max_tokens)

        openai.api_key = os.environ["OPENAI_API_KEY"]
        self.batch_size = 20

    @backoff.on_exception(
        backoff.expo,
        (
            openai.error.RateLimitError,
            openai.error.APIError,
            openai.error.Timeout,
            openai.error.ServiceUnavailableError,
        ),
    )
    def get_response(self, prompt: Iterable[str]) -> Dict[str, Any]:
        """Overloaded get_response to deal with batching

        :param prompt: prompts as batch
        :type prompt: Iterable[str]
        :return: responses from GPT3 API endpoint
        :rtype: Dict[str, Any]
        """
        response = openai.Completion.create(
            model=self.model_name,
            prompt=prompt,
            temperature=self.temperature,
            max_tokens=self.max_tokens,
            logprobs=5
        )

        return response

    def format_response(self, response: Dict[str, Any]) -> Tuple[str, Dict[str, float]]:
        text = response["text"].replace("\n", " ").strip()
        top_logprobs = response["logprobs"]["top_logprobs"]

        output = (text, top_logprobs)

        return output

    def generate_from_prompts(self, examples: Iterable[str]) -> List[str]:
        lines_length = len(examples)
        logger.info(f"Num examples = {lines_length}")
        i = 0

        responses = []

        for i in tqdm(range(0, lines_length, self.batch_size), ncols=0):

            # batch prompts together
            prompt_batch = examples[i : min(i + self.batch_size, lines_length)]
            try:
                # try to get respones
                response = self.get_response(prompt_batch)

                print(response)

                response_batch = [""] * len(prompt_batch)

                # order the responses as they are async
                for choice in response.choices:
                    response_batch[choice.index] = self.format_response(choice.text)

                responses.extend(response_batch)

            # catch any connection exceptions
            except:

                # try each prompt individually
                for i in range(len(prompt_batch)):
                    try:
                        _r = self.get_response(prompt_batch[i])["choices"][0]
                        line = self.format_response(_r)
                        responses.append(line)
                    except:
                        # if there is an exception make blank
                        l_prompt = len(prompt_batch[i])
                        _r = self.get_response(prompt_batch[i][l_prompt - 2000 :])[
                            "choices"
                        ][0]
                        line = self.format_response(_r)
                        responses.append(line)

        return responses


In [None]:
hate = HateXplainRace('../data/HateXplain')

In [None]:
train_df, test_df, overall_demographics = hate.create_prompts()

In [None]:
gpt = GPTLogProb("text-davinci-003")

In [None]:
demonstrations = ["within", "similarity"]

In [None]:
outputs = []

labels = test_df["labels"].tolist()

for demonstration in demonstrations:
    prompts, filtered_test_df, sampler_type = build_demonstration(demonstration, {"shots" : 5}, train_df, test_df, overall_demographics)
    
    responses = gpt.generate_from_prompts(prompts)

    text_responses = [i[0] for i in responses]

    preds_clean = copy.deepcopy(text_responses)

    # clean up predictions
    preds_clean = [x.lower() for x in preds_clean]

    conv = lambda i: i or ""
    preds_clean = [conv(i) for i in preds_clean]

    # create list of all labels
    labels_set = list(set(test_df["labels"].tolist()))

    # map labels to numbers to make it easier for sklearn calculations
    labels_dict = dict(zip(labels_set, range(len(labels_set))))

    # map the labels lists to dummy labels
    dummy_labels = [labels_dict[x] for x in test_df["labels"].tolist()]

    dummy_preds = []

    for pred in preds_clean:

        # see if any of the labels are in the response
        for label in labels_set:
            if pred.find(label) != -1:
                dummy_preds.append(labels_dict[label])
                break
            # if not we add -1 instead
        else:
            dummy_preds.append(-1)

    dummy_preds = np.array(dummy_preds)
    dummy_labels = np.array(dummy_labels)

    incorrect = (dummy_preds != dummy_labels).nonzero()[0]

    responses_incorrect = [(text_responses[i], labels[i]) for i in incorrect]

    total = 0

    differences = []

    for i in range(len(responses_incorrect)):
        response = responses_incorrect[i]

        label = response[1]

        response_openai = responses_incorrect[i][0][1]

        response_dict = dict()

        response_dict = merge_dicts(*response_openai)
        
        response_dict = {k: max(v) for (k,v) in response_dict.items()}

        contains_label = [x for x in list(response_dict.keys()) if label in x.lower()]

        if len(contains_label) != 0:

            total+=1

            best_log_prob_label = math.e**max([response_dict[x] for x in contains_label])

            pred_log_prob = math.e**max(response_dict.values())


            differences.append(best_log_prob_label-pred_log_prob)
    
    print(total)
    print(sum(differences)/len(differences))
    print(total/len(responses_incorrect))
    
    