In [8]:
# !pip install langgraph

Collecting langgraph
  Downloading langgraph-0.2.70-py3-none-any.whl.metadata (17 kB)
Collecting langgraph-checkpoint<3.0.0,>=2.0.10 (from langgraph)
  Downloading langgraph_checkpoint-2.0.10-py3-none-any.whl.metadata (4.6 kB)
Collecting langgraph-sdk<0.2.0,>=0.1.42 (from langgraph)
  Downloading langgraph_sdk-0.1.51-py3-none-any.whl.metadata (1.8 kB)
Collecting msgpack<2.0.0,>=1.1.0 (from langgraph-checkpoint<3.0.0,>=2.0.10->langgraph)
  Downloading msgpack-1.1.0-cp39-cp39-win_amd64.whl.metadata (8.6 kB)
Collecting orjson>=3.10.1 (from langgraph-sdk<0.2.0,>=0.1.42->langgraph)
  Downloading orjson-3.10.15-cp39-cp39-win_amd64.whl.metadata (42 kB)
     ---------------------------------------- 0.0/42.9 kB ? eta -:--:--
     ---------------------------------------- 42.9/42.9 kB 2.0 MB/s eta 0:00:00
Downloading langgraph-0.2.70-py3-none-any.whl (149 kB)
   ---------------------------------------- 0.0/149.7 kB ? eta -:--:--
   ---------------------------------------- 149.7/149.7 kB 4.4 MB/s 

  You can safely remove it manually.


In [26]:
import os
import re
import datetime
from typing import Optional, Dict, List, Any, Tuple
from langgraph.graph import StateGraph, END
from langchain_core.prompts import PromptTemplate
# from langchain_community.chat_models import ChatOpenAI
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv
import yaml

In [110]:
load_dotenv()

