In [None]:
!pip install langchain
!pip install sentence-transformers
!pip install faiss-cpu
!pip install boto3
!pip install tqdm

In [None]:
!pip install -U langchain-community

In [None]:
!pip install langchain[bedrock]

In [None]:
!pip install faiss-gpu

In [None]:
import json
import time
import random
from tqdm import tqdm
from typing import Optional, List
from langchain.embeddings import HuggingFaceEmbeddings
import boto3
from botocore.exceptions import ClientError
import torch
from transformers import AutoTokenizer, AutoModel
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.vectorstores import FAISS
from langchain.embeddings.base import Embeddings
from langchain_core.language_models import BaseLLM
from langchain_core.outputs import Generation, LLMResult
from pydantic import Extra
from typing import Optional, List
from pydantic import Field
from langchain_core.language_models import BaseLLM
from langchain_core.outputs import Generation, LLMResult

# Σύνδεση με AWS Bedrock Runtime
bedrock = boto3.client("bedrock-runtime", region_name="us-west-2")


class BedrockLLM(BaseLLM):
    model_id: str = Field(...)
    region: str = Field(...)

    def __init__(self, model_arn: str, region: str = "us-west-2"):
        # model_id = model_arn.split("/")[-1]  # <-- ΤΟ ΣΩΣΤΟ!
        # print(model_id)
        super().__init__(model_id=model_arn, region=region)
        object.__setattr__(self, "client", boto3.client("bedrock-runtime", region_name=region))

    @property
    def _llm_type(self) -> str:
        return "bedrock-custom"

    # Μέθοδος για αποστολή prompt στο Bedrock και λήψη απάντησης
    def _generate(self, prompts: List[str], stop: Optional[List[str]] = None) -> LLMResult:
        generations = []
        for prompt in prompts:
            for attempt in range(5):
                try:
                    body = {
                        "prompt": prompt,
                        "max_gen_len": 512,
                        "temperature": 0.0,
                        "top_p": 0.9
                    }
                    
                    response = self.client.invoke_model(
                        modelId=self.model_id
                        ,  # <-- ΠΡΟΣΟΧΗ! ΟΧΙ ARN
                        body=json.dumps(body).encode("utf-8"),
                        contentType="application/json",
                        accept="application/json"
                    )
                    result = json.loads(response["body"].read())
                    text = result.get("generation") or result.get("outputs", [{}])[0].get("text", str(result))
                    generations.append([Generation(text=text)])
                    break
                except self.client.exceptions.ModelNotReadyException:
                    wait = 2 ** attempt + random.uniform(0, 1)
                    print(f"Model not ready. Retrying in {wait:.1f} seconds...")
                    time.sleep(wait)
                except Exception as e:
                    raise RuntimeError(f"Client error during invoke: {str(e)}")
        return LLMResult(generations=generations)


# Φλορτωση Dataset
input_path = "sagemaker_ft500vers2 (1) (1).jsonl"
with open(input_path, "r", encoding="utf-8") as f:
    dataset = [json.loads(line) for line in f]

# Ορισμός διανυσματικού μοντέλου
embedding_model = HuggingFaceEmbeddings(model_name="intfloat/e5-large-v2")

# Φόρτωση ανακτητή
retriever = FAISS.load_local(
    "e5largev2_rag_db",
    embedding_model,
    allow_dangerous_deserialization=True
).as_retriever()

# Επιλογή LLM
llm = BedrockLLM(
    #model_arn = "meta.llama3-1-70b-instruct-v1:0",
    model_arn="arn:aws:bedrock:us-west-2:471112783210:imported-model/y9xa363z3o7r",
    #model_arn="arn:aws:bedrock:us-west-2:471112783210:imported-model/1tfsf1vs44wd",
    region="us-west-2"
)

# Prompt templates
prompt_step1 = PromptTemplate(
    input_variables=["question"],
    template="""
Given the following clinical case and question, write a concise summary focusing only on the relevant clinical findings:

{question}

SUMMARY:
"""
)

prompt_step2 = PromptTemplate(
    input_variables=["summary"],
    template="""
You are a medical expert. Based on the clinical summary below, provide the most likely diagnosis with 2–3 reasonable differential diagnoses. Explain your reasoning step by step.

CLINICAL SUMMARY:
{summary}

DIAGNOSIS:
"""
)

prompt_step3 = PromptTemplate(
    input_variables=["summary", "diagnosis", "context"],
    template="""
Based on the following:
- Summary: {summary}
- Most likely diagnosis: {diagnosis}
- Relevant guideline info: {context}

Provide short-term and long-term management recommendations, including risk scores if applicable.
"""
)

prompt_step4 = PromptTemplate(
    input_variables=["question", "summary", "management"],
    template="""
Based on the following case and options:

CASE:
{question}

SUMMARY:
{summary}

MANAGEMENT PLAN:
{management}

Select the best answer choice and justify why it's correct and why the others are incorrect. Begin your answer with the letter (e.g., A.)

FINAL ANSWER:
"""
)

# Αξιολόγηση
output_path = "validation_answers_e5large_multistep_reasoning_ft&nocontext.jsonl"
with open(output_path, "w", encoding="utf-8") as fout:
    for example in tqdm(dataset):
        question = example["model_input"]
        # print(question)
        step1_output = LLMChain(llm=llm, prompt=prompt_step1).run(question=question)
        step2_output = LLMChain(llm=llm, prompt=prompt_step2).run(summary=step1_output)

        docs = retriever.get_relevant_documents(step1_output + step2_output)
        context = "\n".join([doc.page_content for doc in docs[:3]])

        step3_output = LLMChain(llm=llm, prompt=prompt_step4).run(
            summary=step1_output,
            diagnosis=step2_output,
            context=context
        )

        step4_output = LLMChain(llm=llm, prompt=prompt_step5).run(
            question=question,
            summary=step1_output,
            management=step4_output
        )
        # print(step5_output)

        output = {
            "question": question,
            "target_output": example.get("target_output", ""),
            "step1_summary": step1_output,
            "step2_diagnosis": step2_output,
            "step3_management": step4_output,
            "step4_final_answer": step5_output
        }
        fout.write(json.dumps(output) + "\n")
