<a href="https://colab.research.google.com/github/rumeshsmrr/GenAI_Based_Autograding/blob/technical_assignments/startCodeForProgramminEvaluvation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [19]:
import subprocess
import os
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import RobertaTokenizer, RobertaForSequenceClassification
from tabulate import tabulate

# Load CodeBERT and T5 models
codebert_tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base")
codebert_model = RobertaForSequenceClassification.from_pretrained("microsoft/codebert-base", num_labels=1)

flan_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
flan_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base")

def run_java_code(code):
    # Extract the class name from the code (this assumes class is defined as 'public class ClassName')
    class_name = None
    for line in code.splitlines():
        if line.strip().startswith("public class "):
            class_name = line.strip().split()[2]  # Get the class name
            break

    if not class_name:
        return None, "No public class found in the code."

    file_name = f"{class_name}.java"

    with open(file_name, "w") as file:
        file.write(code)

    # Compile the Java file
    compile_process = subprocess.run(
        ["javac", file_name],
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE
    )

    # Check for compilation errors
    if compile_process.returncode != 0:
        return None, compile_process.stderr.decode()

    # Run the Java program
    run_process = subprocess.run(
        ["java", class_name],
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE
    )

    # Check for runtime errors
    if run_process.returncode != 0:
        return None, run_process.stderr.decode()

    # Clean up the generated files
    os.remove(file_name)
    os.remove(f"{class_name}.class")

    return run_process.stdout.decode().strip(), None
def evaluate_java_code(student_code, reference_code):
    student_output, student_error = run_java_code(student_code)
    reference_output, reference_error = run_java_code(reference_code)

    if student_error:
        return 0, "Student code has errors: " + student_error
    if reference_error:
        return 0, "Reference code has errors: " + reference_error

    if student_output == reference_output:
        return 100, "Outputs match perfectly."

    total_length = max(len(student_output), len(reference_output))
    correct_length = sum(1 for s, r in zip(student_output, reference_output) if s == r)

    percentage_correct = (correct_length / total_length) * 100 if total_length > 0 else 0
    return percentage_correct, "Outputs differ."

def evaluate_functionality_with_codebert(student_code, reference_code, execution_score):
    # Calculate the similarity between student and reference code
    match_percentage = calculate_code_similarity(student_code, reference_code)

    # Matching score out of 0.2 based on percentage
    matching_score = 0.2 * (match_percentage / 100)

    # Functionality score via CodeBERT
    inputs = codebert_tokenizer(
        [student_code, reference_code],
        padding=True,
        truncation=True,
        return_tensors="pt"
    )

    with torch.no_grad():
        outputs = codebert_model(**inputs)

    logits = outputs.logits
    predicted_score = torch.sigmoid(logits[0]).squeeze().item()  # Convert the first element of the tensor to scalar float

    # If execution score is 100%, override and give full functionality score
    if execution_score == 100:
        predicted_score = 1.0  # Set functionality score to full if the output is correct

    # Combine functionality score with matching score
    final_functionality_score = predicted_score * 0.5 + matching_score

    # Ensure the score does not fall below a reasonable threshold
    final_functionality_score = min(final_functionality_score, 1.0)

    return final_functionality_score


def calculate_code_similarity(code1, code2):
    """
    Calculate the percentage of matching characters between two code snippets.
    This can be expanded with more sophisticated text comparison methods.
    """
    # Normalize the code by removing spaces and line breaks for better comparison
    normalized_code1 = "".join(code1.split())
    normalized_code2 = "".join(code2.split())

    # Calculate matching percentage
    match_length = sum(1 for a, b in zip(normalized_code1, normalized_code2) if a == b)
    max_length = max(len(normalized_code1), len(normalized_code2))

    return (match_length / max_length) * 100 if max_length > 0 else 0

