In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model_id = "Qwen/Qwen2.5-Coder-7B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", 
                                             torch_dtype=torch.bfloat16, trust_remote_code=True)
model.eval()


In [None]:
!nvidia-smi

# 2) Building prompt for the BRAIN

In [None]:
def build_code_explanation_prompt(code_snippet, code_components="code"):
    return f"""You are a helpful AI assistant skilled in Python code understanding.
Analyze the following {code_components} and generate a concise explanation of what it does. 
Just provide explanation and nothing else. Begin explanation with this token <begin> & end with <end>

```python
{code_snippet}
Summary:"""

In [None]:
## 🧪 Step 4: Run a Test

from transformers import TextStreamer

def summarize_code(code_snippet, build_prompt=build_code_explanation_prompt):
    prompt = build_prompt(code_snippet)
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    
    output = model.generate(
        **inputs,
        streamer=streamer,
        max_new_tokens=256,
        temperature=0.0,
        top_p=0.99,
        do_sample=False
    )
    
    return tokenizer.decode(output[0], skip_special_tokens=True)


In [None]:
example_code = '''
def get_top_n_words(text, n=10):
    words = text.lower().split()
    freq = {}
    for word in words:
        freq[word] = freq.get(word, 0) + 1
    return sorted(freq.items(), key=lambda x: x[1], reverse=True)[:n]
'''

summary = summarize_code(example_code)
print(summary)


In [None]:
example_code = """
from flask import Flask

app = Flask(__name__)

@app.route('/')
def hello():
    return 'Hello, World!'

if __name__ == '__main__':
    app.run(debug=True)

"""

summary = summarize_code(example_code)
print(summary)

In [None]:
example_code = '''
def bubble_sort(numbers):
    """
    Sorts a list of numbers using bubble sort algorithm.
    """
    n = len(numbers)
    for i in range(n):
        for j in range(0, n-i-1):
            if numbers[j] > numbers[j+1]:
                numbers[j], numbers[j+1] = numbers[j+1], numbers[j]
    return numbers

'''

summary = summarize_code(example_code)
print(summary)

In [None]:
example_code = '''
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel

model = GPT2LMHeadModel.from_pretrained("gpt2", output_hidden_states=True)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

input_text = "The quick brown fox"
inputs = tokenizer(input_text, return_tensors="pt")
outputs = model(**inputs)

hidden_states = outputs.hidden_states  # Tuple of (layer+1) tensors

# Apply LM head to each layer's last token hidden state
for i, hs in enumerate(hidden_states):
    logits = model.lm_head(hs)  # shape: [batch, seq_len, vocab]
    probs = torch.softmax(logits, dim=-1)
    top_token_id = torch.argmax(probs[0, -1]).item()
    top_token = tokenizer.decode(top_token_id)
    print(f"Layer {i}: Top token → {top_token}")

'''

summary = summarize_code(example_code)
print(summary)

In [None]:
example_code = '''
import ast
from pathlib import Path
from collections import defaultdict

class CodeFunction:
    def __init__(self, name, file_path, lineno, source, calls):
        self.name = name
        self.file_path = file_path
        self.lineno = lineno
        self.source = source
        self.calls = calls  # List of functions this one calls

    def __repr__(self):
        return f"<Function {self.name} (calls: {self.calls})>"

def extract_functions_from_file(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        source = f.read()

    try:
        tree = ast.parse(source)
    except SyntaxError:
        print(f"Skipping file due to parse error: {file_path}")
        return []

    funcs = []

    for node in ast.walk(tree):
        if isinstance(node, ast.FunctionDef):
            name = node.name
            lineno = node.lineno
            calls = extract_calls_from_node(node)

            # Grab raw source using line numbers
            lines = source.splitlines()
            end_line = max(getattr(n, 'lineno', lineno) for n in ast.walk(node))
            func_source = "\n".join(lines[lineno-1:end_line])

            funcs.append(CodeFunction(
                name=name,
                file_path=str(file_path),
                lineno=lineno,
                source=func_source,
                calls=calls
            ))

    return funcs

def extract_calls_from_node(node):
    """
    Return a list of function names that are called inside this node.
    Only tracks direct function calls (not methods for now).
    """
    calls = []
    for subnode in ast.walk(node):
        if isinstance(subnode, ast.Call):
            if isinstance(subnode.func, ast.Name):
                calls.append(subnode.func.id)  # like func()
            elif isinstance(subnode.func, ast.Attribute):
                calls.append(subnode.func.attr)  # like obj.method()
    return calls

def build_dependency_graph(codebase_path):
    graph = defaultdict(list)      # function_name → [called_function_names]
    function_index = {}           # function_name → CodeFunction object

    py_files = Path(codebase_path).rglob("*.py")

    for file in py_files:
        funcs = extract_functions_from_file(file)
        for func in funcs:
            function_index[func.name] = func
            for callee in func.calls:
                graph[func.name].append(callee)

    return graph, function_index

if __name__ == "__main__":
    base_path = "./my_project"
    graph, index = build_dependency_graph(base_path)

    for func_name, callees in graph.items():
        print(f"{func_name} → {callees}")
    
    print("\n=== Details for funcA ===")
    funcA = index.get("funcA")
    if funcA:
        print(f"File: {funcA.file_path}")
        print(f"Source:\n{funcA.source}")
	
'''

summary = summarize_code(example_code)
print(summary)

