# Langchain + GPT3.5 For Dependency Extraction

### Step 1: Install libraries (if needed), load OPENAI_API_KEY from env

In [72]:
# !pip install openai tiktoken langchain
from dotenv import load_dotenv, find_dotenv
load_dotenv()

True

### Step 2: Load Documents
- Remove comments 
- Only process .java files
- For "large" (>600 token) inputs, we are going to truncate first 1000 characters
  - 800 and 1000 are somewhat random
  - Truncated size needs to be high enough that we retain imports and ideally main methods (which would typically be defined near the top in java programs)
  - Higher truncated size => more cost and time
  - With more resources, we could have opted to use CodeLlama or GPT-4 however both options were not available to us
  - With more resources we also could have avoided truncating any data and even possibly retain the comments
  - In initial tests on a reduced version of flink, we saw optimal inference time and accuracy using these parameters, so that is what we chose to run on the whole system

In [158]:
import tiktoken
import os 
import re
path = ".\\flink-1.17.1"

def remove_comments(code):
    # Removing /* ... */ comments
    code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL)
    
    # Removing // comments
    code = re.sub(r'//.*', '', code)
    return code

def read_file(file_path):
  with open(file_path, "r", encoding="utf-8") as f:
    return f.read()
  
def load_documents():
  documents = []
  encoding = tiktoken.encoding_for_model("gpt-3.5-turbo-instruct")

  for dir_path, _, file_names in os.walk(path):
    for file_name in file_names:
      if file_name.endswith(".java"):
        content = read_file(dir_path + "\\" + file_name)
        content = remove_comments(content)

        while len(encoding.encode(content)) > 800:
          content = content[:1000]
          
        documents.append({"file_path": dir_path + "\\" + file_name, "source": content})
  return documents

documents = load_documents()
len(documents)

13219

In [163]:
print(documents[0])
print(documents[1])
print(documents[2])