class CodeTranslationGraph:
    def __init__(self, kb_path: Optional[str] = None):
        self.knowledge_base = self.load_knowledge_base(kb_path) if kb_path else {}
        self.max_iterations = 3
        self.execution_log: List[Dict[str, Any]] = []
        
        # Initialize LLM
        self.llm = ChatOpenAI(
            temperature=0,
            model="gpt-3.5-turbo",
            openai_api_key=os.getenv("OPENAI_API_KEY")
        )
        
        # Build workflow
        self.workflow = self._build_workflow()

    def load_knowledge_base(self, file_path: str) -> Dict:
        """Load the knowledge base"""
        with open(file_path, 'r', encoding='utf-8') as f:
            return yaml.safe_load(f)

    def _get_code_rules(self, target_lang: str) -> str:
        """Get the target language code rules"""
        rules = self.knowledge_base.get(target_lang, {})
        return "\n".join(
            [f"# {cat.upper()}\n" + "\n".join(f"- {item}" for item in items)
             for cat, items in rules.items() if cat != "analysis_rules"]
        )

    def _get_analysis_rules(self, target_lang: str) -> str:
        """Get analysis rules"""
        if not self.knowledge_base:
            return ""
            
        analysis_rules = self.knowledge_base.get(target_lang, {}).get("analysis_rules", [])
        return "Special Rules:\n" + "\n".join([f"- {rule}" for rule in analysis_rules])

    def _log_step(self, step_name: str, input_data: dict, output_data: Any):
        """Record execution steps"""
        self.execution_log.append({
            "step": step_name,
            "timestamp": datetime.datetime.now().isoformat(),
            "input": input_data,
            "output": output_data
        })

    def _build_workflow(self) -> StateGraph:
        """Building LangGraph Workflow"""
        workflow = StateGraph(state_schema=dict)
        
        # Define nodes
        workflow.add_node("analyze_requirements", self._analyze_requirements)
        workflow.add_node("parse_analysis", self._parse_analysis)
        workflow.add_node("initial_translation", self._initial_translation)
        workflow.add_node("validate_code", self._validate_code)
        workflow.add_node("improve_code", self._improve_code)
        workflow.add_node("finalize_output", self._finalize_output)

        # Set up the initial process
        workflow.set_entry_point("analyze_requirements")
        workflow.add_edge("analyze_requirements", "parse_analysis")
        workflow.add_edge("parse_analysis", "initial_translation")
        workflow.add_edge("initial_translation", "validate_code")
        
        # Set up the validation loop
        workflow.add_conditional_edges(
            "validate_code",
            self._should_improve,
            {"improve": "improve_code", "final": "finalize_output"}
        )
        workflow.add_edge("improve_code", "validate_code")
        
        # Set the final node
        workflow.add_edge("finalize_output", END)
        
        return workflow.compile()

    def _analyze_requirements(self, state: Dict) -> Dict:
        """Requirement Analysis Node"""
        analysis_template = """You are a senior code analysis expert. Perform these tasks:
        1. Identify source programming language (C/C++/FORTRAN/CUDA/OpenMP/JAX)
        2. Identify target language (C/C++/FORTRAN/CUDA/OpenMP/JAX)
        3. Extract code content needing conversion
        4. Analyze potential conversion challenges
        5. Generate code conversion task description
        
        {% if analysis_rules %}
        {{ analysis_rules }}
        {% endif %}

        User input: {{user_input}}

        Respond in this format:
        Source Language: [detected source language]
        Target Language: [detected target language]
        Code Content: [extracted code block]
        Potential Issues: 
        - [Issue1 description]
        - [Issue2 description]
        Task Description: "Convert the following [source] code to [target]:\n[code]"
        """
        
        prompt = PromptTemplate(
            template=analysis_template,
            input_variables=["user_input"],
            partial_variables={"analysis_rules": self._get_analysis_rules("")}, 
            template_format="jinja2"
        )
        
        chain = prompt | self.llm
        result = chain.invoke({"user_input": state["user_input"]})
        
        self._log_step("analyze_requirements", state, result.content)
        return {"analysis": result.content}
    
    def _parse_analysis(self, state: Dict) -> Dict:
        """Enhanced analysis results interpretation"""
        analysis = state.get("analysis", "")
        parsed_data = {
            "source_lang": "", 
            "target_lang": "",
            "code_content": "",
            "potential_issues": []
        }

        current_section = None
        code_content_started = False
        
        for line in analysis.split('\n'):
            raw_line = line.rstrip()  # Keep original format
            clean_line = raw_line.strip()

            if re.match(r"^source[ _]*lang(uage)?\s*:", clean_line, re.I):
                parsed_data["source_lang"] = re.split(r":\s*", clean_line, 1)[-1].strip()
                code_content_started = False
            elif re.match(r"^target[ _]*lang(uage)?\s*:", clean_line, re.I):
                parsed_data["target_lang"] = re.split(r":\s*", clean_line, 1)[-1].strip()
                code_content_started = False
            elif re.match(r"^code[ _]*content\s*:", clean_line, re.I):
                parsed_data["code_content"] = re.split(r":\s*", clean_line, 1)[-1].strip()
                code_content_started = True
            elif re.match(r"^potential[ _]*issues?\s*:", clean_line, re.I):
                current_section = "potential_issues"
                code_content_started = False
            elif current_section == "potential_issues" and clean_line.startswith(('-', '*')):
                parsed_data["potential_issues"].append(clean_line[1:].strip())
            elif code_content_started:
                parsed_data["code_content"] += "\n" + raw_line  # Keep original indentation

        # merge state instead of overwriting
        state.update({
            "source_lang": parsed_data["source_lang"] or state.get("source_lang", ""),
            "target_lang": parsed_data["target_lang"] or state.get("target_lang", ""),
            "code_content": parsed_data["code_content"] or state.get("code_content", ""),
            "potential_issues": parsed_data["potential_issues"] or state.get("potential_issues", [])
        })
        
        for key in ["source_lang", "target_lang", "code_content", "potential_issues"]:
            if key not in state:
                state[key] = "" if key != "potential_issues" else []

        # Enforce validation of required fields
        if not state["target_lang"]:
            error_msg = (
                "Failed to analyze the target language! Please make sure that the analysis result contains a clear Target Language field\n"
                f"Original analysis result:\n{analysis}\n"
                f"Current state: {state}"
            )
            self._log_step("parse_error", state, error_msg)
            raise ValueError(error_msg)

        self._log_step("parse_analysis", state, parsed_data)
        return state

    def _initial_translation(self, state: Dict) -> Dict:
        """Fix state merging issue"""
        # Keep original state and add translation results
        new_state = state.copy()

        translation_template = """You are an HPC code conversion expert. Convert this {{source_lang}} code to {{target_lang}}:
        Requirements:
        1. Maintain identical algorithmic logic
        2. Follow target language's performance best practices
        3. Add necessary comments explaining modifications
        4. Ensure syntactic correctness

        {{code_input}}

        Return ONLY converted code without explanations.
        """

        prompt = PromptTemplate(
            template=translation_template,
            input_variables=["source_lang", "target_lang", "code_input"],
            template_format="jinja2"
        )

        chain = prompt | self.llm
        result = chain.invoke({
            "source_lang": new_state["source_lang"],
            "target_lang": new_state["target_lang"],
            "code_input": new_state["code_content"]
        })

        new_state["translated_code"] = result.content
        self._log_step("initial_translation", state, result.content)
        return new_state

    def _validate_code(self, state: Dict) -> Dict:
        """Validate the code"""
        new_state = state.copy()
        
        # Pre-check
        required_keys = ["target_lang", "translated_code"]
        for key in required_keys:
            if key not in new_state:
                raise ValueError(f"Required parameters are missing during the verification phase:{key}")
                
        code_rules = self._get_code_rules(new_state["target_lang"])
        validation_template = """Review this {{target_lang}} code:
        {{code}}
        
        {% if code_rules %}
        Code Rules:
        {{code_rules}}
        {% endif %}

        Format your findings as:
        Issues Found: [Yes/No]
        Rule Violations:
        - [Rule1] violation description (line X)
        - [Rule2] violation description (line Y)
        Suggestions: 
        - [Suggestion1]
        - [Suggestion2]
        """
        
        prompt = PromptTemplate(
            template=validation_template,
            input_variables=["target_lang", "code"],
            partial_variables={"code_rules": code_rules},
            template_format="jinja2"
        )
        
        chain = prompt | self.llm
        result = chain.invoke({
            "target_lang": new_state["target_lang"],
            "code": new_state["translated_code"]
        })
        
        raw_validation = result.content
        # Clean up Markdown formatting and special symbols
        clean_validation = re.sub(r"\*\*|`", "", raw_validation)
        clean_validation = re.sub(r"\s+", " ", clean_validation) 
        clean_validation = re.sub(r"Issues?\s*Found\s*:\s*(\w+)", 
                                 r"Issues Found: \1", 
                                 clean_validation, 
                                 flags=re.IGNORECASE)
        self._log_step("validate_code", new_state, clean_validation)
