In [1]:
!pip install -q scikit-learn


In [3]:
import os
import ast
from textwrap import indent

import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.pipeline import Pipeline
from sklearn.svm import LinearSVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report


In [4]:
FUNCTION_LABELS = [
    "data_processing",
    "api_endpoint",
    "utility",
    "ml_model",
    "database",
]

def build_function_type_training_data():
    examples = [
        # Data processing
        ("""
def clean_dataframe(df):
    df = df.dropna()
    df = df.rename(columns=str.lower)
    return df
""", "data_processing"),
        ("""
def normalize_data(data):
    mean = np.mean(data)
    std = np.std(data)
    return (data - mean) / std
""", "data_processing"),

        # API endpoints
        ("""
@app.route('/users', methods=['GET'])
def get_users():
    users = User.query.all()
    return jsonify([u.to_dict() for u in users])
""", "api_endpoint"),
        ("""
@router.post('/login')
def login():
    payload = request.json
    token = auth_service.login(payload['email'], payload['password'])
    return JSONResponse({'token': token})
""", "api_endpoint"),

        # Utility
        ("""
def slugify(text):
    return text.lower().replace(" ", "-")
""", "utility"),
        ("""
def chunk_list(lst, size):
    for i in range(0, len(lst), size):
        yield lst[i:i+size]
""", "utility"),

        # ML model related
        ("""
def train_model(X_train, y_train):
    model = RandomForestClassifier()
    model.fit(X_train, y_train)
    return model
""", "ml_model"),
        ("""
def predict_proba(model, X):
    return model.predict_proba(X)
""", "ml_model"),

        # Database
        ("""
def get_user_by_id(conn, user_id):
    cursor = conn.cursor()
    cursor.execute("SELECT * FROM users WHERE id = %s", (user_id,))
    return cursor.fetchone()
""", "database"),
        ("""
def save_order(session, order):
    session.add(order)
    session.commit()
""", "database"),
    ]

    texts, labels = zip(*examples)
    return list(texts), list(labels)

def train_function_type_classifier():
    texts, labels = build_function_type_training_data()

    # test_size=0.5 so each of 5 classes gets at least 1 sample in train & val
    X_train, X_val, y_train, y_val = train_test_split(
        texts, labels, test_size=0.5, random_state=42, stratify=labels
    )

    clf = Pipeline([
        ("tfidf", TfidfVectorizer(ngram_range=(1, 2), max_features=5000)),
        ("svm", LinearSVC())
    ])

    clf.fit(X_train, y_train)

    print("=== Function Type Classifier Validation ===")
    y_pred = clf.predict(X_val)
    print(classification_report(y_val, y_pred))

    return clf

function_type_classifier = train_function_type_classifier()


=== Function Type Classifier Validation ===
                 precision    recall  f1-score   support

   api_endpoint       0.00      0.00      0.00         1
data_processing       0.00      0.00      0.00         1
       database       0.00      0.00      0.00         1
       ml_model       0.33      1.00      0.50         1
        utility       0.00      0.00      0.00         1

       accuracy                           0.20         5
      macro avg       0.07      0.20      0.10         5
   weighted avg       0.07      0.20      0.10         5



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [5]:
READABILITY_LABELS = ["low", "medium", "high"]

