package com.example.benchmark;

import org.eclipse.jdt.core.compiler.IProblem;
import org.eclipse.jdt.core.dom.AST;
import org.eclipse.jdt.core.dom.ASTParser;
import org.eclipse.jdt.core.dom.CompilationUnit;

import java.io.IOException;
import java.io.PrintWriter;
import java.nio.charset.StandardCharsets;
import java.nio.file.*;
import java.nio.file.attribute.BasicFileAttributes;
import java.util.*;

public class JdtParserBenchmark {

    private static final String PROJECT_DIR = "<PATH_TO_JAVA_PROJECT>";
    private static final String CLASSPATH_FILE = "classpath.txt";
    private static final String SOURCE_LEVEL = "11";

    private static final int WARMUP_ROUNDS = 0;
    private static final int MEASURE_ROUNDS = 1;

    private static class SourceFile {
        final String path;
        final char[] content;

        SourceFile(String path, char[] content) {
            this.path = path;
            this.content = content;
        }
    }

    private static class FileErrors {
        final String path;
        final List<String> errors;

        FileErrors(String path, List<String> errors) {
            this.path = path;
            this.errors = errors;
        }
    }

    public static void main(String[] args) throws IOException {
        // Load classpath
        List<String> cpLines = Files.readAllLines(Paths.get(CLASSPATH_FILE), StandardCharsets.UTF_8);
        String[] classpathEntries = cpLines.stream()
                .map(String::trim)
                .filter(s -> !s.isEmpty())
                .toArray(String[]::new);

        String[] sourcepathEntries = {
                PROJECT_DIR + "\\src\\main\\java",
                PROJECT_DIR + "\\src\\test\\java"
        };

        // Collect all .java files
        Path root = Paths.get(PROJECT_DIR);
        List<SourceFile> sourceFiles = new ArrayList<>();
        Files.walkFileTree(root, new SimpleFileVisitor<Path>() {
            @Override
            public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) throws IOException {
                if (file.toString().endsWith(".java")) {
                    byte[] bytes = Files.readAllBytes(file);
                    char[] content = new String(bytes, StandardCharsets.UTF_8).toCharArray();
                    sourceFiles.add(new SourceFile(file.toString(), content));
                }
                return FileVisitResult.CONTINUE;
            }
        });

        System.out.println("Loaded " + sourceFiles.size() + " Java source files for parsing.");

        // Warmup
        System.out.println("\nWarmup (" + WARMUP_ROUNDS + " rounds)...");
        if (WARMUP_ROUNDS > 0) {
            for (int r = 0; r < WARMUP_ROUNDS; r++) {
                parseAll(sourceFiles, classpathEntries, sourcepathEntries, null);
                System.out.println("  Warmup round " + (r + 1) + " done.");
            }
        }

        // Measurement
        System.out.println("\nMeasurement (" + MEASURE_ROUNDS + " rounds)...");
        long[] roundTimes = new long[MEASURE_ROUNDS];
        List<FileErrors> lastErrors = null;
        for (int r = 0; r < MEASURE_ROUNDS; r++) {
            List<FileErrors> errors = new ArrayList<>();
            long start = System.nanoTime();
            parseAll(sourceFiles, classpathEntries, sourcepathEntries, errors);
            long elapsed = System.nanoTime() - start;
            roundTimes[r] = elapsed;
            lastErrors = errors;
            System.out.printf("  Round %d: %.1f ms (%.3f ms/file)%n",
                    r + 1,
                    elapsed / 1_000_000.0,
                    elapsed / 1_000_000.0 / sourceFiles.size());
        }

        // Statistics
        long sum = 0;
        long min = Long.MAX_VALUE;
        long max = Long.MIN_VALUE;
        for (long t : roundTimes) {
            sum += t;
            if (t < min) min = t;
            if (t > max) max = t;
        }
        double avgMs = sum / 1_000_000.0 / MEASURE_ROUNDS;
        double avgPerFile = avgMs / sourceFiles.size();

        System.out.println("\n=== Results ===");
        System.out.printf("Rounds:            %d%n", MEASURE_ROUNDS);
        System.out.printf("Files per round:   %d%n", sourceFiles.size());
        System.out.printf("Avg round time:    %.1f ms%n", avgMs);
        System.out.printf("Min round time:    %.1f ms%n", min / 1_000_000.0);
        System.out.printf("Max round time:    %.1f ms%n", max / 1_000_000.0);
        System.out.printf("Avg time/file:     %.3f ms%n", avgPerFile);

        // Write errors from last measurement round
        List<FileErrors> allErrors = lastErrors != null ? lastErrors : Collections.emptyList();
        allErrors.sort(Comparator.comparing(e -> e.path));

        Path output = Paths.get("C:\\dv\\parser_compare\\errors.txt");
        try (PrintWriter pw = new PrintWriter(Files.newBufferedWriter(output, StandardCharsets.UTF_8))) {
            for (FileErrors fe : allErrors) {
                pw.println("FILE: " + fe.path);
                for (String msg : fe.errors) {
                    pw.println("  ERROR: " + msg);
                }
                pw.println();
            }
            pw.println("=== SUMMARY ===");
            pw.println("Total files parsed:        " + sourceFiles.size());
            pw.println("Files with errors:         " + allErrors.size());
            pw.println("Files parsed successfully: " + (sourceFiles.size() - allErrors.size()));
        }

        System.out.println("\n=== Error summary ===");
        System.out.println("Files with errors: " + allErrors.size() + " / " + sourceFiles.size());
        System.out.println("Details written to: " + output.toAbsolutePath());
    }

    private static void parseAll(List<SourceFile> sourceFiles, String[] classpathEntries,
                                 String[] sourcepathEntries, List<FileErrors> errorCollector) {
        for (SourceFile sf : sourceFiles) {
            ASTParser parser = ASTParser.newParser(AST.JLS17);
            parser.setKind(ASTParser.K_COMPILATION_UNIT);
            parser.setResolveBindings(true);
            parser.setBindingsRecovery(false);
            parser.setStatementsRecovery(false);

            Map<String, String> options = new HashMap<>();
            options.put("org.eclipse.jdt.core.compiler.source", SOURCE_LEVEL);
            options.put("org.eclipse.jdt.core.compiler.compliance", SOURCE_LEVEL);
            options.put("org.eclipse.jdt.core.compiler.codegen.targetPlatform", SOURCE_LEVEL);
            parser.setCompilerOptions(options);

            parser.setEnvironment(classpathEntries, sourcepathEntries, null, true);
            parser.setUnitName(sf.path);
            parser.setSource(sf.content);

            CompilationUnit cu = (CompilationUnit) parser.createAST(null);

            if (errorCollector != null) {
                IProblem[] problems = cu.getProblems();
                if (problems != null && problems.length > 0) {
                    List<String> errors = new ArrayList<>();
                    for (IProblem p : problems) {
                        if (p.isError()) {
                            errors.add("line " + p.getSourceLineNumber() + ": " + p.getMessage());
                        }
                    }
                    if (!errors.isEmpty()) {
                        errorCollector.add(new FileErrors(sf.path, errors));
                    }
                }
            }
        }
    }
}