#         return {"validation_result": result.content}
#         new_state["validation_result"] = result.content
        new_state["validation_result"] = clean_validation
#         print("==================================================")
#         print("Validate Code Result:")
#         print(new_state)
        return new_state

    def _improve_code(self, state: Dict) -> Dict:
        """Code Improvement Node"""
        improvement_template = """Strictly modify the code according to the following requirements:
        1. Modify the problem parts pointed out in the key verification report
        2. Keep the original functions and code structure unchanged
        {% if code_rules %}
        3. Must follow these rules:
        {{code_rules}}
        {% endif %}
        
        Verification Result:
        {{validation_result}}

        Original Code:
        {{current_code}}

        Return the complete corrected code without any additional explanations.
        """
        new_state = state.copy()
        new_state["iteration"] = new_state.get("iteration", 0) + 1
        
        target_lang = new_state.get("target_lang")
        if not target_lang:
            error_msg = "The target_lang field is missing in the status, please check the output of the analyze phase"
            self._log_step("improve_code_error", new_state, error_msg)
            raise ValueError(error_msg)
        
        code_rules = self._get_code_rules(new_state["target_lang"])
        
        prompt = PromptTemplate(
            template=improvement_template,
            input_variables=["validation_result", "current_code"],
            partial_variables={"code_rules": code_rules},
            template_format="jinja2"
        )
        
        chain = prompt | self.llm
        result = chain.invoke({
            "validation_result": new_state["validation_result"],
            "current_code": new_state["translated_code"]
        })
        
        improved_code = re.sub(r"(\/\/ Good\n)(.*?)\n\n", r"\1// Modified: {timestamp}\n\2\n", 
                          result.content,
                          flags=re.DOTALL)
        
        self._log_step("improve_code", new_state, result.content)
#         return {"translated_code": result.content}
#         state["improve_code"] = result.content
        print("==================================================")
        print("Input Code Result:")
        print(new_state["translated_code"])
        print("==================================================")
        print("Improve Code Result:")
        new_state["translated_code"] = improved_code
        print(new_state["translated_code"])
#         print("==================================================")
#         print("Improve Code Result:")
#         print(state)
        return new_state

    def _should_improve(self, state: Dict) -> str:
        validation_text = state.get("validation_result", "")

        print(f"\n=== Validation Debug ===")
        print(f"Iteration: {state.get('iteration', 0)}")
        print(f"Validation Text:\n{validation_text[:500]}...")
        
        if "Issues Found: No" in validation_text:
            print("Validation passed, finalizing...")
            return "final"
        if state.get("iteration", 0) >= self.max_iterations:
            print("Max iterations reached")
            return "final"
        print("Validation failed, needs improvement")
        return "improve"

    def _finalize_output(self, state: Dict) -> Dict:
        """Final output node"""
        result = {
            "source_language": state["source_lang"],
            "target_language": state["target_lang"],
            "original_code": state["code_content"],
            "translated_code": state["translated_code"],
            "execution_log": self.execution_log
        }
        self._log_step("finalize_output", state, result)
        return result

    def process_request(self, user_input: str) -> Dict:
        """Execute the conversion process"""
        initial_state = {
            "user_input": user_input,
            "iteration": 0
        }

        main_state = initial_state.copy()
        # Traverse the entire workflow and assign each output state to final_state
        for step in self.workflow.stream(initial_state):