def build_readability_training_data():
    """
    Synthetic examples â€“ you can extend with your own.
    """
    examples = [
        # Low readability: no spacing, weird names, nested, long lines
        ("""
def f(a,b,c,d,e,f,g,h,i,j):print(a,b,c,d,e,f,g,h,i,j)
""", "low"),
        ("""
def calc(x,y):
    if x>0:
        if y>0:
            if x>y:
                return x-y
            else:
                return y-x
    else:
        return x+y
""", "low"),
        ("""
def do(a,b,c,d,e,f):
    for i in range(len(a)):
        for j in range(len(b)):
            for k in range(len(c)):
                for l in range(len(d)):
                    print(i,j,k,l)
""", "low"),

        # Medium readability: okay but could be better
        ("""
def sum_list(values):
    total = 0
    for v in values:
        total += v
    return total
""", "medium"),
        ("""
def filter_positive(nums):
    result = []
    for n in nums:
        if n > 0:
            result.append(n)
    return result
""", "medium"),
        ("""
def is_valid_email(email):
    if "@" in email and "." in email:
        return True
    return False
""", "medium"),

        # High readability: docstrings, clear names, simple logic
        ("""
def compute_mean(values):
    \"\"\"Return the arithmetic mean of a list of numbers.\"\"\"
    if not values:
        return 0.0
    return sum(values) / len(values)
""", "high"),
        ("""
def format_full_name(first_name, last_name):
    \"\"\"Format a full name in 'Last, First' style.\"\"\"
    first = first_name.strip().title()
    last = last_name.strip().title()
    return f"{last}, {first}"
""", "high"),
        ("""
def chunk(iterable, size):
    \"\"\"Yield fixed-size chunks from an iterable.\"\"\"
    if size <= 0:
        raise ValueError("size must be positive")
    for i in range(0, len(iterable), size):
        yield iterable[i:i+size]
""", "high"),
    ]

    texts, labels = zip(*examples)
    return list(texts), list(labels)

def train_readability_classifier():
    texts, labels = build_readability_training_data()

    # Small dataset â†’ simple split, no stratify to avoid edge errors
    X_train, X_val, y_train, y_val = train_test_split(
        texts, labels, test_size=0.33, random_state=42
    )

    clf = Pipeline([
        ("tfidf", TfidfVectorizer(ngram_range=(1, 2), max_features=5000)),
        ("svm", LinearSVC())
    ])

    clf.fit(X_train, y_train)

    print("=== Readability Classifier Validation ===")
    y_pred = clf.predict(X_val)
    print(classification_report(y_val, y_pred))

    return clf

readability_classifier = train_readability_classifier()


=== Readability Classifier Validation ===
              precision    recall  f1-score   support

        high       0.00      0.00      0.00         1
         low       0.00      0.00      0.00         1
      medium       0.50      1.00      0.67         1

    accuracy                           0.33         3
   macro avg       0.17      0.33      0.22         3
weighted avg       0.17      0.33      0.22         3



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [6]:
COMPLEXITY_NODES = (
    ast.If, ast.For, ast.While, ast.Try, ast.With,
    ast.BoolOp, ast.IfExp, ast.comprehension
)

def compute_cyclomatic_complexity(node):
    """
    Very simple cyclomatic complexity:
    complexity = 1 + number of decision points.
    """
    count = 0
    for child in ast.walk(node):
        if isinstance(child, COMPLEXITY_NODES):
            count += 1
    return 1 + count

def compute_max_depth(node, level=0):
    """
    Approximate nesting depth of blocks.
    """
    if not isinstance(node, ast.AST):
        return level
    max_child = level
    for child in ast.iter_child_nodes(node):
        child_depth = compute_max_depth(child, level + 1)
        if child_depth > max_child:
            max_child = child_depth
    return max_child

def analyze_ast_metrics(func_node, source_lines):
    start = func_node.lineno - 1
    end = getattr(func_node, "end_lineno", func_node.lineno)
    num_lines = end - start

    num_params = len(func_node.args.args)

    complexity = compute_cyclomatic_complexity(func_node)
    max_depth = compute_max_depth(func_node)

    num_calls = 0
    uses_eval_exec = False
    uses_print_input = False
    sql_like_strings = False

    for n in ast.walk(func_node):
        if isinstance(n, ast.Call):
            num_calls += 1
            if isinstance(n.func, ast.Name) and n.func.id in ("eval", "exec"):
                uses_eval_exec = True
            if isinstance(n.func, ast.Name) and n.func.id in ("print", "input"):
                uses_print_input = True
        if isinstance(n, ast.Constant) and isinstance(n.value, str):
            text = n.value.upper()
            if any(kw in text for kw in ("SELECT ", "INSERT ", "UPDATE ", "DELETE ")):
                sql_like_strings = True

    return {
        "num_lines": num_lines,
        "num_params": num_params,
        "complexity": complexity,
        "max_depth": max_depth,
        "num_calls": num_calls,
        "uses_eval_exec": uses_eval_exec,
        "uses_print_input": uses_print_input,
        "sql_like_strings": sql_like_strings,
    }

