In [24]:
import re
import numpy as np
from openai import OpenAI
import colorsys  # Import for color calculations
from IPython.display import display, HTML
from openai.resources.chat.completions import ChatCompletion

client = OpenAI()

system_prompt = """You are Jeff Dean. We need to compare two git patches to determine whether they are functionally equivalent or different. Ignore changes such as logs or comments. Ignore implementation differences if the resulting behavior is the same.\n\nBe exceptionally meticulous to determine whether or not they perform the same functional.\n\nRespond in the following format:\n\n<summary_of_patch_one>\nsummarize the first patch's functional changes\n</summary_of_patch_one>\n\n<summary_of_patch_two>\nsummarize the second patch's functional changes\n</summary_of_patch_two>\n\n<equivalent_or_different>\nwhether the patches are functionally equivalent or different. answer in one word.\n</equivalent_or_different>"""

example_prompt = """<patch_one>
--- 
+++ 
@@ -516,17 +516,21 @@
     def __eq__(self, other):
         # Needed for @total_ordering
         if isinstance(other, Field):
+            if hasattr(self, 'model') and hasattr(other, 'model') and self.model != other.model:
+                return False
             return self.creation_counter == other.creation_counter
         return NotImplemented

     def __lt__(self, other):
         # This is needed because bisect does not take a comparison function.
         if isinstance(other, Field):
+            if hasattr(self, 'model') and hasattr(other, 'model') and self.model != other.model:
+                return self.model._meta.label < other.model._meta.label
             return self.creation_counter < other.creation_counter
         return NotImplemented

     def __hash__(self):
-        return hash(self.creation_counter)
+        return hash((self.model._meta.label, self.creation_counter)) if hasattr(self, 'model') else hash(self.creation_counter)

     def __deepcopy__(self, memodict):
         # We don't have to deepcopy very much here, since most things are not
</patch_one>

<patch_two>
--- 
+++ 
@@ -516,17 +516,21 @@
     def __eq__(self, other):
         # Needed for @total_ordering
         if isinstance(other, Field):
+            if hasattr(self, 'model') and hasattr(other, 'model') and self.model != other.model:
+                return False
             return self.creation_counter == other.creation_counter
         return NotImplemented

     def __lt__(self, other):
         # This is needed because bisect does not take a comparison function.
         if isinstance(other, Field):
+            if hasattr(self, 'model') and hasattr(other, 'model') and self.model != other.model:
+                return self.model._meta.label < other.model._meta.label
             return self.creation_counter < other.creation_counter
         return NotImplemented

     def __hash__(self):
-        return hash(self.creation_counter)
+        return hash((self.model._meta.label, self.creation_counter)) if hasattr(self, 'model') else hash(self.creation_counter)

     def __deepcopy__(self, memodict):
         # We don't have to deepcopy very much here, since most things are not
</patch_two>"""

# parse diff_comparison.csv using csvreader to get equivalent_patch, non_equivalent_patch, and patch_to_compare
def parse_diff_comparison():
    equivalent_patch = []
    non_equivalent_patch = []
    patch_to_compare = []
    with open('tests/diff_comparison.csv', 'r') as file:
        for line in file.readlines()[1:]:
            line = line.split('<diff>')
            equivalent_patch.append(line[0])
            non_equivalent_patch.append(line[1])
            patch_to_compare.append(line[2])
    return equivalent_patch, non_equivalent_patch, patch_to_compare


def generate_call_from_patches(patch_one, patch_two):
    response = client.chat.completions.create(
        model="gpt-4-turbo-preview",
        messages=[
            {
            "role": "system",
            "content": system_prompt
            },
            {
            "role": "user",
            # "content": f"<patch_one>{patch_one}</patch_one>\n\n<patch_two>{patch_two}</patch_two>"
            "content": example_prompt
            },
        ],
        temperature=0.2,
        max_tokens=512,
        top_p=1,
        frequency_penalty=0,
        presence_penalty=0,
        logprobs=True,
        top_logprobs=2,
    )
    acc = {np.round(np.exp(logprob.logprob)*100,2): logprob.token for logprob in response.choices[0].logprobs.content[0].top_logprobs}
    logprobs = response.choices[0].logprobs.content[0].top_logprobs
    response_content = response.choices[0].message.content
    
    # regex match xml tags of equivalent_or_different
    match = re.search(r'<equivalent_or_different>(.*?)</equivalent_or_different>', response_content, re.DOTALL)
    equivalent_or_different = match.group(1) if match else None
    return equivalent_or_different, response

equivalent_or_different, response = generate_call_from_patches(example_prompt, example_prompt)

In [25]:
import matplotlib

colormap = matplotlib.colormaps['coolwarm']  # Get the Viridis colormap

def get_color(position):
    """Helper function to get a color from the coolwarm colormap.

    Args:
        position: A float between 0.0 and 1.0 representing position along the gradient.

    Returns:
        A CSS-compatible RGB color string.
    """
    color = colormap(position)  # Get RGBA (A = alpha, ignored)
    r, g, b = [int(x * 255) for x in color[:3]]  # Convert to 0-255 range
    return f"rgb({r}, {g}, {b})"

def highlight_text(api_response: ChatCompletion):
    tokens = api_response.choices[0].logprobs.content
    # show all token probabilities
    print([np.exp(token.logprob) for token in tokens])

    html_output = ""
    num_tokens = len(tokens)

    for token in tokens:
        token_str = bytes(token.bytes).decode("utf-8")
        h = np.exp(token.logprob)
        # Color based on normalized position
        if "equivalent" in token_str:
            print(h)
        color = get_color(h)
        # Add colored token to HTML output
        html_output += f"<span style='color: {color}'>{token_str}</span>" 

    display(HTML(html_output)) 
    print(f"Total number of tokens: {num_tokens}")
highlight_text(response)

[0.9999998063873687, 0.9999998063873687, 1.0, 1.0, 1.0, 0.9999801379802525, 0.8479024628751181, 0.3911284720293332, 0.9999928926002577, 0.3665643108724679, 0.2500120655711193, 0.7668703171051177, 0.44927909137019323, 0.5709381807918824, 0.731163206095178, 0.9283719075812265, 0.409517127977843, 0.90578843650422, 0.9415850927407695, 0.6223789919383691, 0.9623325268861276, 0.7373533632713033, 0.37568580872734114, 0.8007521864451503, 0.9955315891874188, 0.6583999536613336, 0.47502834644583847, 0.8031175316938742, 0.7956882544396348, 0.4019135352260091, 0.7494857307347274, 0.9999983759447189, 0.999999687183719, 0.9999860980626328, 0.9996990219729768, 0.9104110574305577, 0.999979303571174, 1.0, 0.9999998063873687, 1.0, 0.9999947998470209, 0.980736856951861, 0.9772183750492627, 0.6325488678133417, 0.983620407590506, 0.8303222753498575, 0.673351949409038, 0.5177427405261549, 0.9950334920880923, 0.9572050759952068, 0.8409635089684246, 0.9731509890065737, 0.9595187617996692, 0.6633786537694937, 

Total number of tokens: 282