In [None]:
example_code = '''
"""
This program implements merge sort for a dictionary where:
- Keys are real-valued numbers (e.g., employee salaries or IDs)
- Values are instances of an `Employee` class (not defined here)

The `Employee` class is a custom-defined class representing an employee,
which may contain fields such as name, age, department, role, etc.
For example:
    class Employee:
        def __init__(self, name, age, dept):
            self.name = name
            self.age = age
            self.dept = dept

funcA(left: dict, right: dict) -> dict
    - This function takes two sorted dictionaries (by keys) and merges them into a single sorted dictionary.
    - It compares the keys in ascending order and appends the corresponding key-value pairs to a new dictionary.
    - Since dictionaries are unordered in older Python versions, this function returns a sorted dictionary (Python 3.7+ maintains insertion order).

"""

def funcB(data: dict) -> dict:
    if len(data) <= 1:
        return data

    keys = list(data.keys())
    mid = len(keys) // 2
    left_keys = keys[:mid]
    right_keys = keys[mid:]

    left_dict = {k: data[k] for k in left_keys}
    right_dict = {k: data[k] for k in right_keys}

    sorted_left = funcB(left_dict)
    sorted_right = funcB(right_dict)

    return funcA(sorted_left, sorted_right)

# funcA(left: dict, right: dict) -> dict` here
# The funcA function should:
# - Take two sorted dictionaries (by keys)
# - Merge them into one sorted dictionary based on ascending key order
# - Maintain the association between keys and their corresponding Employee objects

'''

summary = summarize_code(example_code)
print(summary)

In [None]:
example_code = """
from datetime import datetime
import random

class Logger:
    @staticmethod
    def log(message):
        print(f"[{datetime.now()}] {message}")

class Employee:
    def __init__(self, emp_id, name, age):
        self.emp_id = emp_id
        self.name = name
        self.age = age
        self.tasks = []
    
    def assign_task(self, task):
        Logger.log(f"Assigning task '{task.title}' to {self.name}")
        self.tasks.append(task)
    
    def complete_task(self, task_id):
        for task in self.tasks:
            if task.task_id == task_id:
                task.mark_completed()
                Logger.log(f"{self.name} completed task '{task.title}'")
                return
        Logger.log(f"Task ID {task_id} not found for {self.name}")
    
    def get_task_report(self):
        return [(task.title, task.status) for task in self.tasks]

class Manager(Employee):
    def __init__(self, emp_id, name, age, department):
        super().__init__(emp_id, name, age)
        self.department = department

    def assign_task_to_employee(self, employee, task):
        Logger.log(f"Manager {self.name} assigning task '{task.title}' to {employee.name}")
        employee.assign_task(task)

    def review_employee(self, employee):
        completed = sum(1 for task in employee.tasks if task.status == "Completed")
        total = len(employee.tasks)
        Logger.log(f"{employee.name} has completed {completed}/{total} tasks.")
        return completed, total

class Admin(Employee):
    def __init__(self, emp_id, name, age):
        super().__init__(emp_id, name, age)
    
    def reset_employee_tasks(self, employee):
        Logger.log(f"Admin {self.name} resetting tasks for {employee.name}")
        employee.tasks = []

class Department:
    def __init__(self, name):
        self.name = name
        self.employees = []

    def add_employee(self, employee):
        Logger.log(f"Adding {employee.name} to department {self.name}")
        self.employees.append(employee)

    def list_employees(self):
        return [emp.name for emp in self.employees]

class Task:
    def __init__(self, task_id, title, description):
        self.task_id = task_id
        self.title = title
        self.description = description
        self.status = "Assigned"
    
    def mark_completed(self):
        self.status = "Completed"

class SystemManager:
    def __init__(self):
        self.departments = []
        self.employees = []

    def create_department(self, name):
        dept = Department(name)
        self.departments.append(dept)
        Logger.log(f"Created department: {name}")
        return dept
    
    def hire_employee(self, name, age):
        emp_id = f"E{random.randint(1000, 9999)}"
        emp = Employee(emp_id, name, age)
        self.employees.append(emp)
        Logger.log(f"Hired employee: {emp.name} (ID: {emp_id})")
        return emp

    def assign_to_department(self, employee, department):
        department.add_employee(employee)

    def generate_report(self):
        Logger.log("Generating full system report...")
        for dept in self.departments:
            Logger.log(f"Department: {dept.name}")
            for emp in dept.employees:
                Logger.log(f" - {emp.name}: {len(emp.tasks)} tasks assigned.")

# Simulated usage
if __name__ == "__main__":
    system = SystemManager()

    # Setup
    dev_dept = system.create_department("Development")
    qa_dept = system.create_department("QA")

    alice = system.hire_employee("Alice", 28)
    bob = system.hire_employee("Bob", 32)
    claire = Manager("M001", "Claire", 40, dev_dept)

    system.assign_to_department(alice, dev_dept)
    system.assign_to_department(bob, qa_dept)
    system.assign_to_department(claire, dev_dept)

    # Tasks
    t1 = Task("T101", "Fix login bug", "Resolve authentication issue.")
    t2 = Task("T102", "Write test cases", "Write unit tests for UserService.")
    claire.assign_task_to_employee(alice, t1)
    claire.assign_task_to_employee(bob, t2)
    alice.complete_task("T101")
    claire.review_employee(alice)
    admin = Admin("A001", "Greg", 50)
    admin.reset_employee_tasks(alice)
    system.generate_report()
"""

summary = summarize_code(example_code)
print(summary)

# 3) Bug Hunting

Unfortunately, this isnt happening.