def detect_code_smells(ast_metrics, func_source):
    smells = []

    if ast_metrics["num_lines"] > 50:
        smells.append("Long Method (too many lines)")

    if ast_metrics["num_params"] > 5:
        smells.append("Long Parameter List (>5 parameters)")

    if ast_metrics["max_depth"] > 4:
        smells.append("Deeply Nested Logic (max depth > 4)")

    if ast_metrics["uses_eval_exec"]:
        smells.append("Use of eval/exec (security risk)")

    if ast_metrics["uses_print_input"]:
        smells.append("Uses print/input (debug or CLI in core logic)")

    if ast_metrics["sql_like_strings"] and "+" in func_source:
        smells.append("Possible SQL string concatenation (injection risk)")

    if "password" in func_source.lower() and ("=" in func_source):
        smells.append("Possible hard-coded password/secret")

    return smells


In [7]:
def extract_functions_from_file(file_path):
    """
    Parse a Python file and extract all top-level and class methods.
    """
    with open(file_path, "r", encoding="utf-8") as f:
        source = f.read()

    tree = ast.parse(source)
    lines = source.splitlines()

    functions = []

    class FunctionVisitor(ast.NodeVisitor):
        def __init__(self, outer_class=None):
            self.outer_class = outer_class
            super().__init__()

        def visit_ClassDef(self, node):
            visitor = FunctionVisitor(outer_class=node.name)
            visitor.visit_nodes(node.body)

        def visit_FunctionDef(self, node):
            start = node.lineno - 1
            end = getattr(node, "end_lineno", node.lineno)
            func_source = "\n".join(lines[start:end])

            arg_names = [arg.arg for arg in node.args.args]
            qualname = node.name
            if self.outer_class:
                qualname = f"{self.outer_class}.{node.name}"

            metrics = analyze_ast_metrics(node, lines)
            smells = detect_code_smells(metrics, func_source)

            functions.append({
                "name": node.name,
                "qualname": qualname,
                "start_line": start + 1,
                "end_line": end,
                "source": func_source,
                "args": arg_names,
                "metrics": metrics,
                "smells": smells,
            })

        def visit_nodes(self, nodes):
            for n in nodes:
                self.visit(n)

    visitor = FunctionVisitor()
    visitor.visit_nodes(tree.body)

    return functions


In [8]:
READABILITY_NUMERIC = {
    "low": 3.0,
    "medium": 6.0,
    "high": 9.0,
}

def classify_function_type(func_source, model):
    return model.predict([func_source])[0]

def classify_readability(func_source, model):
    label = model.predict([func_source])[0]
    score = READABILITY_NUMERIC.get(label, 5.0)
    return label, score

def generate_function_summary(func_meta, func_type, read_label):
    name = func_meta["qualname"]
    args = func_meta["args"]

    args_str = ", ".join(args) if args else "no arguments"

    type_descriptions = {
        "data_processing": "Performs data cleaning or transformation.",
        "api_endpoint": "Exposes an API endpoint or HTTP route.",
        "utility": "Provides a general-purpose helper/utility.",
        "ml_model": "Trains or uses a machine-learning model.",
        "database": "Interacts with a database or persistent storage.",
    }

    type_desc = type_descriptions.get(func_type, "General-purpose function.")

    return f"{name}({args_str}) â€” {type_desc} Readability is classified as **{read_label}**."