{'file_path': '.\\flink-1.17.1\\flink-annotations\\src\\main\\java\\org\\apache\\flink\\FlinkVersion.java', 'source': '\n\npackage org.apache.flink;\n\nimport org.apache.flink.annotation.Public;\n\nimport java.util.Arrays;\nimport java.util.LinkedHashSet;\nimport java.util.Map;\nimport java.util.Optional;\nimport java.util.Set;\nimport java.util.function.Function;\nimport java.util.stream.Collectors;\nimport java.util.stream.Stream;\n\n\n@Public\npublic enum FlinkVersion {\n\n    \n    \n    \n    v1_3("1.3"),\n    v1_4("1.4"),\n    v1_5("1.5"),\n    v1_6("1.6"),\n    v1_7("1.7"),\n    v1_8("1.8"),\n    v1_9("1.9"),\n    v1_10("1.10"),\n    v1_11("1.11"),\n    v1_12("1.12"),\n    v1_13("1.13"),\n    v1_14("1.14"),\n    v1_15("1.15"),\n    v1_16("1.16"),\n    v1_17("1.17");\n\n    private final String versionStr;\n\n    FlinkVersion(String versionStr) {\n        this.versionStr = versionStr;\n    }\n\n    @Override\n    public String toString() {\n        return versionStr;\n    }\n\n  

### Step 3: Define Templates
- Retry prompt is used for iterative refinement/feedback prompting
- Langchain has a fallback method, but there is no way to track how many times it was called
- Thus we will manually do the retry fallback

In [120]:
from langchain.prompts import PromptTemplate
   
initial_prompt = PromptTemplate(
    template="""
    Instruction: You will only return valid JSON. Given the following code, extract any internal dependencies. 
    Output must be a valid JSON array of strings. For the given code you must 
    determine what external files or packages it depends on, and return them.              
    File Path: {file_path}
    Code: {source} 
    Answer in valid JSON: \n\n###\n\n""",
    input_variables=["file_path", "source"]
)

retry_prompt = PromptTemplate(
    template="""
    The previous response, "{response}" was not valid JSON. Please try again. 
    Instruction: You will only return valid JSON. Given the following code, extract any internal dependencies. 
    Output must be a valid JSON array of strings. For the given code you must 
    determine what external files or packages it depends on, and return them.                    
    File Path: {file_path}
    Code: {source} 
    Answer in valid, unformatted JSON: \n\n###\n\n""",
    input_variables=["response", "file_path", "source"]
)

### Step 4: Define response extraction stages
- Temperature = 0, tends the model towards being deterministic (less randomness)
- Top P = 0.9, ignore the least likely 10% options in prediction (less creative results)
- Invoke model calls itself a second time if the JSON parse fails; after the second time we give up and move on
- In practice the model never made 2 consecutive errors throughout all ~13200 files
- This is largely due to the retry prompt including previous response
- LLMs are very good at fixing errors when they are pointed out (few shot prompting for example)
- Response is buffered when going through the array in case we run into non-string values. When the code was ran on all of Flink, this did not happen.

In [121]:
import json
from langchain.llms import OpenAI

output = []
parse_fails = retry_fails = 0

llm = OpenAI(temperature=0, model="gpt-3.5-turbo-instruct", frequency_penalty=0, presence_penalty=0, top_p=0.9, max_tokens=1000)

def invoke_model(document, prev_response=None):
    file_path = document["file_path"]
    source = document["source"]
    is_retry = prev_response is not None

    if is_retry:
        formatted = retry_prompt.format(file_path=file_path, source=source, response=prev_response)
    else: 
        formatted = initial_prompt.format(file_path=file_path, source=source)

    response = llm.invoke(formatted)
    response = response.replace("\n", "").replace(" ", "")

    success = handle_response(file_path, response, is_retry)
    if success or (not success and is_retry): 
        return
    else:
        invoke_model(document, prev_response=response)

            
def handle_response(file_path, response, is_retry):
    buffered_output = []
    try: 
        parsed = json.loads(response)

        for dependency in parsed:
            if isinstance(dependency, str):
                buffered_output.append((file_path, dependency))
            else:
                handle_parse_fail(response, "Dependency was not a string", is_retry)
                return False

    except json.JSONDecodeError as e:
        handle_parse_fail(response, "Invalid JSON", is_retry)
        return False
    
    finally:
        for item in buffered_output:
            output.append(item)
        return True


def handle_parse_fail(response, msg, is_retry):
    global retry_fails, parse_fails

    if is_retry:
        retry_fails += 1
        msg_type = "Retry"
    else:
        parse_fails += 1
        msg_type = "Parse"
        
    print(f"{msg_type} fail ({msg}): {response}")

### Step 5: Process in batches
- When doing 13,000+ requests we have a few things to be worried about
  - 10k to 100k request rate limit depending on plan tier
  - Tokens per minute rate limit
  - Any other api errors that might arise
- Batch size is modifiable
- After each batch we will save the list state to bin
- This way if anything fails we can restore manually to a previous state in the array

In [93]:
# parse_fails = retry_fails = 0
import pickle
with open(f"bin/deps_13219_13221.pkl", "rb") as f:
    output = pickle.load(f)

print(len(output))
batch_start = 5000


126518


In [222]:
import pickle
from IPython.display import display, clear_output

batch_size = 2
batch_end = batch_start + batch_size
batch = documents[batch_start:batch_end]

# sequential Process
for i, document in enumerate(batch):
    try:
        invoke_model(document)
    except Exception as e:
        # usually fails due to length so this is an optimistic retry (hopefully API error is just 1 off)
        modified_doc = document["source"][:len(document["source"]) // 4] 
        print("API Error [1st]")
        try:
            invoke_model(modified_doc)
        except Exception as e:
            print("API Error [2nd]")

    clear_output(wait=True)
    display(f"Iteration: {i}")


# parallel Process -> wont work because rate limit, would have to use small batches and not really going to save much time
# if we had an account without rate limit or access to multiple instances of something 
# like CodeLlama we would be able to do this and save a lot of time


# import asyncio

# async def async_invoke(prompt):
#     resp = await llm.ainvoke(prompt)

# async def invoke_model_async(prev_response=None):
#     tasks = [async_invoke(llm, initial_prompt.format(file_path=document["file_path"], source=document["source"])) for document in batch]
#     for coroutine in asyncio.as_completed(tasks):
#         try:
#             results = await coroutine
#         except Exception as e:
#             parse_fails += 1
            
#             # try a second time...
#             try: 
#                 results_2 = await async_invoke(llm, retry_prompt.format(file_path=document["file_path"], source=document["source"], response=results))
#             except Exception as e:
#                 retry_fails += 1

#         else:
#             print('Results:', results)

#     await asyncio.gather(*tasks)

# await invoke_model_async()

print(f"Batch start: {batch_start}")
print(f"Batch size: {batch_size}")
print(f"Parse fails: {parse_fails}")
print(f"Retry fails: {retry_fails}")
print(f"Output length: {len(output)}")

# move start (for next run) up to end of batch
batch_start = batch_end
print(f"Start for next run is {batch_end}")

# save intermediate state
with open(f"bin/deps_{batch_end - batch_size}_{batch_end}.pkl", "wb") as f:
    pickle.dump(output, f)
    print("Saved output")

Batch start: 13219
Batch size: 2
Parse fails: 311
Retry fails: 0
Output length: 126518
Saved output
Start for next run is 13221


In [3]:
# pick out random 10 to see if we did things right...
import random

for x,y in random.sample(output, 10):
    print(x,y)

.\flink-1.17.1\flink-connectors\flink-connector-files\src\test\java\org\apache\flink\connector\file\table\stream\StreamingFileWriterTest.java org.apache.flink.streaming.api.functions.sink.filesystem.OutputFileConfig
.\flink-1.17.1\flink-connectors\flink-connector-files\src\main\java\org\apache\flink\connector\file\table\DefaultPartTimeExtractor.java java.time.LocalTime
.\flink-1.17.1\flink-table\flink-table-runtime\src\main\java\org\apache\flink\table\runtime\functions\BuiltInSpecializedFunction.java org.apache.flink.table.catalog.DataTypeFactory
.\flink-1.17.1\flink-connectors\flink-hadoop-compatibility\src\test\java\org\apache\flink\test\hadoopcompatibility\mapred\HadoopMapredITCase.java org.apache.flink.test.testdata.WordCountData
.\flink-1.17.1\flink-runtime\src\test\java\org\apache\flink\runtime\checkpoint\channel\ChannelStateWriteRequestExecutorImplTest.java org.apache.flink.util.function.BiConsumerWithException
.\flink-1.17.1\flink-table\flink-table-runtime\src\main\java\org\apa

### Step 6: Clean up, write to TA

In [144]:
# clean duplicates
output = sorted(set(output))
print(f" {len(output)} uniques")

 120269 uniques


In [161]:
# generate map of package -> path where it is defined
package_to_path_map = {}

for dir_path, _, file_names in os.walk(path):
    for file_name in file_names:
        if file_name.endswith(".java"):
            package = dir_path.split("\\java\\")[-1].replace("\\", ".") + "." + file_name.replace(".java", "")
            package_to_path_map[package] = dir_path + "\\" + file_name
            
print(f"Size: {len(package_to_path_map)}")
for k,v in list(package_to_path_map.items())[:10]:
    print(k + " -> " + v)


Size: 13194
org.apache.flink.FlinkVersion -> .\flink-1.17.1\flink-annotations\src\main\java\org\apache\flink\FlinkVersion.java
org.apache.flink.annotation.Experimental -> .\flink-1.17.1\flink-annotations\src\main\java\org\apache\flink\annotation\Experimental.java
org.apache.flink.annotation.Internal -> .\flink-1.17.1\flink-annotations\src\main\java\org\apache\flink\annotation\Internal.java
org.apache.flink.annotation.Public -> .\flink-1.17.1\flink-annotations\src\main\java\org\apache\flink\annotation\Public.java
org.apache.flink.annotation.PublicEvolving -> .\flink-1.17.1\flink-annotations\src\main\java\org\apache\flink\annotation\PublicEvolving.java
org.apache.flink.annotation.VisibleForTesting -> .\flink-1.17.1\flink-annotations\src\main\java\org\apache\flink\annotation\VisibleForTesting.java
org.apache.flink.annotation.docs.ConfigGroup -> .\flink-1.17.1\flink-annotations\src\main\java\org\apache\flink\annotation\docs\ConfigGroup.java
org.apache.flink.annotation.docs.ConfigGroups -> 

In [150]:
# apply mapping to depedency list, generate final output (to use for raw ta)
final_output = []

for source, depdendency in output:
    dep_path = package_to_path_map.get(depdendency)
    if dep_path is not None:
        final_output.append((source[2:].replace("\\", "/"), dep_path[2:].replace("\\", "/")))

print(f"Final output length: {len(final_output)}")
for x,y in final_output[:10]:
    print(x,y)


Final output length: 74500
flink-1.17.1/flink-annotations/src/main/java/org/apache/flink/FlinkVersion.java flink-1.17.1/flink-annotations/src/main/java/org/apache/flink/annotation/Public.java
flink-1.17.1/flink-annotations/src/main/java/org/apache/flink/annotation/Experimental.java flink-1.17.1/flink-annotations/src/main/java/org/apache/flink/annotation/Public.java
flink-1.17.1/flink-annotations/src/main/java/org/apache/flink/annotation/Internal.java flink-1.17.1/flink-annotations/src/main/java/org/apache/flink/annotation/Public.java
flink-1.17.1/flink-annotations/src/main/java/org/apache/flink/annotation/Public.java flink-1.17.1/flink-annotations/src/main/java/org/apache/flink/annotation/Public.java
flink-1.17.1/flink-annotations/src/main/java/org/apache/flink/annotation/PublicEvolving.java flink-1.17.1/flink-annotations/src/main/java/org/apache/flink/annotation/Public.java
flink-1.17.1/flink-annotations/src/main/java/org/apache/flink/annotation/VisibleForTesting.java flink-1.17.1/fli

In [151]:
# write to TA
raw_ta_output = "./source_raw_ta/llm_dependencies.raw.ta"

with open(raw_ta_output, "w+") as f:
  f.write("FACT TUPLE : \n")

  unique_file_paths = set(file_path for file_path, _ in final_output)

  # first generate all the concrete instances
  for file_path in unique_file_paths:
    f.write(f"$INSTANCE {file_path} cFile\n")

  # now add in all the dependencies
  for file_path, dependency in final_output:
    f.write(f"cLinks {file_path} {dependency}\n")

# Recap

1. More resources would let us chose a better model (GPT-4 or CodeLlama)

2. More resources could also mean less truncation of input files (improve accuracy)

3. Batching would still be needed to allow us to work within rate limits 

4. Batching helped incrementally generate the dependency list with tolerance for failure

5. Examples could have been provided to the LLM of hard to recognize dependencies, if those exist (improve accuracy)

6. Even without providing examples, we still obtained good performance

7. In a more focused environment we could have also finetuned the model prior to use for better performance