In [119]:
import requests
from typing import List, Tuple, Optional, Literal
import logging
from dotenv import load_dotenv
from anthropic import Anthropic
import os
import subprocess
from github import Github
from langchain_anthropic import ChatAnthropic
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field
import re
from tqdm import tqdm


load_dotenv()
ANTHROPIC_API_KEY = os.environ.get('ANTHROPIC_API_KEY')
GITHUB_TOKEN = os.environ.get('GITHUB_TOKEN')
github_client = Github(GITHUB_TOKEN)

In [3]:
def run_command(command, cwd=None):
    """Utility function to run shell commands within a specific directory."""
    result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, cwd=cwd)
    if result.returncode != 0:
        print(f"Error: {result.stderr.decode('utf-8')}")
    else:
        print(result.stdout.decode('utf-8'))

def checkout_branch(repo_dir, branch_name):
    """Checkout to a specific branch."""
    run_command(f"git checkout {branch_name}", cwd=repo_dir)

def checkout_commit(repo_dir, commit_hash):
    """Checkout to a specific commit."""
    run_command(f"git checkout {commit_hash}", cwd=repo_dir)

def checkout(repo_dir, branch_name, commit_hash):
    checkout_branch(repo_dir, branch_name)
    checkout_commit(repo_dir, commit_hash)

def get_file_content(repo_dir: str, file_path: str) -> Optional[str]:
    full_path = os.path.join(repo_dir, file_path)
    with open(full_path, 'r') as file:
        return file.read()

In [200]:
def get_pr_information(owner: str, repo: str, pr_number: int, auth_token: str) -> dict:
    repo = github_client.get_repo(f"{owner}/{repo}")
    pr = repo.get_pull(pr_number)
    print(f'fetching PR information for {pr_number}')
    return pr.raw_data
def get_file_content(url):
    response = requests.get(url)
    if response.status_code == 200:
        return response.text
    else:
        raise Exception(f"Failed to fetch file content. Status code: {response.status_code}")

def add_line_numbers(content):
    lines = content.split('\n')
    numbered_lines = [f"<{i+1}>{line}" for i, line in enumerate(lines)]
    return '\n'.join(numbered_lines)

def remove_line_numbers(content):
    lines = content.split('\n')
    unnumbered_lines = [re.sub(r'<\d+>\s', '', line) for line in lines]
    return '\n'.join(unnumbered_lines)

def add_file_content_inplace(file) -> str:
    # get file content 
    file_content = get_file_content(file['raw_url'])

    # add line number to each line 
    numbered_content = add_line_numbers(file_content)

    # update file dict  
    file['content'] = numbered_content
    file['raw_content'] = file_content


def get_pr_files(pull_request) -> List[dict]:
    return [i.raw_data for i in pull_request.get_files()]

def get_prompt():
    return """
    Given the following information about a file that has been changed and its latency profile, what optimizations would you suggest to reduce latency? Only suggest changes when you're confident they will improve the latency, as the results will be evaluated by a profiler.

    File path: {file_name}
    
    File content: {file_content}
    
    File changes: {patch}

    Latency profile results: {latency_results}"

    Please provide specific optimization suggestions based on this information. Return the file updated with the changes made.
    """

def get_system_message():
    return """You are an AI assistant and a smart software engineer. You are specialized in improving code performance and runtime."""

def get_changes_from_llm(file, latency_profile):
    prompt = get_prompt(file, latency_profile=latency_profile)

    changes = send_prompt_to_llm(prompt)

    return changes 


def get_updated_file(file, latency_profile):
    pass 

def create_commit():
    pass 

def create_pr():
    pass 


def send_prompt_to_llm(prompt):
    client = Anthropic(api_key=ANTHROPIC_API_KEY)

    message = client.messages.create(
        max_tokens=4096,
        system=[
            {
                "type": "text",
                "text": get_system_message()
            }
        ],
        messages=[
            {
                "role": "user",
                "content": prompt,
            },
            {
                "role": "assistant",
                "content": "Below is the original file with the updated optimizations made:"
            }
        ],
        model="claude-3-5-sonnet-20240620",
    )
    resp = message.content

    print(f"Usage: {message.usage}")
    if resp:
        return resp[0].text


def get_pull_requests(owner, repo):
    """
    Retrieves all pull requests for a given repository.
    
    Args:
        owner (str): The owner of the repository.
        repo (str): The name of the repository.
    
    Returns:
        list: A list of PullRequest objects.
    """
    repository = github_client.get_repo(f"{owner}/{repo}")
    pull_requests = list(repository.get_pulls(state='open'))
    return pull_requests


def pr_files(owner: str, repo: str, pr_number: int):
    repo_obj = github_client.get_repo(f"{owner}/{repo}")
    pr = repo_obj.get_pull(pr_number)
    pr_files = get_pr_files(pr)

    for file in pr_files:
        # add content of the file 
        add_file_content_inplace(file)
    
    return pr_files

class CodeChange(BaseModel):
    """Changes to make to the file."""

    change_type: Literal["delete", "modify", "add_after"] = Field(description="The type of change to make.")
    line_start: int = Field(description="The starting line of the change. Please verify that the line number is correct.")
    line_end: int = Field(description="The end line of the change. Please verify that the line number is correct.")
    content: str = Field(description="The code changes to make to these lines")