def generate_auto_refactor_suggestion(func_meta):
    """
    Heuristic-based refactor suggestion.
    This is NOT changing the original file, just giving a suggested version.
    Focus: readability + performance + security hints.
    """
    src = func_meta["source"]
    lines = src.splitlines()
    if not lines:
        return src  # nothing to do

    # Ensure docstring
    ref_lines = lines.copy()
    def_line = ref_lines[0]
    indent_spaces = " " * (len(def_line) - len(def_line.lstrip(" ")))

    has_docstring = False
    if len(ref_lines) > 1:
        stripped = ref_lines[1].strip()
        if (stripped.startswith('"""') and stripped.endswith('"""')) or \
           (stripped.startswith("'''") and stripped.endswith("'''")):
            has_docstring = True

    if not has_docstring:
        docstring = indent_spaces + '    """TODO: Add meaningful docstring."""'
        if len(ref_lines) == 1:
            ref_lines.append(docstring)
        else:
            ref_lines.insert(1, docstring)

    # Append security/performance comments at the end if smells detected
    metrics = func_meta["metrics"]
    smells = func_meta["smells"]

    comment_block = []
    comment_block.append(indent_spaces + "# Auto-refactor hints:")
    if metrics["num_lines"] > 50:
        comment_block.append(indent_spaces + "# - Consider splitting into smaller functions.")
    if metrics["max_depth"] > 4:
        comment_block.append(indent_spaces + "# - Reduce nesting with early returns or helpers.")
    if metrics["uses_eval_exec"]:
        comment_block.append(indent_spaces + "# - Replace eval/exec with safe alternatives.")
    if metrics["sql_like_strings"]:
        comment_block.append(indent_spaces + "# - Use parameterized queries instead of string concatenation.")
    if metrics["num_params"] > 5:
        comment_block.append(indent_spaces + "# - Group related parameters into objects or dataclasses.")

    if len(comment_block) > 1:
        ref_lines.append("")  # blank line
        ref_lines.extend(comment_block)

    return "\n".join(ref_lines)


In [9]:
def generate_markdown_documentation(file_path, functions,
                                    type_model, read_model):
    md_lines = []
    md_lines.append(f"# Auto-Generated Code Intelligence Report\n")
    md_lines.append(f"**Source file:** `{os.path.basename(file_path)}`\n")
    md_lines.append("---\n")

    for func in functions:
        func_type = classify_function_type(func["source"], type_model)
        read_label, read_score = classify_readability(func["source"], read_model)
        summary = generate_function_summary(func, func_type, read_label)
        metrics = func["metrics"]
        smells = func["smells"]
        refactored = generate_auto_refactor_suggestion(func)

        md_lines.append(f"## `{func['qualname']}`")
        md_lines.append(f"- **Lines:** {func['start_line']}â€“{func['end_line']}")
        md_lines.append(f"- **Type (ML classified):** `{func_type}`")
        md_lines.append(f"- **Readability:** `{read_label}` (score â‰ˆ {read_score:.1f}/10)")
        md_lines.append(f"- **Cyclomatic Complexity:** `{metrics['complexity']}`")
        md_lines.append(f"- **Max Nesting Depth:** `{metrics['max_depth']}`")
        md_lines.append(f"- **Parameters:** `{metrics['num_params']}`")
        md_lines.append(f"- **Function Calls:** `{metrics['num_calls']}`")

        if smells:
            md_lines.append(f"- **Detected Code Smells:**")
            for s in smells:
                md_lines.append(f"  - {s}")
        else:
            md_lines.append(f"- **Detected Code Smells:** None ðŸŽ‰")

        md_lines.append(f"- **Summary:** {summary}\n")

        # Original source
        md_lines.append("<details><summary>Original Source Code</summary>\n")
        md_lines.append("")
        md_lines.append("```python")
        md_lines.append(func["source"])
        md_lines.append("```")
        md_lines.append("\n</details>\n")

        # Refactored suggestion
        md_lines.append("<details><summary>Auto-Refactor Suggestion</summary>\n")
        md_lines.append("")
        md_lines.append("```python")
        md_lines.append(refactored)
        md_lines.append("```")
        md_lines.append("\n</details>\n")

        md_lines.append("---\n")

    md_content = "\n".join(md_lines)

    out_path = os.path.join(os.path.dirname(file_path), "DOCUMENTATION.md")
    with open(out_path, "w", encoding="utf-8") as f:
        f.write(md_content)

    return out_path