In [None]:
def build_bug_hunter_prompt(code_snippet, code_components="code"):
    return f"""Analyze the following {code_components} and find if there are any bugs in the code. 
    Find all the bugs & fix them if any. Begin explanation with <begin> & end with <end>. Just provide
    the bug fix nothing else. Bugs may be logical also.

```python
{code_snippet}
Summary:"""

example_code = """
import heapq
class Edge:
    def __init__(self, to, rev, cap, cost):
        self.to = to
        self.rev = rev
        self.cap = cap
        self.cost = cost

class MinCostMaxFlow:
    def __init__(self, N):
        self.N = N
        self.graph = [[] for _ in range(N)]

    def add_edge(self, fr, to, cap, cost):
        forward = Edge(to, len(self.graph[to]), cap, cost)
        backward = Edge(fr, len(self.graph[fr]), 0, -cost)
        self.graph[fr].append(forward)
        self.graph[to].append(backward)

    def min_cost_flow(self, s, t, maxf):
        N = self.N
        h = [0] * N
        prevv = [0] * N
        preve = [0] * N
        INF = float('inf')
        res = 0
        flow = 0

        def dijkstra():
            dist = [INF] * N
            used = [False] * N
            dist[s] = 0
            queue = [(0, s)]
            while queue:
                d, v = heapq.heappop(queue)
                if used[v]:  
                    continue
                used[v] = True
                for i, e in enumerate(self.graph[v]):
                    if e.cap <= 0: 
                        continue
                    new_cost = dist[v] + e.cost + h[v] - h[e.to]
                    if dist[e.to] > new_cost:
                        dist[e.to] = new_cost
                        prevv[e.to] = v
                        preve[e.to] = i
                        heapq.heappush(queue, (dist[e.to], e.to))
            return dist[t] != INF, dist

        while flow < maxf:
            found, dist = dijkstra()
            if not found:
                break
            for v in range(N):
                h[v] = dist[v] 
            d = maxf  
            v = t
            while v != s:
                d = min(d, self.graph[prevv[v]][preve[v]].cap)
                v = prevv[v]
            flow += d
            res += d * h[t]
            v = t
            while v != s:
                e = self.graph[prevv[v]][preve[v]]
                e.cap -= d
                self.graph[e.to][e.rev].cap += d  
                v = prevv[v]

        return flow, res

"""

summary = summarize_code(example_code, build_prompt=build_bug_hunter_prompt)
print(summary)

# 4) Line-by-line explanation

In [None]:
def build_line_by_line_explan_prompt(code_snippet, code_components="code"):
    return f"""You are a helpful AI assistant skilled in Python code understanding.
    Analyze the following {code_components} and provide a line by line explanation of the code. 

```python
{code_snippet}
Summary:"""

In [None]:
example_code="""
from collections import deque, defaultdict

class Edge:
    def __init__(self, to, rev, capacity):
        self.to = to          # destination node
        self.rev = rev        # index of reverse edge in the adjacency list
        self.capacity = capacity  # current capacity

class Dinic:
    def __init__(self, n):
        self.n = n                              # number of nodes
        self.graph = [[] for _ in range(n)]     # adjacency list of edges
        self.level = [0] * n                    # level of each node
        self.iter = [0] * n                     # current edge to explore for each node

    def add_edge(self, fr, to, capacity):
        forward = Edge(to, len(self.graph[to]), capacity)
        backward = Edge(fr, len(self.graph[fr]), 0)  # reverse edge with 0 capacity
        self.graph[fr].append(forward)
        self.graph[to].append(backward)

    def bfs(self, s, t):
        self.level = [-1] * self.n
        queue = deque([s])
        self.level[s] = 0

        while queue:
            v = queue.popleft()
            for edge in self.graph[v]:
                if edge.capacity > 0 and self.level[edge.to] < 0:
                    self.level[edge.to] = self.level[v] + 1
                    queue.append(edge.to)

        return self.level[t] != -1

    def dfs(self, v, t, flow):
        if v == t:
            return flow
        for i in range(self.iter[v], len(self.graph[v])):
            edge = self.graph[v][i]
            if edge.capacity > 0 and self.level[v] < self.level[edge.to]:
                d = self.dfs(edge.to, t, min(flow, edge.capacity))
                if d > 0:
                    edge.capacity -= d
                    self.graph[edge.to][edge.rev].capacity += d
                    return d
            self.iter[v] += 1
        return 0

    def max_flow(self, s, t):
        flow = 0
        INF = float('inf')
        while self.bfs(s, t):
            self.iter = [0] * self.n
            f = self.dfs(s, t, INF)
            while f > 0:
                flow += f
                f = self.dfs(s, t, INF)
        return flow
"""

summary = summarize_code(example_code, build_prompt=build_line_by_line_explan_prompt)
print(summary)

In [None]:
%%writefile resource_manager.py

class Logger:
    @staticmethod
    def log(message):
        print(f"[{datetime.now()}] {message}")

class Employee:
    def __init__(self, emp_id, name, age):
        self.emp_id = emp_id
        self.name = name
        self.age = age
        self.tasks = []
    
    def assign_task(self, task):
        Logger.log(f"Assigning task '{task.title}' to {self.name}")
        self.tasks.append(task)
    
    def complete_task(self, task_id):
        for task in self.tasks:
            if task.task_id == task_id:
                task.mark_completed()
                Logger.log(f"{self.name} completed task '{task.title}'")
                return
        Logger.log(f"Task ID {task_id} not found for {self.name}")
    
    def get_task_report(self):
        return [(task.title, task.status) for task in self.tasks]

