# Langchain + GPT3.5 For Dependency Extraction
First we are going to run this on a small subset of Flink as the whole thing will cost a lot of money.  We will start with a simple reduced structure found in /flink-reduced

### Step 1: Download reqs, load OPENAI_API_KEY from env

In [6]:
# ! pip install openai tiktoken chromadb langchain
from dotenv import load_dotenv, find_dotenv
load_dotenv()

True

### Step 2: Load Documents

In [183]:
import tiktoken
import os 
import re
path = ".\\flink-1.17.1"

def remove_comments(code):
    # Removing /* ... */ comments
    code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL)
    # Removing // comments
    code = re.sub(r'//.*', '', code)
    return code

def read_file(file_path):
  with open(file_path, "r", encoding="utf-8") as f:
    return f.read()
  
def load_documents():
  documents = []
  encoding = tiktoken.encoding_for_model("gpt-3.5-turbo-instruct")

  for dir_path, _, file_names in os.walk(path):
    for file_name in file_names:
      if file_name.endswith(".java"):
        content = read_file(dir_path + "\\" + file_name)
        content = remove_comments(content)

        while len(encoding.encode(content)) > 600:
          content = content[:min(800, len(content))]
          
        documents.append({"file_path": dir_path + "\\" + file_name, "source": content})
  return documents


documents = load_documents()
len(documents)

13218

In [80]:
print(documents[0])
print(documents[1])
print(documents[2])