class ChangedCode(BaseModel):
    """The code snippet with the patch applied."""

    content: str = Field(description="code snippet with the code applied.")

class Patch(BaseModel):
    start: Optional[int] = Field(description="start line of the patch")
    end: Optional[int] = Field(description="end line of the patch")
    code_snippet: Optional[str] = Field(description="original piece of code to apply patch to")
    patch: Optional[str] = Field(description="code snippet after applying the patch")

class FileChange(BaseModel):
    """List of all the changes to make to the file."""
    changes: list[CodeChange] = Field(description="List of specific code changes")

class LLM:
    def __init__(self):
        llm = ChatAnthropic(model="claude-3-5-sonnet-20240620")
        structured_llm = llm.with_structured_output(FileChange)
        prompt_template = ChatPromptTemplate.from_messages([
            ("system", get_system_message()),
            ("user", get_prompt())
        ])


        self.runnable = prompt_template | structured_llm

    def get_response(self, file: dict, latency_results: dict) -> FileChange:
        res = self.runnable.invoke(
            {
                "file_name": file['filename'], 
                "file_content": file['content'], 
                "patch": file['patch'],
                "latency_results": latency_results
            })
        
        return res 
    
class LineChangeFixer:
    def __init__(self):
        llm = ChatAnthropic(model="claude-3-5-sonnet-20240620")
        structured_llm = llm.with_structured_output(ChangedCode)
        system = """
        You are a smart and cautious software engineer
        """
        prompt = """
        Given the following snippet of code, apply the following patches to it.

        Code Snippet:
        {code_snippet}
        
        Patches: 
        {patch}
        """
        prompt_template = ChatPromptTemplate.from_messages([
            ("system", system),
            ("user", prompt)
        ])


        self.runnable = prompt_template | structured_llm

    def get_response(self, file: dict, code_change_list: List[CodeChange]) -> Patch:
        if not code_change_list:
            return Patch(patch=None)
        st = min([i.line_start for i in code_change_list])
        end = max([i.line_end for i in code_change_list])
        adjusted_st = max(0, st - 10)
        adjusted_end = end + 10
        # split file content by lines 
        code_snippet = '\n'.join(file['content'].split('\n')[adjusted_st:adjusted_end])
        res = self.runnable.invoke(
            {
                "code_snippet": code_snippet, 
                "patch": [i.model_dump() for i in code_change_list]
            })

        
        return Patch(start=adjusted_st, end=adjusted_end, code_snippet=code_snippet, patch=res.content)
    
def update_file(file: dict, updated_content: str):
    # load file 
    local_file_path = f"../PyGithub/{file['filename']}"
    
    # Write the updated content to the file
    with open(local_file_path, 'w') as f:
        f.write(updated_content)

    print(f'{file["filename"]} has been updated')

def apply_patch_to_file(file: dict, patch: Patch):
    # apply the patch to the file 
    lines = file['raw_content'].split('\n')
    # remove line numbers from the patch 
    tmp = remove_line_numbers(tmp['patch'].content)
    new_lines = tmp.split('\n')
    lines[patch.start:patch.end] = new_lines
    new_file_content = '\n'.join(lines)

    update_file(file, new_file_content)

def combine_close_changes(changes: List[CodeChange], max_distance: int = 10) -> List[List[CodeChange]]:
    if not changes:
        return []

    # Sort changes by start line
    sorted_changes = sorted(changes, key=lambda x: x.line_start)
    
    combined = []
    current_group = [sorted_changes[0]]

    for change in sorted_changes[1:]:
        if change.line_start - current_group[-1].line_end <= max_distance:
            current_group.append(change)
        else:
            combined.append(current_group)
            current_group = [change]

    # Add the last group
    combined.append(current_group)

    return combined

    
def main(owner, repo, pr_number):
    # get latency profile 
    latency_profile = "100ms"
    # get files in the pr  
    pr_files = get_pr_files(owner, repo, pr_number)

    for file in tqdm(pr_files):
        # get code changes 
        res = LLM().get_response(file, latency_results=latency_profile)
        code_change_list = combine_close_changes(res.changes)
        patch = LineChangeFixer().get_response(file=file, code_change_list=code_change_list)
        # get changes and update 
        apply_patch_to_file(patch)

def remove_line_numbers(text: str) -> str:    
    return re.sub(r'<\d+>', '', text)

In [None]:
# Example usage:
owner="pygithub"
repo = "pygithub"
pr_list = get_pull_requests(owner, repo)

In [202]:
import random 
pr = random.choice(pr_list)
if len(list(pr.get_files())) < 4:
    files = pr_files(owner, repo, pr.number)
    print(f'Number of files changed: {len(files)}')

Number of files changed: 3


In [205]:
file = random.choice(files)
print(f'Chosen file: {file["filename"]}')
res = LLM().get_response(file, latency_results="100ms")
print(f'{len(res.changes)} suggested')
code_change_list = combine_close_changes(res.changes)
print(f'{len(res.changes)} suggested changes')
tmp = LineChangeFixer().get_response(file=file, code_change_list=code_change_list)

Chosen file: tests/ReplayData/Organization.testGetRepoSecurityAdvisories.txt
0 suggested changes


ValueError: min() arg is an empty sequence