class Manager(Employee):
    def __init__(self, emp_id, name, age, department):
        super().__init__(emp_id, name, age)
        self.department = department

    def assign_task_to_employee(self, employee, task):
        Logger.log(f"Manager {self.name} assigning task '{task.title}' to {employee.name}")
        employee.assign_task(task)

    def review_employee(self, employee):
        completed = sum(1 for task in employee.tasks if task.status == "Completed")
        total = len(employee.tasks)
        Logger.log(f"{employee.name} has completed {completed}/{total} tasks.")
        return completed, total

class Admin(Employee):
    def __init__(self, emp_id, name, age):
        super().__init__(emp_id, name, age)
    
    def reset_employee_tasks(self, employee):
        Logger.log(f"Admin {self.name} resetting tasks for {employee.name}")
        employee.tasks = []

class Department:
    def __init__(self, name):
        self.name = name
        self.employees = []

    def add_employee(self, employee):
        Logger.log(f"Adding {employee.name} to department {self.name}")
        self.employees.append(employee)

    def list_employees(self):
        return [emp.name for emp in self.employees]

class Task:
    def __init__(self, task_id, title, description):
        self.task_id = task_id
        self.title = title
        self.description = description
        self.status = "Assigned"
    
    def mark_completed(self):
        self.status = "Completed"

class SystemManager:
    def __init__(self):
        self.departments = []
        self.employees = []

    def create_department(self, name):
        dept = Department(name)
        self.departments.append(dept)
        Logger.log(f"Created department: {name}")
        return dept
    
    def hire_employee(self, name, age):
        emp_id = f"E{random.randint(1000, 9999)}"
        emp = Employee(emp_id, name, age)
        self.employees.append(emp)
        Logger.log(f"Hired employee: {emp.name} (ID: {emp_id})")
        return emp

    def assign_to_department(self, employee, department):
        department.add_employee(employee)

    def generate_report(self):
        Logger.log("Generating full system report...")
        for dept in self.departments:
            Logger.log(f"Department: {dept.name}")
            for emp in dept.employees:
                Logger.log(f" - {emp.name}: {len(emp.tasks)} tasks assigned.")

# 5) Parser + Dependency Graph Builder

In [None]:
import ast
from pathlib import Path
from collections import defaultdict

class CodeFunction:
    def __init__(self, name, file_path, lineno, source, calls, parent_class=None):
        self.name = name
        self.file_path = file_path
        self.lineno = lineno
        self.source = source
        self.calls = calls
        self.parent_class = parent_class  # Optional class this function belongs to

    def fqname(self):
        return f"{self.parent_class}.{self.name}" if self.parent_class else self.name

    def __repr__(self):
        return f"<Function {self.fqname()} (calls: {self.calls})>"

class CodeClass:
    def __init__(self, name, file_path, lineno, source, methods):
        self.name = name
        self.file_path = file_path
        self.lineno = lineno
        self.source = source
        self.methods = methods  # List of CodeFunction objects

    def __repr__(self):
        return f"<Class {self.name} (methods: {[m.name for m in self.methods]})>"

		