def evaluate_non_functional_code(student_code):
    # First use T5 model for readability and structure evaluation
    inputs = flan_tokenizer(
        f"Evaluate the code for readability, comments, and structure: {student_code}",
        return_tensors="pt",
        padding=True,
        truncation=True
    )

    with torch.no_grad():
        outputs = flan_model.generate(**inputs)

    evaluation = flan_tokenizer.decode(outputs[0], skip_special_tokens=True)
    score_from_model = extract_score_from_evaluation(evaluation)

    # Basic heuristic evaluation (presence of comments, consistent indentation)
    comments_present = "//" in student_code or "/*" in student_code  # Checking for comments
    indentation_consistent = check_indentation(student_code)  # Check if indentation is consistent

    # Add extra points based on these heuristic checks
    heuristic_score = 0.5 if comments_present else 0.0
    heuristic_score += 0.5 if indentation_consistent else 0.0

    # Combine T5 model score with heuristic score
    final_structure_score = min(1.0, score_from_model + heuristic_score)

    return final_structure_score

def extract_score_from_evaluation(evaluation_text):
    if "good" in evaluation_text.lower():
        return 0.5
    elif "average" in evaluation_text.lower():
        return 0.25
    elif "poor" in evaluation_text.lower():
        return 0.0
    return 0.0  # Default score if nothing matches

def check_indentation(student_code):
    """Heuristic check for consistent indentation (basic)"""
    lines = student_code.split('\n')
    indent_levels = [len(line) - len(line.lstrip()) for line in lines if line.strip() != ""]
    return len(set(indent_levels)) <= 3  # Allow up to 3 different indent levels (basic check)

def extract_marking_scheme_from_llm(prompt):
    # Example marking scheme extraction (mocked)
    return {
        "functionality": 0.5,
        "execution": 0.3,
        "structure": 0.2
    }

def grade_student_code(student_code, reference_code, marking_scheme):
    # Execution Evaluation
    execution_score, execution_message = evaluate_java_code(student_code, reference_code)

    # If output is perfect and code similarity is above 90%, award full marks
    match_percentage = calculate_code_similarity(student_code, reference_code)
    if match_percentage >= 90 and execution_score == 100:
        final_score = 100
        feedback = "Perfect match with the reference code. Full marks awarded!"
    else:
        # Functionality Evaluation using CodeBERT, adjusted for perfect execution score
        functionality_score = evaluate_functionality_with_codebert(student_code, reference_code, execution_score)

        # Non-Functional Evaluation using T5 and heuristics
        structure_score = evaluate_non_functional_code(student_code)

        # Combine scores based on the provided marking scheme
        final_score = (functionality_score * marking_scheme["functionality"]) + \
                      (execution_score * marking_scheme["execution"]) + \
                      (structure_score * marking_scheme["structure"])

        feedback = (f"Functionality Score: {functionality_score}, "
                    f"Execution Score: {execution_score}, "
                    f"Structure Score: {structure_score}. "
                    f"Message: {execution_message}")

    return final_score, feedback

