diff --git a/sourcecode-parser/diagnostic/llm.go b/sourcecode-parser/diagnostic/llm.go new file mode 100644 index 00000000..b617e3e0 --- /dev/null +++ b/sourcecode-parser/diagnostic/llm.go @@ -0,0 +1,219 @@ +package diagnostic + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" +) + +// LLMClient handles communication with local LLM (Ollama/vLLM). +type LLMClient struct { + BaseURL string + Model string + Temperature float64 + MaxTokens int + HTTPClient *http.Client +} + +// NewLLMClient creates a new LLM client. +// Example: +// +// client := NewLLMClient("http://localhost:11434", "qwen3-coder:32b") +func NewLLMClient(baseURL, model string) *LLMClient { + return &LLMClient{ + BaseURL: baseURL, + Model: model, + Temperature: 0.0, // Deterministic + MaxTokens: 2000, + HTTPClient: &http.Client{ + Timeout: 120 * time.Second, + }, + } +} + +// AnalyzeFunction sends a function to the LLM for pattern discovery and test generation. +// Returns structured analysis result or error. +// +// Performance: ~2-5 seconds per function (depends on function size) +// +// Example: +// +// client := NewLLMClient("http://localhost:11434", "qwen3-coder:32b") +// result, err := client.AnalyzeFunction(functionMetadata) +// if err != nil { +// log.Printf("LLM analysis failed: %v", err) +// return nil, err +// } +// fmt.Printf("Found %d sources, %d sinks, %d test cases\n", +// len(result.DiscoveredPatterns.Sources), +// len(result.DiscoveredPatterns.Sinks), +// len(result.DataflowTestCases)) +func (c *LLMClient) AnalyzeFunction(fn *FunctionMetadata) (*LLMAnalysisResult, error) { + startTime := time.Now() + + // Build prompt + prompt := BuildAnalysisPrompt(fn.SourceCode) + + // Call LLM + responseText, err := c.callOllama(prompt) + if err != nil { + return nil, fmt.Errorf("LLM call failed: %w", err) + } + + // Parse JSON response + var result LLMAnalysisResult + err = json.Unmarshal([]byte(responseText), &result) + if err != nil { + return nil, fmt.Errorf("failed to parse LLM response: %w\nResponse: %s", err, responseText) + } + + // Add metadata + result.FunctionFQN = fn.FQN + result.AnalysisMetadata.ProcessingTime = time.Since(startTime).String() + result.AnalysisMetadata.ModelUsed = c.Model + + // Validate result + if err := c.validateResult(&result); err != nil { + return nil, fmt.Errorf("invalid LLM result: %w", err) + } + + return &result, nil +} + +// callOllama makes HTTP request to Ollama API. +func (c *LLMClient) callOllama(prompt string) (string, error) { + // Ollama API format + requestBody := map[string]interface{}{ + "model": c.Model, + "prompt": prompt, + "stream": false, + "options": map[string]interface{}{ + "temperature": c.Temperature, + "num_predict": c.MaxTokens, + }, + "format": "json", // Request JSON output + } + + jsonBody, err := json.Marshal(requestBody) + if err != nil { + return "", fmt.Errorf("failed to marshal request: %w", err) + } + + // Make request + url := c.BaseURL + "/api/generate" + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, url, bytes.NewBuffer(jsonBody)) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return "", fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("LLM returned status %d: %s", resp.StatusCode, string(bodyBytes)) + } + + // Read response + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read response: %w", err) + } + + // Parse Ollama response format + var ollamaResp struct { + Response string `json:"response"` + Done bool `json:"done"` + } + err = json.Unmarshal(bodyBytes, &ollamaResp) + if err != nil { + return "", fmt.Errorf("failed to parse Ollama response: %w", err) + } + + return ollamaResp.Response, nil +} + +// validateResult checks that LLM result has required fields. +func (c *LLMClient) validateResult(result *LLMAnalysisResult) error { + if result.AnalysisMetadata.Confidence < 0.0 || result.AnalysisMetadata.Confidence > 1.0 { + return fmt.Errorf("invalid confidence: %f", result.AnalysisMetadata.Confidence) + } + + // Validate test cases + for i, tc := range result.DataflowTestCases { + if tc.Source.Line <= 0 { + return fmt.Errorf("test case %d: invalid source line %d", i, tc.Source.Line) + } + if tc.Sink.Line <= 0 { + return fmt.Errorf("test case %d: invalid sink line %d", i, tc.Sink.Line) + } + if tc.Confidence < 0.0 || tc.Confidence > 1.0 { + return fmt.Errorf("test case %d: invalid confidence %f", i, tc.Confidence) + } + } + + return nil +} + +// AnalyzeBatch analyzes multiple functions in parallel. +// Returns results map (FQN -> result) and errors map (FQN -> error). +// +// Performance: 4-8 parallel workers, ~30-60 minutes for 10k functions +// +// Example: +// +// client := NewLLMClient("http://localhost:11434", "qwen3-coder:32b") +// results, errors := client.AnalyzeBatch(functions, 4) +// fmt.Printf("Analyzed %d functions, %d errors\n", len(results), len(errors)) +func (c *LLMClient) AnalyzeBatch(functions []*FunctionMetadata, concurrency int) (map[string]*LLMAnalysisResult, map[string]error) { + results := make(map[string]*LLMAnalysisResult) + errors := make(map[string]error) + + // Channel for work + workChan := make(chan *FunctionMetadata, len(functions)) + resultChan := make(chan struct { + fqn string + result *LLMAnalysisResult + err error + }, len(functions)) + + // Start workers + for i := 0; i < concurrency; i++ { + go func() { + for fn := range workChan { + result, err := c.AnalyzeFunction(fn) + resultChan <- struct { + fqn string + result *LLMAnalysisResult + err error + }{fn.FQN, result, err} + } + }() + } + + // Send work + for _, fn := range functions { + workChan <- fn + } + close(workChan) + + // Collect results + for i := 0; i < len(functions); i++ { + res := <-resultChan + if res.err != nil { + errors[res.fqn] = res.err + } else { + results[res.fqn] = res.result + } + } + + return results, errors +} diff --git a/sourcecode-parser/diagnostic/llm_test.go b/sourcecode-parser/diagnostic/llm_test.go new file mode 100644 index 00000000..49bbc07d --- /dev/null +++ b/sourcecode-parser/diagnostic/llm_test.go @@ -0,0 +1,392 @@ +package diagnostic + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestNewLLMClient tests client creation. +func TestNewLLMClient(t *testing.T) { + client := NewLLMClient("http://localhost:11434", "qwen3-coder:32b") + + assert.NotNil(t, client) + assert.Equal(t, "http://localhost:11434", client.BaseURL) + assert.Equal(t, "qwen3-coder:32b", client.Model) + assert.Equal(t, 0.0, client.Temperature) + assert.Equal(t, 2000, client.MaxTokens) + assert.NotNil(t, client.HTTPClient) +} + +// TestAnalyzeFunction_Success tests successful LLM analysis. +func TestAnalyzeFunction_Success(t *testing.T) { + // Mock LLM response + mockResponse := LLMAnalysisResult{ + DiscoveredPatterns: DiscoveredPatterns{ + Sources: []PatternLocation{ + { + Pattern: "request.GET", + Lines: []int{2}, + Variables: []string{"user_input"}, + Category: "user_input", + Description: "HTTP GET parameter", + }, + }, + Sinks: []PatternLocation{ + { + Pattern: "os.system", + Lines: []int{3}, + Variables: []string{"user_input"}, + Category: "command_exec", + Description: "OS command execution", + Severity: "CRITICAL", + }, + }, + }, + DataflowTestCases: []DataflowTestCase{ + { + TestID: 1, + Description: "User input to command exec", + Source: TestCaseSource{ + Pattern: "request.GET['cmd']", + Line: 2, + Variable: "user_input", + }, + Sink: TestCaseSink{ + Pattern: "os.system", + Line: 3, + Variable: "user_input", + }, + ExpectedDetection: true, + VulnerabilityType: "COMMAND_INJECTION", + Confidence: 0.95, + }, + }, + AnalysisMetadata: AnalysisMetadata{ + TotalSources: 1, + TotalSinks: 1, + TotalFlows: 1, + Confidence: 0.95, + }, + } + + // Create mock server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + responseBytes, _ := json.Marshal(mockResponse) + ollamaResp := map[string]interface{}{ + "response": string(responseBytes), + "done": true, + } + json.NewEncoder(w).Encode(ollamaResp) + })) + defer server.Close() + + // Create client pointing to mock server + client := NewLLMClient(server.URL, "mock-model") + + // Test function + fn := &FunctionMetadata{ + FQN: "test.func", + SourceCode: "def func():\n pass", + StartLine: 1, + EndLine: 2, + } + + // Analyze + result, err := client.AnalyzeFunction(fn) + require.NoError(t, err) + require.NotNil(t, result) + + // Verify + assert.Equal(t, "test.func", result.FunctionFQN) + assert.Equal(t, 1, len(result.DiscoveredPatterns.Sources)) + assert.Equal(t, 1, len(result.DiscoveredPatterns.Sinks)) + assert.Equal(t, 1, len(result.DataflowTestCases)) + assert.Equal(t, "COMMAND_INJECTION", result.DataflowTestCases[0].VulnerabilityType) + assert.True(t, result.DataflowTestCases[0].ExpectedDetection) + assert.Equal(t, "mock-model", result.AnalysisMetadata.ModelUsed) + assert.NotEmpty(t, result.AnalysisMetadata.ProcessingTime) +} + +// TestAnalyzeFunction_InvalidJSON tests error handling for bad JSON. +func TestAnalyzeFunction_InvalidJSON(t *testing.T) { + // Create mock server that returns invalid JSON + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ollamaResp := map[string]interface{}{ + "response": "This is not valid JSON {{{", + "done": true, + } + json.NewEncoder(w).Encode(ollamaResp) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "mock-model") + + fn := &FunctionMetadata{ + FQN: "test.func", + SourceCode: "def func():\n pass", + } + + result, err := client.AnalyzeFunction(fn) + assert.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "failed to parse") +} + +// TestAnalyzeFunction_HTTPError tests error handling for HTTP failures. +func TestAnalyzeFunction_HTTPError(t *testing.T) { + // Create mock server that returns 500 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("Internal server error")) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "mock-model") + + fn := &FunctionMetadata{ + FQN: "test.func", + SourceCode: "def func():\n pass", + } + + result, err := client.AnalyzeFunction(fn) + assert.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "status 500") +} + +// TestAnalyzeFunction_MalformedOllamaResponse tests handling of bad Ollama response. +func TestAnalyzeFunction_MalformedOllamaResponse(t *testing.T) { + // Create mock server that returns malformed Ollama response + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("not a valid ollama response")) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "mock-model") + + fn := &FunctionMetadata{ + FQN: "test.func", + SourceCode: "def func():\n pass", + } + + result, err := client.AnalyzeFunction(fn) + assert.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "failed to parse Ollama response") +} + +// TestValidateResult_InvalidConfidence tests confidence validation. +func TestValidateResult_InvalidConfidence(t *testing.T) { + client := NewLLMClient("http://localhost:11434", "test") + + tests := []struct { + name string + confidence float64 + shouldFail bool + }{ + {"valid 0.0", 0.0, false}, + {"valid 0.5", 0.5, false}, + {"valid 1.0", 1.0, false}, + {"invalid negative", -0.1, true}, + {"invalid > 1.0", 1.5, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := &LLMAnalysisResult{ + AnalysisMetadata: AnalysisMetadata{ + Confidence: tt.confidence, + }, + } + + err := client.validateResult(result) + if tt.shouldFail { + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid confidence") + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestValidateResult_InvalidTestCase tests test case validation. +func TestValidateResult_InvalidTestCase(t *testing.T) { + client := NewLLMClient("http://localhost:11434", "test") + + tests := []struct { + name string + testCase DataflowTestCase + shouldFail bool + errorMsg string + }{ + { + name: "valid test case", + testCase: DataflowTestCase{ + Source: TestCaseSource{Line: 5}, + Sink: TestCaseSink{Line: 10}, + Confidence: 0.9, + }, + shouldFail: false, + }, + { + name: "invalid source line zero", + testCase: DataflowTestCase{ + Source: TestCaseSource{Line: 0}, + Sink: TestCaseSink{Line: 5}, + Confidence: 0.9, + }, + shouldFail: true, + errorMsg: "invalid source line", + }, + { + name: "invalid sink line negative", + testCase: DataflowTestCase{ + Source: TestCaseSource{Line: 5}, + Sink: TestCaseSink{Line: -1}, + Confidence: 0.9, + }, + shouldFail: true, + errorMsg: "invalid sink line", + }, + { + name: "invalid confidence", + testCase: DataflowTestCase{ + Source: TestCaseSource{Line: 5}, + Sink: TestCaseSink{Line: 10}, + Confidence: 1.5, + }, + shouldFail: true, + errorMsg: "invalid confidence", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := &LLMAnalysisResult{ + DataflowTestCases: []DataflowTestCase{tt.testCase}, + AnalysisMetadata: AnalysisMetadata{ + Confidence: 0.9, + }, + } + + err := client.validateResult(result) + if tt.shouldFail { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestAnalyzeBatch tests parallel batch processing. +func TestAnalyzeBatch(t *testing.T) { + // Mock server with counter + callCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + mockResponse := LLMAnalysisResult{ + AnalysisMetadata: AnalysisMetadata{ + Confidence: 0.9, + }, + } + responseBytes, _ := json.Marshal(mockResponse) + ollamaResp := map[string]interface{}{ + "response": string(responseBytes), + "done": true, + } + json.NewEncoder(w).Encode(ollamaResp) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "mock-model") + + // Create test functions + functions := []*FunctionMetadata{ + {FQN: "test.func1", SourceCode: "def func1(): pass"}, + {FQN: "test.func2", SourceCode: "def func2(): pass"}, + {FQN: "test.func3", SourceCode: "def func3(): pass"}, + } + + // Analyze batch + results, errors := client.AnalyzeBatch(functions, 2) + + // Verify + assert.Equal(t, 3, len(results)) + assert.Equal(t, 0, len(errors)) + assert.Equal(t, 3, callCount) + + assert.NotNil(t, results["test.func1"]) + assert.NotNil(t, results["test.func2"]) + assert.NotNil(t, results["test.func3"]) +} + +// TestAnalyzeBatch_WithErrors tests batch processing with some failures. +func TestAnalyzeBatch_WithErrors(t *testing.T) { + // Mock server that fails on certain requests + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Read request to determine which function + var reqBody map[string]interface{} + json.NewDecoder(r.Body).Decode(&reqBody) + prompt := reqBody["prompt"].(string) + + // Fail if prompt contains "func2" + if contains(prompt, "func2") { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("Simulated error")) + return + } + + mockResponse := LLMAnalysisResult{ + AnalysisMetadata: AnalysisMetadata{ + Confidence: 0.9, + }, + } + responseBytes, _ := json.Marshal(mockResponse) + ollamaResp := map[string]interface{}{ + "response": string(responseBytes), + "done": true, + } + json.NewEncoder(w).Encode(ollamaResp) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "mock-model") + + functions := []*FunctionMetadata{ + {FQN: "test.func1", SourceCode: "def func1(): pass"}, + {FQN: "test.func2", SourceCode: "def func2(): pass"}, + {FQN: "test.func3", SourceCode: "def func3(): pass"}, + } + + results, errors := client.AnalyzeBatch(functions, 2) + + // Should have 2 successes and 1 error + assert.Equal(t, 2, len(results)) + assert.Equal(t, 1, len(errors)) + + assert.NotNil(t, results["test.func1"]) + assert.NotNil(t, results["test.func3"]) + assert.NotNil(t, errors["test.func2"]) +} + +// Helper function for string contains check. +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(substr) == 0 || (len(s) > 0 && len(substr) > 0 && findSubstring(s, substr))) +} + +func findSubstring(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/sourcecode-parser/diagnostic/prompt.go b/sourcecode-parser/diagnostic/prompt.go new file mode 100644 index 00000000..b1b24205 --- /dev/null +++ b/sourcecode-parser/diagnostic/prompt.go @@ -0,0 +1,162 @@ +package diagnostic + +import ( + "fmt" +) + +// BuildAnalysisPrompt constructs the prompt for LLM pattern discovery and test generation. +// Based on diagnostic-tech-proposal.md Section 3.3 (LLM Prompt Design). +func BuildAnalysisPrompt(sourceCode string) string { + return fmt.Sprintf(`You are a dataflow analysis expert. Analyze this Python function to discover all dataflow patterns and generate test cases. + +**FUNCTION TO ANALYZE**: +`+"```python\n%s\n```"+` + +**YOUR TASK**: + +1. **DISCOVER PATTERNS** - Identify all dataflow patterns in THIS function: + - **Sources**: Any operation that introduces new data (user input, file reads, network, env vars, function params, etc.) + - **Sinks**: Any operation that consumes data (output, storage, exec, system calls, returns, etc.) + - **Sanitizers**: Any operation that transforms/validates data (escape, quote, validate, cast, etc.) + - **Propagators**: Operations that pass data along (assignments, calls, returns) + +2. **TRACE DATAFLOWS** - For each discovered source: + - Track where the data flows (which variables, which lines) + - Identify if it reaches any sinks + - Note if any sanitizers are applied + - Track through: assignments, calls, returns, branches, containers, attributes + +3. **GENERATE TEST CASES** - Create test cases our tool should pass: + - Expected flows (source → sink paths) + - Expected sanitizer detection + - Expected variable tracking + +**OUTPUT FORMAT** (JSON): +`+"```json"+` +{ + "discovered_patterns": { + "sources": [ + { + "pattern": "request.GET", + "lines": [10, 15], + "variables": ["user_input", "cmd"], + "category": "user_input", + "description": "HTTP GET parameter access" + } + ], + "sinks": [ + { + "pattern": "os.system", + "lines": [45], + "variables": ["command"], + "category": "command_exec", + "description": "OS command execution", + "severity": "CRITICAL" + } + ], + "sanitizers": [ + { + "pattern": "shlex.quote", + "lines": [30], + "variables": ["cleaned_cmd"], + "description": "Shell escaping function" + } + ], + "propagators": [ + { + "type": "assignment", + "line": 12, + "from_var": "user_input", + "to_var": "raw_cmd" + } + ] + }, + + "dataflow_test_cases": [ + { + "test_id": 1, + "description": "Unsanitized user input flows to command execution", + "source": { + "pattern": "request.GET['cmd']", + "line": 10, + "variable": "user_input" + }, + "sink": { + "pattern": "os.system", + "line": 45, + "variable": "command" + }, + "flow_path": [ + {"line": 10, "variable": "user_input", "operation": "source"}, + {"line": 12, "variable": "raw_cmd", "operation": "assignment"}, + {"line": 45, "variable": "command", "operation": "sink"} + ], + "sanitizers_in_path": [], + "expected_detection": true, + "vulnerability_type": "COMMAND_INJECTION", + "confidence": 0.95, + "reasoning": "Direct flow from user input to OS command without sanitization" + } + ], + + "variable_tracking": [ + { + "variable": "user_input", + "first_defined": 10, + "last_used": 45, + "aliases": ["raw_cmd", "command"], + "flows_to_lines": [12, 20, 45], + "flows_to_vars": ["raw_cmd", "processed", "command"] + } + ], + + "analysis_metadata": { + "total_sources": 1, + "total_sinks": 1, + "total_sanitizers": 0, + "total_flows": 1, + "dangerous_flows": 1, + "safe_flows": 0, + "confidence": 0.95, + "limitations": [ + "Intra-procedural only (did not analyze called functions)", + "Control flow branches not fully explored" + ] + } +} +`+"```"+` + +**IMPORTANT GUIDELINES**: + +1. **NO PREDEFINED PATTERNS**: Discover patterns from the code itself, don't assume +2. **BE SPECIFIC**: Include exact line numbers, variable names, code snippets +3. **TRACK EVERYTHING**: Even non-security dataflows (var assignments, returns, etc.) +4. **SANITIZER EFFECTIVENESS**: Note what each sanitizer actually blocks +5. **GENERATE TESTS**: Each test case should be independently verifiable +6. **CONFIDENCE SCORES**: Rate how confident you are (0.0-1.0) +7. **EXPLAIN REASONING**: Why you think a flow exists or doesn't exist + +**EXAMPLE PATTERNS TO DISCOVER**: + +Security: +- request.GET/POST/COOKIES → eval/exec/os.system +- input() → open() +- socket.recv() → subprocess.call() + +Generic Dataflow: +- function_param → return value +- config['key'] → database.save() +- user.name → logger.info() +- x = calculate() → result = process(x) + +**FOCUS**: Validate dataflow tracking algorithm: +- ✅ Track variables through assignments +- ✅ Detect def-use chains correctly +- ✅ Identify taint propagation paths +- ✅ Recognize sanitizers +- ✅ Handle control flow (if/else) +- ✅ Track container operations +- ✅ Track attribute access + +Output ONLY the JSON, no additional text.`, sourceCode) +} diff --git a/sourcecode-parser/diagnostic/prompt_test.go b/sourcecode-parser/diagnostic/prompt_test.go new file mode 100644 index 00000000..3be7c4af --- /dev/null +++ b/sourcecode-parser/diagnostic/prompt_test.go @@ -0,0 +1,101 @@ +package diagnostic + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestBuildAnalysisPrompt tests prompt construction. +func TestBuildAnalysisPrompt(t *testing.T) { + sourceCode := "def test():\n x = 1\n return x" + + prompt := BuildAnalysisPrompt(sourceCode) + + // Verify prompt contains key elements + assert.Contains(t, prompt, "dataflow analysis expert") + assert.Contains(t, prompt, sourceCode) + assert.Contains(t, prompt, "DISCOVER PATTERNS") + assert.Contains(t, prompt, "TRACE DATAFLOWS") + assert.Contains(t, prompt, "GENERATE TEST CASES") + assert.Contains(t, prompt, "discovered_patterns") + assert.Contains(t, prompt, "dataflow_test_cases") + assert.Contains(t, prompt, "JSON") + assert.Contains(t, prompt, "Sources") + assert.Contains(t, prompt, "Sinks") + assert.Contains(t, prompt, "Sanitizers") + assert.Contains(t, prompt, "Propagators") +} + +// TestBuildAnalysisPrompt_ContainsExamples tests that prompt includes examples. +func TestBuildAnalysisPrompt_ContainsExamples(t *testing.T) { + prompt := BuildAnalysisPrompt("def dummy(): pass") + + // Check for security examples + assert.Contains(t, prompt, "request.GET") + assert.Contains(t, prompt, "os.system") + assert.Contains(t, prompt, "COMMAND_INJECTION") + + // Check for generic dataflow examples + assert.Contains(t, prompt, "function_param") + assert.Contains(t, prompt, "return value") +} + +// TestBuildAnalysisPrompt_ContainsGuidelines tests that prompt includes important guidelines. +func TestBuildAnalysisPrompt_ContainsGuidelines(t *testing.T) { + prompt := BuildAnalysisPrompt("def dummy(): pass") + + assert.Contains(t, prompt, "NO PREDEFINED PATTERNS") + assert.Contains(t, prompt, "BE SPECIFIC") + assert.Contains(t, prompt, "TRACK EVERYTHING") + assert.Contains(t, prompt, "CONFIDENCE SCORES") + assert.Contains(t, prompt, "Output ONLY the JSON") +} + +// TestBuildAnalysisPrompt_JSONStructure tests that prompt shows expected JSON structure. +func TestBuildAnalysisPrompt_JSONStructure(t *testing.T) { + prompt := BuildAnalysisPrompt("def dummy(): pass") + + // Check for JSON structure elements + assert.Contains(t, prompt, "pattern") + assert.Contains(t, prompt, "lines") + assert.Contains(t, prompt, "variables") + assert.Contains(t, prompt, "category") + assert.Contains(t, prompt, "description") + assert.Contains(t, prompt, "test_id") + assert.Contains(t, prompt, "expected_detection") + assert.Contains(t, prompt, "vulnerability_type") + assert.Contains(t, prompt, "confidence") + assert.Contains(t, prompt, "reasoning") + assert.Contains(t, prompt, "variable_tracking") + assert.Contains(t, prompt, "analysis_metadata") +} + +// TestBuildAnalysisPrompt_EmptySourceCode tests with empty source code. +func TestBuildAnalysisPrompt_EmptySourceCode(t *testing.T) { + prompt := BuildAnalysisPrompt("") + + // Should still generate valid prompt structure + assert.Contains(t, prompt, "DISCOVER PATTERNS") + assert.Contains(t, prompt, "GENERATE TEST CASES") + assert.NotEmpty(t, prompt) +} + +// TestBuildAnalysisPrompt_ComplexSourceCode tests with realistic source code. +func TestBuildAnalysisPrompt_ComplexSourceCode(t *testing.T) { + sourceCode := `def process_input(request): + user_cmd = request.GET['cmd'] + cleaned = shlex.quote(user_cmd) + os.system(cleaned)` + + prompt := BuildAnalysisPrompt(sourceCode) + + // Verify source code is embedded + assert.Contains(t, prompt, "process_input") + assert.Contains(t, prompt, "user_cmd") + assert.Contains(t, prompt, "shlex.quote") + + // Verify prompt structure intact + assert.Contains(t, prompt, "```python") + assert.Contains(t, prompt, "```json") +} diff --git a/sourcecode-parser/diagnostic/types.go b/sourcecode-parser/diagnostic/types.go index 0c925c2b..9ba79d28 100644 --- a/sourcecode-parser/diagnostic/types.go +++ b/sourcecode-parser/diagnostic/types.go @@ -41,3 +41,151 @@ type FunctionMetadata struct { // IsAsync indicates if this is an async function IsAsync bool } + +// LLMAnalysisResult contains the LLM's analysis of a function. +type LLMAnalysisResult struct { + // FunctionFQN identifies which function was analyzed + FunctionFQN string + + // DiscoveredPatterns contains sources/sinks/sanitizers found by LLM + DiscoveredPatterns DiscoveredPatterns + + // DataflowTestCases are test cases generated by LLM + // Each test case specifies expected dataflow behavior + DataflowTestCases []DataflowTestCase + + // VariableTracking shows how LLM traced variables through the function + VariableTracking []VariableTrack + + // Metadata about the analysis + AnalysisMetadata AnalysisMetadata +} + +// DiscoveredPatterns contains all patterns discovered by LLM in the function. +type DiscoveredPatterns struct { + Sources []PatternLocation + Sinks []PatternLocation + Sanitizers []PatternLocation + Propagators []PropagatorOperation +} + +// PatternLocation describes where a pattern (source/sink/sanitizer) was found. +type PatternLocation struct { + // Pattern is the code pattern (e.g., "request.GET", "os.system") + Pattern string + + // Lines where this pattern appears + Lines []int + + // Variables involved + Variables []string + + // Category for semantic grouping + // Examples: "user_input", "file_read", "sql_execution", "command_exec" + Category string + + // Description of what this pattern does + Description string + + // Severity (for sinks): CRITICAL, HIGH, MEDIUM, LOW + Severity string +} + +// PropagatorOperation describes how data propagates. +type PropagatorOperation struct { + // Type: "assignment", "function_call", "return" + Type string + + // Line number + Line int + + // Source variable + FromVar string + + // Destination variable + ToVar string + + // Function name (if Type == "function_call") + Function string +} + +// DataflowTestCase is a test case generated by LLM. +// This is what we validate our tool against. +type DataflowTestCase struct { + // TestID for reference + TestID int + + // Description of what this test validates + Description string + + // Source information + Source TestCaseSource + + // Sink information + Sink TestCaseSink + + // Flow path (sequence of variables/operations) + FlowPath []FlowStep + + // Sanitizers in the path (if any) + SanitizersInPath []string + + // Expected detection result + // true: Our tool SHOULD detect this flow + // false: Our tool should NOT detect (e.g., sanitized) + ExpectedDetection bool + + // Vulnerability type (if ExpectedDetection == true) + VulnerabilityType string + + // Confidence score (0.0-1.0) + Confidence float64 + + // Reasoning for this test case + Reasoning string +} + +// TestCaseSource describes the source in a test case. +type TestCaseSource struct { + Pattern string // e.g., "request.GET['cmd']" + Line int + Variable string +} + +// TestCaseSink describes the sink in a test case. +type TestCaseSink struct { + Pattern string // e.g., "os.system" + Line int + Variable string +} + +// FlowStep describes one step in a dataflow path. +type FlowStep struct { + Line int + Variable string + Operation string // "source", "assignment", "call", "sanitizer", "sink" +} + +// VariableTrack shows how LLM traced a variable. +type VariableTrack struct { + Variable string + FirstDefined int + LastUsed int + Aliases []string // Other variable names that hold the same data + FlowsToLines []int + FlowsToVars []string +} + +// AnalysisMetadata contains metadata about the LLM analysis. +type AnalysisMetadata struct { + TotalSources int + TotalSinks int + TotalSanitizers int + TotalFlows int + DangerousFlows int + SafeFlows int + Confidence float64 + Limitations []string + ProcessingTime string + ModelUsed string +}