def extract_classes_and_functions(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        source = f.read()

    try:
        tree = ast.parse(source)
    except SyntaxError:
        print(f"Skipping file due to parse error: {file_path}")
        return [], []

    functions = []
    classes = []
    lines = source.splitlines()

    for node in ast.walk(tree):
        if isinstance(node, ast.FunctionDef) and isinstance(getattr(node, 'parent', None), ast.ClassDef):
            continue  # We’ll handle class methods inside class parsing

        if isinstance(node, ast.FunctionDef):
            lineno = node.lineno
            end_line = max(getattr(n, 'lineno', lineno) for n in ast.walk(node))
            func_source = "\n".join(lines[lineno-1:end_line])
            calls = extract_calls_from_node(node)
            functions.append(CodeFunction(
                name=node.name,
                file_path=str(file_path),
                lineno=lineno,
                source=func_source,
                calls=calls,
            ))

        elif isinstance(node, ast.ClassDef):
            class_lineno = node.lineno
            class_end_line = max(getattr(n, 'lineno', class_lineno) for n in ast.walk(node))
            class_source = "\n".join(lines[class_lineno-1:class_end_line])

            methods = []
            for child in node.body:
                if isinstance(child, ast.FunctionDef):
                    lineno = child.lineno
                    end_line = max(getattr(n, 'lineno', lineno) for n in ast.walk(child))
                    method_source = "\n".join(lines[lineno-1:end_line])
                    calls = extract_calls_from_node(child)
                    methods.append(CodeFunction(
                        name=child.name,
                        file_path=str(file_path),
                        lineno=lineno,
                        source=method_source,
                        calls=calls,
                        parent_class=node.name,
                    ))

            classes.append(CodeClass(
                name=node.name,
                file_path=str(file_path),
                lineno=class_lineno,
                source=class_source,
                methods=methods
            ))

    return functions, classes
	
def extract_calls_from_node(node):
    calls = []

    class CallVisitor(ast.NodeVisitor):
        def visit_Call(self, call_node):
            # Extract the function name from different types of function calls
            if isinstance(call_node.func, ast.Name):
                # Simple function call: foo()
                calls.append(call_node.func.id)
            elif isinstance(call_node.func, ast.Attribute):
                # Method or attribute call: obj.method()
                # Try to extract just the attribute (method) name
                calls.append(call_node.func.attr)
            self.generic_visit(call_node)  # Continue visiting children

    CallVisitor().visit(node)
    return calls

	
def attach_parents(tree):
    for node in ast.walk(tree):
        for child in ast.iter_child_nodes(node):
            child.parent = node

def resolve_fqname(name, function_index):
    # If already fully qualified
    if name in function_index:
        return name
    # Try to match by suffix
    matches = [fq for fq in function_index if fq.endswith(f".{name}") or fq == name]
    return matches[0] if matches else name  # fallback to original


			
def build_dependency_graph(codebase_path):
    graph = defaultdict(list)
    function_index = {}

    py_files = list(Path(codebase_path).rglob("*.py"))

    all_funcs = []
    all_methods = []

    # --------- Pass 1: Build function index ---------
    for file in py_files:
        funcs, classes = extract_classes_and_functions(file)

        for func in funcs:
            key = func.fqname()
            function_index[key] = func
            all_funcs.append(func)

        for cls in classes:
            for method in cls.methods:
                key = method.fqname()
                function_index[key] = method
                all_methods.append(method)

    # --------- Pass 2: Build graph using full index ---------
    for func in all_funcs:
        key = func.fqname()
        for callee in func.calls:
            fq_callee = resolve_fqname(callee, function_index)
            graph[key].append(fq_callee)

    for method in all_methods:
        key = method.fqname()
        for callee in method.calls:
            fq_callee = resolve_fqname(callee, function_index)
            graph[key].append(fq_callee)

    return graph, function_index


	

base_path = "."
graph, index = build_dependency_graph(base_path)
graph = {k: v for k, v in graph.items() if '.' in k}  # This is a temporary fix for the duplicate graph node. Need to handle code logic properly

for func_name, callees in graph.items():
    print(f"{func_name} → {callees}")
    
    #print("\n=== Details for funcA ===")
    #funcA = index.get("funcA")
    #if funcA:
    #    print(f"File: {funcA.file_path}")
    #    print(f"Source:\n{funcA.source}")


		

# 6) Building the Context Assembler

In [None]:
def get_dependency_closure(func_name, graph):
    visited = set()
    stack = [func_name]

    while stack:
        current = stack.pop()
        if current not in visited:
            visited.add(current)
            for dep in graph.get(current, []):
                if dep not in visited:
                    stack.append(dep)

    return visited

closure = get_dependency_closure("SystemManager.hire_employee", graph)
print(closure)
# closure → {'DataLoader.load', 'DataLoader._validate_input', 'parse_config', ...}


In [None]:
print("SystemManager.hire_employee" in index)  # Should be True
print(index.get("SystemManager.hire_employee"))  # Should not be None

In [None]:
print("SystemManager.hire_employee" in graph)
print(graph["SystemManager.hire_employee"])


In [None]:
def assemble_context(func_name, graph, function_index):
    relevant_funcs = get_dependency_closure(func_name, graph)
    relevant_funcs.add(func_name)  # Make sure the root function is included
    
    context_parts = []

    for fname in sorted(relevant_funcs):
        if fname in function_index:
            func_obj = function_index[fname]
            label = f"--- {fname} (from {func_obj.file_path}) ---"
            context_parts.append(label)
            context_parts.append(func_obj.source)
            context_parts.append("")  # blank line
        else:
            print(f"[Warning] Missing in index: {fname}")  # Debug print

    return "\n".join(context_parts)


def build_prompt_for_function(func_name, graph, function_index):
    context = assemble_context(func_name, graph, function_index)
    
    prompt = f"""
You are an AI assistant helping a developer understand a Python function.

Below is the source code for the function `{func_name}` and its dependencies.

{context}

Explain clearly what `{func_name}` does in simple terms.
"""
    return prompt.strip()

if __name__ == "__main__":
    graph, function_index = build_dependency_graph(".")

    func_name = "SystemManager.hire_employee"  # or just "load_data"
    prompt = build_prompt_for_function(func_name, graph, function_index)

    print(prompt)  # or send it to the LLM

    summary = summarize_code(prompt)
    print(summary)


# AGENTIC WRAPPER CLASSES

# DISTILLING QWEN2.5-7B to QWEN2.5-1.5B model

In [None]:
!pip install accelerate --quiet
!pip install -U bitsandbytes --quiet

In [None]:
!nvidia-smi

In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from accelerate import Accelerator
import gc
import os
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set environment variables for better memory management
os.environ["TOKENIZERS_PARALLELISM"] = "false"

class CodeDataset:
    def __init__(self, name='jtatman/python-code-dataset-500k', split='train', max_samples=None):
        self.dataset = load_dataset(name, split=split)
        if max_samples:
            self.dataset = self.dataset.select(range(max_samples))
        
        # Pre-filter dataset to remove problematic entries
        self.dataset = self.dataset.filter(self._is_valid_sample)
        logger.info(f"Dataset filtered to {len(self.dataset)} valid samples")
    
    def _is_valid_sample(self, sample):
        """Filter out invalid samples"""
        instruction = sample.get('instruction', '')
        output = sample.get('output', '')
        
        # Check if both fields exist and are strings
        if not isinstance(instruction, str) or not isinstance(output, str):
            return False
        
        # Check if at least one field has meaningful content
        if len(instruction.strip()) == 0 and len(output.strip()) == 0:
            return False
        
        # Check for reasonable length (not too short or too long)
        total_len = len(instruction) + len(output)
        if total_len < 10 or total_len > 2048:
            return False
            
        return True
    
    def get_data(self):
        return self.dataset

class TeacherModel:
    def __init__(self, model_name="Qwen/Qwen2.5-Coder-14B", device='cuda'): #"Qwen/Qwen2.5-Coder-7B"
        # Use 8-bit quantization for memory efficiency
        bnb_config = BitsAndBytesConfig(
            load_in_8bit=True,
            llm_int8_enable_fp32_cpu_offload=True,
            llm_int8_skip_modules=["lm_head"],
        )

        # For maximum quality (slightly larger):
        #bnb_config = BitsAndBytesConfig(
        #    load_in_4bit=True,
        #    bnb_4bit_quant_type="fp4",  # Float4 instead of NormalFloat4
        #    bnb_4bit_use_double_quant=True,
        #    bnb_4bit_compute_dtype=torch.float16,
        #)

        # For maximum memory savings (may impact quality):
        #bnb_config = BitsAndBytesConfig(
        #    load_in_4bit=True,
        #    bnb_4bit_quant_type="nf4",
        #    bnb_4bit_use_double_quant=True,
        #    bnb_4bit_compute_dtype=torch.float16,
        #    llm_int8_skip_modules=["lm_head"],)  
                                            # Keep output layer unquantized, # Common modules to keep unquantized for stability:
                                            #llm_int8_skip_modules=["lm_head", "embed_tokens", "norm"]

                                            # Or just the critical ones:
                                            # llm_int8_skip_modules=["lm_head"]
    

        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side='left')
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            quantization_config=bnb_config,
            device_map='auto',
            trust_remote_code=True,
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True,
        )
        self.device = device
        self.model.eval()

    def forward(self, input_ids, attention_mask):
        with torch.no_grad():
            try:
                outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
                logits = outputs.logits
                
                # Check for invalid values and handle them
                if torch.isnan(logits).any() or torch.isinf(logits).any():
                    logger.warning("NaN/Inf detected in teacher logits, applying fixes...")
                    logits = torch.where(torch.isnan(logits), torch.zeros_like(logits), logits)
                    logits = torch.clamp(logits, min=-50.0, max=50.0)
                
                return logits.half()
            except Exception as e:
                logger.error(f"Error in teacher forward pass: {e}")
                raise

