<a href="https://colab.research.google.com/github/ray-ik/Hello-world/blob/master/Create_streamlit_app.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import streamlit as st
import os
import torch
import chromadb
from neo4j import GraphDatabase
from transformers import AutoTokenizer, AutoModelForCausalLM
import uuid

# --- Setup and Initialization ---

st.set_page_config(page_title="COBOL Code Fixer", layout="wide")

# Initialize ChromaDB client and collection
try:
    chroma_client = chromadb.Client()
    chroma_collection = chroma_client.get_or_create_collection(name="cobol_code_fixes")
except Exception as e:
    st.error(f"Failed to initialize ChromaDB: {e}")
    chroma_collection = None

# Initialize Neo4j driver
NEO4J_URI = "neo4j+ssc://da1763a7.databases.neo4j.io"
NEO4J_USER = "neo4j"
NEO4J_PASSWORD = "IiQHYUsyVxNvkz51Kcb690cQjccnRHziHzS8T0QOlp0" # <<< CHANGE THIS PASSWORD
try:
    neo4j_driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
except Exception as e:
    st.error(f"Failed to connect to Neo4j: {e}. Please check your database is running and credentials are correct.")
    neo4j_driver = None

# Load StarCoder2 model and tokenizer
@st.cache_resource
def load_starcoder2_model():
    """Loads StarCoder2 model and tokenizer, caching to avoid re-loading on each run."""
    try:
        tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder2-3b")
        model = AutoModelForCausalLM.from_pretrained("bigcode/starcoder2-3b", torch_dtype=torch.bfloat16)
        return tokenizer, model
    except Exception as e:
        st.error(f"Failed to load StarCoder2 model: {e}. Check your internet connection or model availability.")
        return None, None

tokenizer, model = load_starcoder2_model()

# --- Functions for Logic ---

def get_code_embedding(code_snippet):
    """Generates a proxy embedding for a code snippet using StarCoder2's hidden states."""
    if not tokenizer or not model:
        return None
    try:
        inputs = tokenizer(code_snippet, return_tensors="pt")
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)
        last_hidden_state = outputs.hidden_states[-1]
        last_token_embedding = last_hidden_state[0, -1, :]
        return last_token_embedding.numpy()
    except Exception as e:
        st.error(f"Error generating embedding: {e}")
        return None

def store_in_chroma(erroneous_code, fixed_code, unique_id):
    """Stores a single code pair and its fix in ChromaDB."""
    if not chroma_collection:
        return False
    try:
        erroneous_embedding = get_code_embedding(erroneous_code)
        if erroneous_embedding is not None:
            chroma_collection.add(
                documents=[erroneous_code],
                metadatas=[{"fixed_code": fixed_code}],
                embeddings=[erroneous_embedding.tolist()],
                ids=[unique_id]
            )
            return True
    except Exception as e:
        st.error(f"Error storing in ChromaDB for ID {unique_id}: {e}")
    return False

def store_in_neo4j(erroneous_id, fixed_id, relationships=None):
    """Stores code relationship knowledge in Neo4j."""
    if not neo4j_driver:
        return False
    try:
        with neo4j_driver.session() as session:
            session.run("MERGE (p:Program {id: $id, type: 'broken'})", id=erroneous_id)
            session.run("MERGE (p:Program {id: $id, type: 'fixed'})", id=fixed_id)
            session.run("MATCH (a:Program {id: $id1}), (b:Program {id: $id2}) "
                        "MERGE (a)-[:FIXED_BY]->(b)", id1=erroneous_id, id2=fixed_id)
            if relationships:
                for rel in relationships:
                    session.run("MATCH (a:Program {id: $id1}), (b:Module {name: $mod_name}) "
                                "MERGE (a)-[:CALLS]->(b)", id1=erroneous_id, mod_name=rel)
        return True
    except Exception as e:
        st.error(f"Error storing in Neo4j for ID {erroneous_id}: {e}")
    return False

