In [None]:
from langchain_core.output_parsers import StrOutputParser
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import PromptTemplate

import torch
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import os
from torchvision import transforms
from PIL import Image
from torchvision import models
from dotenv import load_dotenv

In [None]:
load_dotenv()

In [5]:
def do_all(diagnosis, history, language = "en"):
    
    if diagnosis == "normal":
        if language == "en":
            print("Congratulations, you are normal!")
            return
        elif language == "hi":
            print("बधाई हो, आप बिलकुल ठीक हैं")
            return

    parser = StrOutputParser()
    llm = ChatOpenAI()
    
    if diagnosis == "diabetic_retinopathy":
        file = "diabetic_retinopathy.txt"
    elif diagnosis == "glaucoma":
        file = "glaucoma.txt"
    elif diagnosis == "cataract":
        file = "cataract.txt"
    with open(file, 'r', encoding = 'utf-8') as file:
        text = file.read() 
    text = text.replace(r"\n\n", '\n')
    
    splitter = RecursiveCharacterTextSplitter(chunk_size = 1000, chunk_overlap = 200)
    chunks = splitter.create_documents([text])
    embeddings = OpenAIEmbeddings(model = "text-embedding-3-small")
    vector_store = FAISS.from_documents(chunks, embeddings)
    retriever = vector_store.as_retriever(search_type = "similarity", search_kwargs = {'k':4})
    
    summary_prompt = PromptTemplate(
        template = """
        You have to summarize the patient's history into short pointers(only the parts useful for querying documents): {history}
        """,
        input_variables = ['history']
    )
    summary_chain = summary_prompt | llm | parser
    summary = summary_chain.invoke({"history":history})
    
    academia = retriever.invoke(summary)
    causes_academia = "\n\n\n".join([ac.page_content for ac in academia])

    causes_prompt = PromptTemplate(
        template = """
        You are an assistant who gives the causes given the following diagnosis: {diagnosis}, 
        And the following patient history : {context}
        You have the following medical academia to infer a cause : {academia}
        """,
        input_variables=['diagnosis', 'context', 'academia']
    )
    cause_chain = causes_prompt | llm | parser
    cause = cause_chain.invoke({"diagnosis":diagnosis, "context":summary, "academia": causes_academia})

    refrence = retriever.invoke(f'Treatment methods for {diagnosis}')
    treatment_academia = "\n\n".join([ac.page_content for ac in refrence])

    treatment_prompt = PromptTemplate(
        template = """
        You are a helpful assistant and your role is the give the treatment for the following diagnosis: {diagnosis}
        You are given the following sets of causes: {causes}
        Use the following academia as reference: {reference}
        """,
        input_variables = ['causes','diagnosis', 'reference']
    )
    treatment_chain = treatment_prompt | llm | parser
    treatment = treatment_chain.invoke({'causes': cause, "diagnosis":diagnosis, "reference": treatment_academia})

    translate_prompt = PromptTemplate(
        template = "Convert the following text to hindi: {text}",
        input_variables = ['text']
    )
    translator = translate_prompt | llm | parser
    
    if language == "hi":
        cause = translator.invoke({"text":cause})
        treatment = translator.invoke({'text':treatment})
        print("कारण: \n")
        print(cause, "\n\n")
        print("इलाज: \n")
        print(treatment, "\n\n")
    elif language == "en":
        print('Cause: \n')
        print(cause)
        print('Treatment: \n')
        print(treatment)

In [6]:
diagnosis = "glaucoma"
history = """
Doctor, I’ve been having trouble with my eyes lately, and I’m worried it might be something serious.

Over the past few months, I’ve noticed that my peripheral vision isn’t as sharp as it used to be—it’s like I’m looking through a tunnel sometimes. I’ve also been getting frequent headaches, especially around my brow and temples, and my eyes often feel achy or strained.

A few times, I’ve had sudden blurry vision in one eye, along with halos around lights, particularly at night. My right eye seems worse—colors don’t look as vivid, and sometimes there’s a dull pain behind it.
"""
language = "en"

In [7]:
do_all(diagnosis = diagnosis, language = "en", history = history)

ImportError: Could not import faiss python package. Please install it with `pip install faiss-gpu` (for CUDA supported GPU) or `pip install faiss-cpu` (depending on Python version).

In [None]:
checkpoint = torch.load('fundus_classifier(test-89.38, train-96.74).pth')

model = CNN()  # Make sure CNN matches the model architecture you trained
model.load_state_dict(checkpoint['model_state_dict'])  # ✅ Correct key
model.to(device)
model.eval()

# Also extract label mappings if needed:
idx_to_label = checkpoint['idx_to_label']
label_to_idx = checkpoint['label_to_idx']

In [None]:
import torch.nn.functional as F
do_inference('dataset/glaucoma/_10_1472170.jpg', idx_to_label, model)