class StudentModel:
    def __init__(self, model_name="Qwen/Qwen2.5-Coder-1.5B", device='cuda'):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side='left')
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            device_map='auto',
            trust_remote_code=True,
            low_cpu_mem_usage=True,
        )
        self.device = device

    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.logits

def prepare_text_pairs(batch):
    """Prepare instruction-output pairs with better formatting"""
    texts = []
    
    for i in range(len(batch['instruction'])):
        instruction = str(batch['instruction'][i]).strip()
        output = str(batch['output'][i]).strip()
        
        # Create a properly formatted text pair
        if instruction and output:
            # Use a clear format that separates instruction from output
            formatted_text = f"# Instruction:\n{instruction}\n\n# Output:\n{output}"
        elif instruction:
            formatted_text = f"# Code:\n{instruction}"
        elif output:
            formatted_text = f"# Code:\n{output}"
        else:
            continue  # Skip empty entries
            
        texts.append(formatted_text)
    
    return texts

def tokenize_batch(batch, tokenizer, max_length=512):
    """Improved tokenization with better error handling"""
    try:
        # Prepare texts
        texts = prepare_text_pairs(batch)
        
        if not texts:
            logger.warning("No valid texts found in batch")
            return None
        
        # Tokenize with better parameters
        tokens = tokenizer(
            texts,
            padding='max_length',
            truncation=True,
            max_length=max_length,
            return_tensors='pt',
            add_special_tokens=True
        )
        
        # Verify tokenization was successful
        if tokens['input_ids'].size(0) == 0:
            logger.warning("Tokenization resulted in empty batch")
            return None
        
        # Check for valid tokens (not all padding)
        non_pad_tokens = (tokens['input_ids'] != tokenizer.pad_token_id).sum(dim=1)
        if non_pad_tokens.min() < 5:  # At least 5 non-padding tokens
            logger.warning("Batch contains sequences with too few valid tokens")
            return None
        
        return {
            'input_ids': tokens['input_ids'],
            'attention_mask': tokens['attention_mask']
        }
        
    except Exception as e:
        logger.error(f"Error in tokenization: {e}")
        return None