def fix_new_code(new_broken_code):
    """Combines vector search and graph knowledge to fix new code."""
    if not chroma_collection or not tokenizer or not model:
        st.warning("Prerequisites not met. Cannot fix code.")
        return "Could not fix the code. Please ensure databases are running and the model is loaded."

    with st.spinner("Searching for similar fixes..."):
        new_embedding = get_code_embedding(new_broken_code)
        if new_embedding is None:
            return "Could not generate an embedding for the new code."

        search_results = chroma_collection.query(
            query_embeddings=[new_embedding.tolist()],
            n_results=1,
            include=['documents', 'metadatas']
        )

        if not search_results['documents']:
            return "No similar historical code fixes found in the database."

        relevant_broken_code = search_results['documents'][0][0]
        relevant_fixed_code = search_results['metadatas'][0][0]['fixed_code']

    # Use the combined knowledge to prompt StarCoder2
    with st.spinner("Generating fix with StarCoder2..."):
        prompt = f"""
        Given the following broken COBOL code and a historical example of a fix, generate the corrected code.

        ### Historical Example
        Broken Code:
        ```cobol
        {relevant_broken_code}
        ```
        Fixed Code:
        ```cobol
        {relevant_fixed_code}
        ```

        ### New Broken Code
        ```cobol
        {new_broken_code}
        ```

        ### Corrected Code
        ```cobol
        """

        inputs = tokenizer(prompt, return_tensors="pt")
        output_ids = model.generate(
            inputs.input_ids,
            max_length=len(inputs.input_ids[0]) + 500,
            pad_token_id=tokenizer.eos_token_id,
            do_sample=True,
            temperature=0.7,
            top_k=50,
        )
        generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

        try:
            fixed_code_part = generated_text.split("### Corrected Code")[1].strip()
            return fixed_code_part.replace("```cobol", "").replace("```", "").strip()
        except IndexError:
            return "Failed to extract the fixed code from the model's response."

# --- Streamlit UI ---

st.title("COBOL Legacy Code Fixer 🚀")
st.markdown("Use this tool to upload pairs of broken and fixed COBOL code listings to train a knowledge base. Then, upload a new broken program to get a suggested fix.")

st.header("1. Train the Knowledge Base")
st.markdown("Upload multiple pairs of erroneous and fixed code listings. **The order of files must match.** For example, the first file in the 'Erroneous' list should correspond to the first file in the 'Fixed' list.")
with st.container():
    col1, col2 = st.columns(2)
    with col1:
        st.subheader("Erroneous Code Listings")
        erroneous_files = st.file_uploader("Select .cbl files with bugs", type=["cbl"], accept_multiple_files=True, key="erroneous_upload")
    with col2:
        st.subheader("Fixed Code Listings")
        fixed_files = st.file_uploader("Select .cbl files with fixes", type=["cbl"], accept_multiple_files=True, key="fixed_upload")

    if st.button("Store these Fixes"):
        if len(erroneous_files) == 0 or len(fixed_files) == 0:
            st.warning("Please upload at least one pair of files.")
        elif len(erroneous_files) != len(fixed_files):
            st.error("The number of erroneous files must match the number of fixed files. Please check your selection.")
        else:
            total_pairs = len(erroneous_files)
            success_count = 0
            with st.spinner(f"Storing {total_pairs} fix pairs..."):
                for i in range(total_pairs):
                    erroneous_code = erroneous_files[i].getvalue().decode("utf-8")
                    fixed_code = fixed_files[i].getvalue().decode("utf-8")
                    unique_id = str(uuid.uuid4())

                    chroma_success = store_in_chroma(erroneous_code, fixed_code, unique_id)
                    neo4j_success = store_in_neo4j(unique_id, f"fixed-{unique_id}")

                    if chroma_success and neo4j_success:
                        success_count += 1
                    else:
                        st.warning(f"Failed to store pair {i+1}.")

            st.success(f"Successfully stored {success_count} of {total_pairs} fix pairs.")


st.divider()

st.header("2. Fix a New Program")
new_program_file = st.file_uploader(
    "Upload the new broken COBOL program to fix (.cbl)",
    type=["cbl"]
)

if st.button("Fix Program"):
    if not new_program_file:
        st.warning("Please upload a COBOL file to fix.")
    else:
        new_program_code = new_program_file.getvalue().decode("utf-8")
        fixed_program = fix_new_code(new_program_code)

        st.subheader("Suggested Fix")
        st.code(fixed_program, language="cobol")

        st.success("Fix generated successfully!")
