# Simple Classifier for LLM Prompts
# (DON'T LOOK) WORK IN PROGRESS!

### Approach 1: Binary Classification using Logistic Regression

In [None]:
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score, classification_report

def train_llm_prompt_classifier(data, labels):
    # Split the data into training and test sets
    X_train, X_test, y_train, y_test = train_test_split(data, labels, test_size=0.3, random_state=42, shuffle=True)

    # Convert text to numerical features using TF-IDF
    # May not be the best idea for detecting LLM prompts 😬. Let's see how it goes.
    tfidf_vectorizer = TfidfVectorizer() 

    # Use Logistic Regression for classification
    classifier = LogisticRegression()

    # Create a pipeline
    pipeline = Pipeline([
        ('tfidf', tfidf_vectorizer),
        ('classifier', classifier)
    ])

    # Train the model
    pipeline.fit(X_train, y_train)

    # Evaluate the model on the test data
    y_pred = pipeline.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    print("\n\nClassifier Performance\n")
    print(f"Accuracy: {accuracy:.2f}\n")
    print(classification_report(y_test, y_pred, target_names=["Non-Prompt", "LLM Prompt"]))

    return pipeline

classifier = train_llm_prompt_classifier(data, labels)

In [None]:
def is_llm_prompt(text, classifier):
    prediction = classifier.predict([text])
    return prediction[0] == 1

example_text = "Example prompt {question}:"
print(is_llm_prompt(example_text, classifier))

### Approach 2: Docstrings with certain criteria

In [None]:
from tree_sitter import Language, Parser

# Load Python grammar
TREE_SITTER_PYTHON_PATH = "tree-sitter-python"  # Change this according to the cloned path
Language.build_library(
    "build/python.so",
    [
        TREE_SITTER_PYTHON_PATH,
    ],
)
PYTHON_LANGUAGE = Language("build/python.so", "python")

# Initialize parser
parser = Parser()
parser.set_language(PYTHON_LANGUAGE)

# Function to check if a node is a docstring
def is_docstring_node(node):
    return (
        node.type == "string"
        and node.parent
        and node.parent.type in ["module", "class_definition", "function_definition"]
    )


def find_docstrings(tree):
    cursor = tree.walk()
    docstrings = []

    def _traverse(node):
        if is_docstring_node(node):
            docstrings.append(node)
        for child in node.children:
            _traverse(child)

    _traverse(tree.root_node)

    return docstrings


# Test the code
code = """
def my_function():
    \"\"\"This is a docstring.\"\"\"
    pass

\"\"\"This is another docstring.\"\"\"

class MyClass:
    \"\"\"Class docstring.\"\"\"
    pass
"""

tree = parser.parse(bytes(code, "utf8"))

docstrings = find_docstrings(tree)

for docstring in docstrings:
    print(docstring)

### TODO: Test "\n" heuristic against a list of all docstrings assigned to variables

### Approach 3: Regex to Check for String Interpolation within all strings assigned to variables

In [None]:
import re
from tree_sitter import Language, Parser

# ... (previous code)

# Function to check if a docstring uses string interpolation
def contains_string_interpolation(docstring_text):
    # Check for f-string expressions
    f_string_pattern = r'{[^{]*?\{(.*?)\}[^}]*?}'
    if re.search(f_string_pattern, docstring_text):
        return True

    # Check for str.format() expressions
    str_format_pattern = r'\{(?:[^{}]+)?\}'
    if re.search(str_format_pattern, docstring_text):
        return True

    return False

# Test the code
code = '''
def my_function(arg):
    """This is a docstring with interpolation: {arg}"""
    pass

def another_function(value):
    f"""This is an f-string docstring: {value}"""
    pass
'''

tree = parser.parse(bytes(code, "utf8"))

docstrings = find_docstrings(tree)

for docstring in docstrings:
    docstring_text = code[docstring.start_byte:docstring.end_byte]
    print(f"Docstring: {docstring_text.strip()}")
    print(f"Contains string interpolation: {contains_string_interpolation(docstring_text)}\n")