def compute_distillation_loss(student_logits, teacher_logits, attention_mask, 
                            temperature=3.0, alpha=0.7, top_k=None):
    """
    Compute a stable distillation loss combining KL divergence and cross-entropy
    """
    # Ensure both logits are in float32 for numerical stability
    student_logits = student_logits.float()
    teacher_logits = teacher_logits.float()
    
    # Align dimensions
    min_seq_len = min(student_logits.size(1), teacher_logits.size(1))
    min_vocab_size = min(student_logits.size(-1), teacher_logits.size(-1))
    
    student_logits = student_logits[:, :min_seq_len, :min_vocab_size]
    teacher_logits = teacher_logits[:, :min_seq_len, :min_vocab_size]
    attention_mask = attention_mask[:, :min_seq_len]
    
    # Apply temperature scaling
    student_logits_scaled = student_logits / temperature
    teacher_logits_scaled = teacher_logits / temperature
    
    # Clamp logits to prevent extreme values
    student_logits_scaled = torch.clamp(student_logits_scaled, min=-20, max=20)
    teacher_logits_scaled = torch.clamp(teacher_logits_scaled, min=-20, max=20)
    
    # Compute soft targets from teacher
    with torch.no_grad():
        teacher_probs = torch.softmax(teacher_logits_scaled, dim=-1)
        # Add small epsilon for numerical stability
        teacher_probs = torch.clamp(teacher_probs, min=1e-8, max=1.0)
    
    # Compute student log probabilities
    student_log_probs = torch.log_softmax(student_logits_scaled, dim=-1)
    
    # KL divergence loss (only on non-masked positions)
    kl_loss = -torch.sum(teacher_probs * student_log_probs, dim=-1)
    
    # Apply attention mask and compute mean
    mask = attention_mask.float()
    masked_kl_loss = kl_loss * mask
    
    valid_tokens = mask.sum()
    if valid_tokens > 0:
        kl_loss_mean = masked_kl_loss.sum() / valid_tokens
    else:
        kl_loss_mean = torch.tensor(0.0, device=student_logits.device, requires_grad=True)
    
    # Standard cross-entropy loss for hard targets
    teacher_hard_targets = torch.argmax(teacher_logits, dim=-1)
    ce_loss = torch.nn.functional.cross_entropy(
        student_logits.view(-1, student_logits.size(-1)),
        teacher_hard_targets.view(-1),
        reduction='none',
        ignore_index=-100
    ).view(student_logits.shape[:-1])
    
    masked_ce_loss = ce_loss * mask
    if valid_tokens > 0:
        ce_loss_mean = masked_ce_loss.sum() / valid_tokens
    else:
        ce_loss_mean = torch.tensor(0.0, device=student_logits.device, requires_grad=True)
    
    # Combine losses
    total_loss = alpha * kl_loss_mean + (1 - alpha) * ce_loss_mean
    
    # Final safety checks
    if torch.isnan(total_loss) or torch.isinf(total_loss):
        logger.warning("NaN/Inf loss detected, returning safe fallback")
        return torch.tensor(1.0, device=student_logits.device, requires_grad=True)
    
    # Clamp final loss to reasonable range
    total_loss = torch.clamp(total_loss, min=0.0, max=50.0)
    
    return total_loss

def safe_gradient_norm(model):
    """Compute gradient norm safely"""
    total_norm = 0.0
    param_count = 0
    
    for p in model.parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(2).item()
            if not (torch.isnan(torch.tensor(param_norm)) or torch.isinf(torch.tensor(param_norm))):
                total_norm += param_norm ** 2
                param_count += 1
    
    if param_count == 0:
        return 0.0
    
    total_norm = (total_norm ** 0.5)
    return total_norm

class DistillationTrainer:
    def __init__(self, teacher, student, dataset, batch_size=2, grad_accum=8, 
                 learning_rate=2e-5, max_grad_norm=1.0, alpha=0.7):
        self.teacher = teacher
        self.student = student
        self.dataset = dataset
        self.batch_size = batch_size
        self.grad_accum = grad_accum
        self.max_grad_norm = max_grad_norm
        self.alpha = alpha
        
        # Initialize accelerator
        self.accelerator = Accelerator(
            mixed_precision='fp16',
            gradient_accumulation_steps=grad_accum
        )
        
        # Optimizer with better settings
        self.optimizer = torch.optim.AdamW(
            self.student.model.parameters(),
            lr=learning_rate,
            betas=(0.9, 0.999),
            eps=1e-8,
            weight_decay=0.01
        )
        
        # Learning rate scheduler
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, 
            T_max=len(dataset) // (batch_size * grad_accum),
            eta_min=1e-6
        )
        
        # Use student tokenizer for consistency
        self.tokenizer = self.student.tokenizer
        
        # Prepare models
        self.student.model, self.optimizer, self.scheduler = self.accelerator.prepare(
            self.student.model, self.optimizer, self.scheduler
        )
        
        # Statistics tracking
        self.stats = {
            'processed_batches': 0,
            'skipped_batches': 0,
            'total_loss': 0.0,
            'high_grad_norm_count': 0
        }
        
        logger.info(f"Trainer initialized with batch_size={batch_size}, grad_accum={grad_accum}")
        logger.info(f"Effective batch size: {batch_size * grad_accum}")

    def train(self, epochs=1, save_steps=500):
        # Create dataloader with better settings
        data_loader = DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=2,
            pin_memory=True,
            drop_last=True  # Ensure consistent batch sizes
        )
        data_loader = self.accelerator.prepare(data_loader)
        
        self.student.model.train()
        
        for epoch in range(epochs):
            logger.info(f"Starting epoch {epoch + 1}/{epochs}")
            
            progress_bar = tqdm(data_loader, desc=f"Epoch {epoch+1}")
            
            for step, batch in enumerate(progress_bar):
                try:
                    with self.accelerator.accumulate(self.student.model):
                        # Tokenize batch
                        tokenized = tokenize_batch(batch, self.tokenizer, max_length=512)
                        
                        if tokenized is None:
                            self.stats['skipped_batches'] += 1
                            continue
                        
                        input_ids = tokenized['input_ids'].to(self.accelerator.device)
                        attention_mask = tokenized['attention_mask'].to(self.accelerator.device)
                        
                        # Get teacher predictions
                        with torch.no_grad():
                            teacher_logits = self.teacher.forward(input_ids, attention_mask)
                        
                        # Student forward pass
                        with self.accelerator.autocast():
                            student_outputs = self.student.model(
                                input_ids=input_ids,
                                attention_mask=attention_mask
                            )
                            student_logits = student_outputs.logits
                            
                            # Compute distillation loss
                            loss = compute_distillation_loss(
                                student_logits, teacher_logits, attention_mask,
                                temperature=3.0, alpha=self.alpha
                            )
                        
                        # Backward pass
                        self.accelerator.backward(loss)
                        
                        # Gradient clipping and optimization
                        if self.accelerator.sync_gradients:
                            grad_norm = safe_gradient_norm(self.student.model)
                            
                            if grad_norm > self.max_grad_norm:
                                self.accelerator.clip_grad_norm_(
                                    self.student.model.parameters(), 
                                    self.max_grad_norm
                                )
                                self.stats['high_grad_norm_count'] += 1
                            
                            self.optimizer.step()
                            self.scheduler.step()
                            self.optimizer.zero_grad()
                        
                        # Update statistics
                        self.stats['processed_batches'] += 1
                        self.stats['total_loss'] += loss.item()
                        
                        # Update progress bar
                        if self.stats['processed_batches'] > 0:
                            avg_loss = self.stats['total_loss'] / self.stats['processed_batches']
                            current_lr = self.scheduler.get_last_lr()[0]
                            
                            progress_info = {
                                'loss': f'{loss.item():.4f}',
                                'avg_loss': f'{avg_loss:.4f}',
                                'lr': f'{current_lr:.2e}',
                                'processed': self.stats['processed_batches'],
                                'skipped': self.stats['skipped_batches'],
                                'grad_clips': self.stats['high_grad_norm_count']
                            }
                            progress_bar.set_postfix(progress_info)
                        
                        # Periodic cleanup
                        if step % 50 == 0:
                            torch.cuda.empty_cache()
                            
                except Exception as e:
                    logger.error(f"Error at step {step}: {e}")
                    self.stats['skipped_batches'] += 1
                    torch.cuda.empty_cache()
                    continue
            
            # End of epoch summary
            self._print_epoch_summary(epoch + 1)
            
            # Cleanup
            torch.cuda.empty_cache()
            gc.collect()

    def _print_epoch_summary(self, epoch):
        """Print detailed epoch summary"""
        total_batches = self.stats['processed_batches'] + self.stats['skipped_batches']
        success_rate = (self.stats['processed_batches'] / total_batches * 100) if total_batches > 0 else 0
        avg_loss = (self.stats['total_loss'] / self.stats['processed_batches']) if self.stats['processed_batches'] > 0 else 0
        
        logger.info(f"\nEpoch {epoch} Summary:")
        logger.info(f"  Processed Batches: {self.stats['processed_batches']}")
        logger.info(f"  Skipped Batches: {self.stats['skipped_batches']}")
        logger.info(f"  Success Rate: {success_rate:.1f}%")
        logger.info(f"  Average Loss: {avg_loss:.4f}")
        logger.info(f"  Gradient Clips: {self.stats['high_grad_norm_count']}")
        logger.info(f"  Current LR: {self.scheduler.get_last_lr()[0]:.2e}")

    def save_model(self, output_dir="distil-Qwen2.5-Coder-1.5B"):
        """Save the distilled model"""
        self.accelerator.wait_for_everyone()
        unwrapped_model = self.accelerator.unwrap_model(self.student.model)
        
        if self.accelerator.is_main_process:
            unwrapped_model.save_pretrained(output_dir)
            self.tokenizer.save_pretrained(output_dir)
            logger.info(f"Model saved to {output_dir}")