# Example Usage
student_submissions = {
    "student_1": """
   import java.util.Scanner;

public class Calculator {
    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);

        System.out.println("Enter first number:");
        double num1 = scanner.nextDouble();

        System.out.println("Enter second number:");
        double num2 = scanner.nextDouble();

        System.out.println("Enter an operator (+, -, *, /):");
        char operator = scanner.next().charAt(0);

        double result = 0;
        switch (operator) {
            case '+':
                result = num1 + num2;
                break;
            case '-':
                result = num1 - num2;
                break;
            case '*':
                result = num1 * num2;
                break;
            case '/':
                if (num2 != 0) {
                    result = num1 / num2;
                } else {
                    System.out.println("Error! Division by zero.");
                    return;
                }
                break;
            default:
                System.out.println("Invalid operator");
                return;
        }

        System.out.println("Result: " + result);
    }
}""",

    "student_2": """public class TempCode {
        public static void main(String[] args) {
            System.out.println("Hello, World!"); // Extra comment
        }
    }""",

    "student_3": """public class TempCode {
        public static void main(String[] args) {
            System.out.println("Hello, Wrong World!"); // Wrong output
        }
    }""",

    "student_4": """public class Calculator {
        public static void main(String[] args) {
            Scanner sc = new Scanner(System.in);
            System.out.println("Enter two numbers and an operator");
            double num1 = sc.nextDouble();
            double num2 = sc.nextDouble();
            char operator = sc.next().charAt(0);
            double result = 0;
            switch (operator) {
                case '+':
                    result = num1 + num2;
                    break;
                case '-':
                    result = num1 - num2;
                    break;
                case '*':
                    result = num1 * num2;
                    break;
                case '/':
                    if (num2 == 0) {
                        System.out.println("Cannot divide by zero!");
                        return;
                    }
                    result = num1 / num2;
                    break;
                default:
                    System.out.println("Invalid operator");
                    return;
            }
            System.out.println("The result is: " + result);
        }
    }""",

    "student_5": """public class Calculator {
        public static void main(String[] args) {
            Scanner sc = new Scanner(System.in);
            double num1 = sc.nextDouble();
            double num2 = sc.nextDouble();
            char operator = sc.next().charAt(0);
            double result = 1; // Incorrect initialization
            switch (operator) {
                case '+':
                    result = num1 + num2;
                    break;
                case '-':
                    result = num1 - num2;
                    break;
                case '*':
                    result = num1 * num2;
                    break;
                case '/':
                    result = num1 / num2;
                    break;
                default:
                    System.out.println("Invalid operator");
                    return;
            }
            System.out.println("Result: " + result);
        }
    }""",

    "student_6": """public class Calculator {
        public static void main(String[] args) {
            Scanner sc = new Scanner(System.in);
            double num1 = sc.nextDouble();
            double num2 = sc.nextDouble();
            char operator = sc.next().charAt(0);
            double result = 0;
            switch (operator) {
                case '+':
                    result = num1 + num2;
                    break;
                case '-':
                    result = num1 - num2;
                    break;
                case '*':
                    result = num1 * num2;
                    break;
                case '/':
                    result = num1 / num2; // No check for zero division
                    break;
                default:
                    System.out.println("Invalid operator");
                    return;
            }
            System.out.println("Result: " + result);
        }
    }""",

}


reference_code = """import java.util.Scanner;

public class Calculator {
    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);

        System.out.println("Enter first number:");
        double num1 = scanner.nextDouble();

        System.out.println("Enter second number:");
        double num2 = scanner.nextDouble();

        System.out.println("Enter an operator (+, -, *, /):");
        char operator = scanner.next().charAt(0);

        double result = 0;
        switch (operator) {
            case '+':
                result = num1 + num2;
                break;
            case '-':
                result = num1 - num2;
                break;
            case '*':
                result = num1 * num2;
                break;
            case '/':
                if (num2 != 0) {
                    result = num1 / num2;
                } else {
                    System.out.println("Error! Division by zero.");
                    return;
                }
                break;
            default:
                System.out.println("Invalid operator");
                return;
        }

        System.out.println("Result: " + result);
    }
}
"""  # Example correct answer

# Get marking scheme from LLM
marking_scheme_prompt = "Extract the marking scheme for evaluating Java code based on functionality, execution, and structure."
marking_scheme = extract_marking_scheme_from_llm(marking_scheme_prompt)

# Collect results in a list for tabulation
results = []

# Evaluate each student code
for student_id, student_code in student_submissions.items():
    final_score, feedback = grade_student_code(student_code, reference_code, marking_scheme)
    results.append([student_id, f"{final_score:.2f}", feedback])

# Print results in a table format
headers = ["Student ID", "Final Score", "Feedback"]
print(tabulate(results, headers, tablefmt="grid"))


Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at microsoft/codebert-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


+--------------+---------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| Student ID   |   Final Score | Feedback                                                                                                                                                                           |
| student_1    |          0.24 | Functionality Score: 0.4748866379261017, Execution Score: 0, Structure Score: 0.0. Message: Student code has errors: Exception in thread "main" java.util.NoSuchElementException   |
|              |               | 	at java.base/java.util.Scanner.throwFor(Scanner.java:937)                                                                                                                                                                                    |
|              |               | 	at java.base/java.util.Scanner.next(Scanner.java:15