In [10]:
def analyze_code_file(file_path,
                      type_model=None,
                      read_model=None):
    if type_model is None or read_model is None:
        raise ValueError("Models cannot be None. Pass both type_model and read_model.")

    print(f"Analyzing: {file_path}")

    functions = extract_functions_from_file(file_path)
    if not functions:
        print("No functions found!")
        return

    print(f"Found {len(functions)} functions/methods.\n")

    for func in functions:
        func_type = classify_function_type(func["source"], type_model)
        read_label, read_score = classify_readability(func["source"], read_model)
        metrics = func["metrics"]
        smells = func["smells"]

        print(f"- {func['qualname']} (lines {func['start_line']}â€“{func['end_line']})")
        print(f"  Type: {func_type}")
        print(f"  Readability: {read_label} (â‰ˆ {read_score:.1f}/10)")
        print(f"  Complexity: {metrics['complexity']}, Max depth: {metrics['max_depth']}")
        print(f"  Params: {metrics['num_params']}, Calls: {metrics['num_calls']}")
        if smells:
            print("  Smells:")
            for s in smells:
                print(f"    - {s}")
        else:
            print("  Smells: None")
        print()

    doc_path = generate_markdown_documentation(
        file_path, functions, type_model, read_model
    )
    print(f"\nðŸ“„ Documentation generated at: {doc_path}")


In [11]:
def analyze_code_file(file_path,
                      type_model=None,
                      read_model=None):
    if type_model is None or read_model is None:
        raise ValueError("Models cannot be None. Pass both type_model and read_model.")

    print(f"Analyzing: {file_path}")

    functions = extract_functions_from_file(file_path)
    if not functions:
        print("No functions found!")
        return

    print(f"Found {len(functions)} functions/methods.\n")

    for func in functions:
        func_type = classify_function_type(func["source"], type_model)
        read_label, read_score = classify_readability(func["source"], read_model)
        metrics = func["metrics"]
        smells = func["smells"]

        print(f"- {func['qualname']} (lines {func['start_line']}â€“{func['end_line']})")
        print(f"  Type: {func_type}")
        print(f"  Readability: {read_label} (â‰ˆ {read_score:.1f}/10)")
        print(f"  Complexity: {metrics['complexity']}, Max depth: {metrics['max_depth']}")
        print(f"  Params: {metrics['num_params']}, Calls: {metrics['num_calls']}")
        if smells:
            print("  Smells:")
            for s in smells:
                print(f"    - {s}")
        else:
            print("  Smells: None")
        print()

    doc_path = generate_markdown_documentation(
        file_path, functions, type_model, read_model
    )
    print(f"\nðŸ“„ Documentation generated at: {doc_path}")


In [12]:
sample_code = """
import numpy as np
from flask import Flask, request, jsonify

app = Flask(__name__)

def clean_dataframe(df):
    df = df.dropna()
    df = df.rename(columns=str.lower)
    return df

@app.route('/predict', methods=['POST'])
def predict():
    payload = request.json
    features = np.array(payload['features']).reshape(1, -1)
    # dummy prediction
    score = float(features.mean())
    return jsonify({'score': score})

class UserRepository:
    def __init__(self, conn, debug=False, password='secret123'):
        self.conn = conn
        self.debug = debug
        self.password = password

    def get_user(self, user_id):
        cursor = self.conn.cursor()
        query = "SELECT * FROM users WHERE id = " + str(user_id)
        cursor.execute(query)
        return cursor.fetchone()

def helper_format_user(user):
    return {
        'id': user[0],
        'name': user[1],
    }

def ugly_function(a,b,c,d,e,f,g,h,i,j):
    if a>b:
        if b>c:
            if c>d:
                if d>e:
                    if e>f:
                        if f>g:
                            print(a,b,c,d,e,f,g,h,i,j)
"""

test_file_path = "/content/sample_app.py"
with open(test_file_path, "w", encoding="utf-8") as f:
    f.write(sample_code)

print(f"Sample file written to: {test_file_path}")


Sample file written to: /content/sample_app.py


In [13]:
analyze_code_file(
    test_file_path,
    type_model=function_type_classifier,
    read_model=readability_classifier,
)


Analyzing: /content/sample_app.py
Found 6 functions/methods.