#             print(f"Processing step: {step}") 

            if "__end__" in step:
                main_state.update(step["__end__"])
            else:
                for node_key, node_state in step.items():
                    if isinstance(node_state, dict):
                        main_state.update(node_state)

            if "iteration" in step:
                main_state["iteration"] = step["iteration"]

        return {
            "final_output": self._finalize_output(main_state),
            "execution_log": main_state.get("execution_log", []) 
        }

In [113]:
if __name__ == "__main__":
    system = CodeTranslationGraph("KB/code_rules.yaml")
    
#     user_input = """
#     Please help me convert the following FORTRAN code into CUDA code:
#     PROGRAM VECTOR_ADD
#     INTEGER, PARAMETER :: N = 1000000
#     REAL :: A(N), B(N), C(N)
#     DO I = 1, N
#         C(I) = A(I) + B(I)
#     END DO
#     END PROGRAM
#     """
    user_input = """
    Please help me convert the following C++ code into FORTRAN code:
    #include <stdio.h>\nint main(int argc, char* argv[])\n{\n int i;\n int len=100;\n int a[100], b[100];\n\n for (i=0;i<len;i++)\n {\n a[i]=i;\n b[i]=i+1;\n }\n\n#pragma omp simd \n for (i=0;i<len-1;i++)\n a[i+1]=a[i]+b[i];\n\n for (i=0;i<len;i++)\n printf("i=%d a[%d]=%d\n",i,i,a[i]);\n return 0;\n}\n
    """
    result = system.process_request(user_input)
    latest_code = result['final_output']['translated_code']
    
    for log_entry in result['final_output']['execution_log']:
        if log_entry['step'] == 'improve_code':
            print(f"Iteration {log_entry['input']['iteration']} Improve Result:")
            print(log_entry['output'])
            print("\n---\n")
    
    optimized_code = result['final_output']['translated_code'].strip('```').strip()
    if optimized_code.startswith('cuda'):
        optimized_code = optimized_code[4:].lstrip()
    
    print("Final Conversion Results:")
    print(optimized_code)
    
    print("\n=== Full Execution Log ===")
    for log in result['final_output']["execution_log"]:
        print(f"\n[{log['step']}]")
        print("Input:", log.get("input"))
        print("Output:", log.get("output"))


=== Validation Debug ===
Iteration: 0
Validation Text:
Issues Found: Yes Rule Violations: - F-DEC-001: Always use IMPLICIT NONE (line 2) - F-DEC-003: Always initialize variables (line 3) - F-ARR-001: Prefer intrinsic array operations over explicit loops (line 11) Suggestions: - Add IMPLICIT NONE at the beginning of the program. - Initialize variable len in line 3. - Refactor the loop in line 11 to use intrinsic array operations....
Validation failed, needs improvement
Input Code Result:
program main
  implicit none
  integer :: i, len
  integer, dimension(100) :: a, b

  len = 100

  do i = 1, len
    a(i) = i - 1
    b(i) = i
  end do

  !$omp simd
  do i = 1, len - 1
    a(i + 1) = a(i) + b(i)
  end do

  do i = 1, len
    print *, "i=", i, " a(", i, ")=", a(i)
  end do

end program main
Improve Code Result:
program main
  implicit none
  integer :: i, len
  integer, dimension(100) :: a, b

  len = 100

  do i = 1, len
    a(i) = i - 1
    b(i) = i
  end do

  !$omp simd
  a(2:) = a

In [114]:
# latest_code = result['final_output']['translated_code']
    
# for log_entry in result['final_output']['execution_log']:
#     if log_entry['step'] == 'improve_code':
#         print(f"Iteration {log_entry['input']['iteration']} Improve Result:")
#         print(log_entry['output'])
#         print("\n---\n")
    
# optimized_code = result['final_output']['translated_code'].strip('```').strip()
# if optimized_code.startswith('cuda'):
#     optimized_code = optimized_code[4:].lstrip()
    
# print("Conversion Results:")
# print(optimized_code)
    
# print("\n=== Full Execution Log ===")
# for log in result['final_output']["execution_log"]:
#     print(f"\n[{log['step']}]")
#     print("Input:", log.get("input"))
#     print("Output:", log.get("output"))