{'file_path': '.\\flink-1.17.1\\flink-annotations\\src\\main\\java\\org\\apache\\flink\\FlinkVersion.java', 'source': '\n\npackage org.apache.flink;\n\nimport org.apache.flink.annotation.Public;\n\nimport java.util.Arrays;\nimport java.util.LinkedHashSet;\nimport java.util.Map;\nimport java.util.Optional;\nimport java.util.Set;\nimport java.util.function.Function;\nimport java.util.stream.Collectors;\nimport java.util.stream.Stream;\n\n\n@Public\npublic enum FlinkVersion {\n\n    \n    \n    \n    v1_3("1.3"),\n    v1_4("1.4"),\n    v1_5("1.5"),\n    v1_6("1.6"),\n    v1_7("1.7"),\n    v1_8("1.8"),\n    v1_9("1.9"),\n    v1_10("1.10"),\n    v1_11("1.11"),\n    v1_12("1.12"),\n    v1_13("1.13"),\n    v1_14("1.14"),\n    v1_15("1.15"),\n    v1_16("1.16"),\n    v1_17("1.17");\n\n    private final String versionStr;\n\n    FlinkVersion(String versionStr) {\n        this.versionStr = versionStr;\n    }\n\n    @Override\n    public String toString() {\n        return versionStr;\n    }\n\n  

### Step 3: Define Templates

In [120]:
from langchain.prompts import PromptTemplate
   
initial_prompt = PromptTemplate(
    template="""
    Instruction: You will only return valid JSON. Given the following code, extract any internal dependencies. 
    Output must be a valid JSON array of strings. For the given code you must 
    determine what external files or packages it depends on, and return them.              
    File Path: {file_path}
    Code: {source} 
    Answer in valid JSON: \n\n###\n\n""",
    input_variables=["file_path", "source"]
)

retry_prompt = PromptTemplate(
    template="""
    The previous response, "{response}" was not valid JSON. Please try again. 
    Instruction: You will only return valid JSON. Given the following code, extract any internal dependencies. 
    Output must be a valid JSON array of strings. For the given code you must 
    determine what external files or packages it depends on, and return them.                    
    File Path: {file_path}
    Code: {source} 
    Answer in valid, unformatted JSON: \n\n###\n\n""",
    input_variables=["response", "file_path", "source"]
)

### Step 4: Define response extraction stages

In [121]:
import json
from langchain.llms import OpenAI

output = []
parse_fails = 0
retry_fails = 0
llm = OpenAI(temperature=0, model="gpt-3.5-turbo-instruct", frequency_penalty=0, presence_penalty=0, top_p=0.9, max_tokens=1000)

def invoke_model(document, prev_response=None):
    file_path = document["file_path"]
    source = document["source"]
    is_retry = prev_response is not None

    if is_retry:
        formatted = retry_prompt.format(file_path=file_path, source=source, response=prev_response)
    else: 
        formatted = initial_prompt.format(file_path=file_path, source=source)

    response = llm.invoke(formatted)
    response = response.replace("\n", "").replace(" ", "")

    success = handle_response(file_path, response, is_retry)
    if success or (not success and is_retry): 
        return
    else:
        invoke_model(document, prev_response=response)

            
def handle_response(file_path, response, is_retry):
    buffered_output = []
    try: 
        parsed = json.loads(response)

        for dependency in parsed:
            if isinstance(dependency, str):
                buffered_output.append((file_path, dependency))
            else:
                handle_parse_fail(response, "Dependency was not a string", is_retry)
                return False

    except json.JSONDecodeError as e:
        handle_parse_fail(response, "Invalid JSON", is_retry)
        return False
    
    finally:
        for item in buffered_output:
            output.append(item)
        return True


def handle_parse_fail(response, msg, is_retry):
    global retry_fails, parse_fails

    if is_retry:
        retry_fails += 1
        msg_type = "Retry"
    else:
        parse_fails += 1
        msg_type = "Parse"
        
    print(f"{msg_type} fail ({msg}): {response}")

### Step 5: Process in batches

In [181]:
import pickle
with open(f"bin/deps_4000_5000.pkl", "rb") as f:
    output = pickle.load(f)

print(len(output))
batch_start = 5000


54822


In [222]:
import pickle
from IPython.display import display, clear_output

batch_size = 2
batch_end = batch_start + batch_size
batch = documents[batch_start:batch_end]

for i, document in enumerate(batch):
    try:
        invoke_model(document)
    except Exception as e:
        modified_doc = document["source"][:len(document["source"]) // 4]
        print("api err 1x")
        try:
            invoke_model(modified_doc)
        except Exception as e:
            print("api err 2x")

    clear_output(wait=True)
    display(f"Iteration: {i}")
        
batch_start = batch_end

print(f"Batch start: {batch_end - batch_size}")
print(f"Batch size: {batch_size}")
print(f"Parse fails: {parse_fails}")
print(f"Retry fails: {retry_fails}")
print(f"Output length: {len(output)}")

with open(f"bin/deps_{batch_end - batch_size}_{batch_end}.pkl", "wb") as f:
    pickle.dump(output, f)
    print("Saved output")

print(f"Start for next run is {batch_end}")

Batch start: 13219
Batch size: 2
Parse fails: 311
Retry fails: 0
Output length: 126518
Saved output
Start for next run is 13221


In [232]:
import random

rand_out = random.sample(output, 10)

for x,y in rand_out:
    print(x,y)

.\flink-1.17.1\flink-runtime-web\src\main\java\org\apache\flink\runtime\webmonitor\handlers\JarRunMessageParameters.java org.apache.flink.runtime.rest.messages.MessageParameters
.\flink-1.17.1\flink-connectors\flink-connector-hive\src\test\java\org\apache\flink\connectors\hive\read\HivePartitionFetcherTest.java org.apache.flink.connectors.hive.HiveOptions
.\flink-1.17.1\flink-yarn-tests\src\test\java\org\apache\flink\yarn\YARNSessionCapacitySchedulerITCase.java org.apache.flink.runtime.testutils.CommonTestUtils
.\flink-1.17.1\flink-runtime\src\main\java\org\apache\flink\runtime\operators\coordination\CoordinationResponse.java java.io.Serializable
.\flink-1.17.1\flink-connectors\flink-connector-files\src\main\java\org\apache\flink\connector\file\table\stream\ProcTimeCommitTrigger.java org.apache.flink.streaming.runtime.tasks.ProcessingTimeService
.\flink-1.17.1\flink-clients\src\main\java\org\apache\flink\client\program\MiniClusterClient.java org.apache.flink.util.SerializedValue
.\flin

In [233]:
print(f"Parse fails: {parse_fails}")
print(f"Retry fails: {retry_fails}")
print(f"Output length: {len(output)}")

Parse fails: 311
Retry fails: 0
Output length: 126518


### Step 6: Clean up, write to TA

In [235]:
def package_to_path(source, package_name):
    # Extract the base directory up to and including 'java'
    base_dir = source.split("\\java\\")[0] + "\\java\\"

    # Convert the package name to a path and append to base_dir
    package_path = package_name.replace(".", "\\") + ".java"
    return base_dir + package_path

processed_output = set([(source.replace("\\", "/"), package_to_path(source, dependency).replace("\\", "/")) for source, dependency in output])

print(f"Processed size: {len(processed_output)}")

Processed size: 120269


In [236]:
raw_ta_output = "./source_raw_ta/llm_dependencies.raw.ta"

with open(raw_ta_output, "w+") as f:
  f.write("FACT TUPLE : \n")

  unique_file_paths = set(file_path for file_path, _ in processed_output)

  # first generate all the concrete instances
  for file_path in unique_file_paths:
    f.write(f"$INSTANCE {file_path} cFile\n")

  # now add in all the dependencies
  for file_path, dependency in processed_output:
    f.write(f"cLinks {file_path} {dependency}\n")