- clean_dataframe (lines 7â€“10)
  Type: data_processing
  Readability: high (â‰ˆ 9.0/10)
  Complexity: 1, Max depth: 6
  Params: 1, Calls: 2
  Smells:
    - Deeply Nested Logic (max depth > 4)

- predict (lines 13â€“18)
  Type: api_endpoint
  Readability: high (â‰ˆ 9.0/10)
  Complexity: 1, Max depth: 7
  Params: 0, Calls: 6
  Smells:
    - Deeply Nested Logic (max depth > 4)

- UserRepository.__init__ (lines 21â€“24)
  Type: api_endpoint
  Readability: low (â‰ˆ 3.0/10)
  Complexity: 1, Max depth: 4
  Params: 4, Calls: 0
  Smells:
    - Possible hard-coded password/secret

- UserRepository.get_user (lines 26â€“30)
  Type: data_processing
  Readability: high (â‰ˆ 9.0/10)
  Complexity: 1, Max depth: 6
  Params: 2, Calls: 4
  Smells:
    - Deeply Nested Logic (max depth > 4)
    - Possible SQL string concatenation (injection risk)

- helper_format_user (lines 32â€“36)
  Type: ml_model
  Readability: high (â‰ˆ 9.0/10)
  Complexi

In [15]:
file_path = "/content/test_ml_api.py"  # or whichever you upload
analyze_code_file(
    file_path,
    type_model=function_type_classifier,
    read_model=readability_classifier,
)


Analyzing: /content/test_ml_api.py
Found 5 functions/methods.

- load_data (lines 8â€“12)
  Type: api_endpoint
  Readability: high (â‰ˆ 9.0/10)
  Complexity: 2, Max depth: 6
  Params: 1, Calls: 4
  Smells:
    - Deeply Nested Logic (max depth > 4)

- train_model (lines 14â€“18)
  Type: ml_model
  Readability: high (â‰ˆ 9.0/10)
  Complexity: 1, Max depth: 5
  Params: 0, Calls: 3
  Smells:
    - Deeply Nested Logic (max depth > 4)

- predict (lines 21â€“25)
  Type: ml_model
  Readability: high (â‰ˆ 9.0/10)
  Complexity: 1, Max depth: 8
  Params: 0, Calls: 7
  Smells:
    - Deeply Nested Logic (max depth > 4)

- bad_database_fetch (lines 27â€“32)
  Type: ml_model
  Readability: high (â‰ˆ 9.0/10)
  Complexity: 1, Max depth: 5
  Params: 2, Calls: 3
  Smells:
    - Deeply Nested Logic (max depth > 4)
    - Possible SQL string concatenation (injection risk)

- helper (lines 34â€“35)
  Type: utility
  Readability: high (â‰ˆ 9.0/10)
  Complexity: 2, Max depth: 5
  Params: 1, Calls: 0
  Smells:


In [16]:
file_path = "/content/test_ugly_utils.py"  # or whichever you upload
analyze_code_file(
    file_path,
    type_model=function_type_classifier,
    read_model=readability_classifier,
)


Analyzing: /content/test_ugly_utils.py
Found 3 functions/methods.

- process (lines 1â€“12)
  Type: ml_model
  Readability: low (â‰ˆ 3.0/10)
  Complexity: 8, Max depth: 12
  Params: 10, Calls: 2
  Smells:
    - Long Parameter List (>5 parameters)
    - Deeply Nested Logic (max depth > 4)
    - Uses print/input (debug or CLI in core logic)

- format_data (lines 14â€“15)
  Type: utility
  Readability: low (â‰ˆ 3.0/10)
  Complexity: 2, Max depth: 5
  Params: 1, Calls: 1
  Smells:
    - Deeply Nested Logic (max depth > 4)

- useless_function (lines 17â€“21)
  Type: ml_model
  Readability: low (â‰ˆ 3.0/10)
  Complexity: 1, Max depth: 4
  Params: 1, Calls: 3
  Smells:
    - Uses print/input (debug or CLI in core logic)


ðŸ“„ Documentation generated at: /content/DOCUMENTATION.md