def main():
    # Setup
    torch.cuda.empty_cache()
    gc.collect()
    
    # GPU info
    logger.info(f"Available GPUs: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        props = torch.cuda.get_device_properties(i)
        logger.info(f"GPU {i}: {props.name} - {props.total_memory/1e9:.1f}GB")
    
    # Load dataset with better filtering
    logger.info("Loading and filtering dataset...")
    dataset_loader = CodeDataset(max_samples=50000)  # 50K samples
    dataset = dataset_loader.get_data()
    logger.info(f"Final dataset size: {len(dataset)} samples")
    
    # Load models
    logger.info("Loading teacher model...")
    teacher = TeacherModel("Qwen/Qwen2.5-Coder-7B") # Qwen/Qwen2.5-Coder-14B-Instruct
    
    logger.info("Loading student model...")
    student = StudentModel()
    
    logger.info(f"Teacher vocab size: {len(teacher.tokenizer)}")
    logger.info(f"Student vocab size: {len(student.tokenizer)}")
    
    # Initialize trainer
    trainer = DistillationTrainer(
        teacher=teacher,
        student=student,
        dataset=dataset,
        batch_size=64,
        grad_accum=1,
        learning_rate=2e-5,
        max_grad_norm=1.0,
        alpha=0.8
    )
    
    # Train
    logger.info("Starting training...")
    trainer.train(epochs=2)
    
    # Save model
    trainer.save_model("distil-Qwen2.5-Coder-1.5B-Instruct")
    logger.info("Training completed!")

if __name__ == "__main__":
    main()

2025-08-20 12:07:33.374894: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1755691653.397350     303 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1755691653.404093     303 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 1: 100%|██████████| 513/513 [2:23:06<00:00, 16.74s/it]  
Epoch 2: 100%|██████████| 513/513 [2:23:55<00:00, 16.83s/it]  


1) Using 14B teacher @4-bit precision. Student model 1.5B @FP-16 precision. Below is the GPU usage for the distillation task that I performed.
This setup takes 9hrs per epoch. Batch size = 16

2) Using 7B teacher @4-bit precision takes 4hrs per epoch. Batch size = 16.

3) Using 7B teacher @8-bit precision takes 2hrs 20mins per epoch. Batch size = 64.


ToDos :

a) Use the neulab/conala dataset to find how good the distilled model is in comparison to 1.5B coder , 3B coder & 7B coder model. Use 14B Qwen2.5 coder & Gpt-OSS-20B in a LLM-as-a-Judge setup.

In [None]:
!nvidia-smi

In [None]:
teacher.device()