<a href="https://colab.research.google.com/github/susub2/Embedded-project/blob/main/%08%08GAI_PROJECT_FINAL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install gradio vllm transformers triton PyPDF2 Pillow sentence_transformers numpy typing faiss-gpu spacy pymupdf4llm fitz frontend tools semchunk

In [None]:
import gradio as gr
import faiss
import numpy as np
import spacy
from sentence_transformers import SentenceTransformer
import os
import time
import semchunk
import pymupdf as fitz
import pymupdf4llm
from vllm import LLM, SamplingParams
from typing import List, Tuple, Dict, Optional
from PIL import Image
import hashlib
import logging
import torch
import gc

# 전역 변수 초기화
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", dtype='half', max_model_len=8192)
sampling_params = SamplingParams(temperature=0.7, max_tokens=512)
logging.basicConfig(level=logging.INFO)
embedder = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')
chunker = semchunk.chunkerify('gpt-4', 200)

In [None]:
"""
PDF 파일 RAG를 위한 Pipeline class
"""
class RAGPipeline:
    def __init__(self):
        # 전역 변수로 선언된 llm과 sampling_params 사용
        self.llm = llm
        self.sampling_params = sampling_params

        # embedding
        self.embedder = embedder
        self.chunker = chunker
        self.index = faiss.IndexFlatL2(self.embedder.get_sentence_embedding_dimension())
        self.chunks = []
        self.processed_files = {} # {file_hash: file_path}

    def get_file_hash(self, file_path: str) -> str:
        with open(file_path, "rb") as f:
            return hashlib.md5(f.read()).hexdigest()

    def indexing_pdf(self, pdf_path: List[str]):
        for pdf in pdf_path:
            try:
                file_hash = self.get_file_hash(pdf)
                if file_hash in self.processed_files:
                    logging.info(f"{pdf} has already been processed before")
                    continue

                self.processed_files[file_hash] = pdf
                logging.info(f"Processing new file: {pdf}")

                doc = fitz.open(pdf)
                markdown_text = pymupdf4llm.to_markdown(doc)
                doc.close()

                chunks = self.chunker(markdown_text)
                self.chunks.extend(chunks)
                chunks_embeddings = self.embedder.encode(chunks)
                self.index.add(chunks_embeddings)
            except Exception as e:
                logging.error(f"Error in indexing {pdf_path}: {e}")

        logging.info(f"Processed {len(pdf_path)} files. Total unique files: {len(self.processed_files)}")

    def process_query(self, query: str, top_k: int = 5) -> List[str]:
        query_embedding = self.embedder.encode([query])
        distances, indices = self.index.search(query_embedding, top_k)
        return [self.chunks[i] for i in indices[0]]

    def prompt_template(self, query: str, context: List[str]) -> str:
        system_message = """You are an AI assistant tasked with answering questions based on provided context. Your role is to:
                            1. Carefully analyze the given context
                            2. Provide accurate and relevant information
                            3. Synthesize a coherent response
                            4. Maintain objectivity and clarity
                            If the context doesn't contain sufficient information, state so clearly."""

        context_str = "\n".join([f"Context {i+1}: {ctx}" for i, ctx in enumerate(context)])

        prompt = f"""[INST] {system_message}

            Relevant information:
            {context_str}

            User's Quetion: {query}

            Instructions:
            - Answer the query using only the information provided in the context.
            - If the context doesn't contain enough information to fully answer the query, acknowledge this limitation in your response.
            - Provide a concise yet comprehensive answer.
            - Do not introduce information not present in the given context.
            - Provide in complete sentences in English always.
            - Check once again your response so that the user can be provided precise information.

            Please provide your response below:
            [/INST]"""

        return prompt

    def generate_response(self, query: str, context: List[str]) -> str:
        prompt = self.prompt_template(query, context)
        output = self.llm.generate([prompt], self.sampling_params)
        return output[0].outputs[0].text

    def answer_query(self, query: str, top_k: int = 5) -> str:
        retrieved_contexts = self.process_query(query, top_k)
        return self.generate_response(query, retrieved_contexts)

"""
이미지 처리를 위한 LLaVA Processor class
"""
class LLaVAImageQAProcessor:
    def __init__(self):
        self.llm = llm
        self.sampling_params = sampling_params

    def get_prompt(self, question: str):
        # 기본 설명 요청인 경우
        if question.lower() in ["이 이미지에 대해 설명해줘", "이미지를 설명해줘", "이 이미지를 설명해줘"]:
            return f"""[INST] <image>
                    Describe this image comprehensively in bullet points.
                    Focus on:
                    - Main subjects and their characteristics
                    - Setting and background
                    - Overall mood or atmosphere
                    Your response should be in complete sentences.
                    [/INST]"""

        # 특정 부분에 대한 질문인 경우
        else:
            return f"""[INST] <image>
                    Focus specifically on answering this question: {question}
                    Provide a detailed response about exactly what was asked.
                    Stay focused on the specific aspect mentioned in the question.
                    Your response should be in complete sentences and bullet points.
                    [/INST]"""

    def process_image(self, image: Image.Image, question: str) -> str:
        prompt = self.get_prompt(question)
        try:
            inputs = {"prompt": prompt, "multi_modal_data": {"image": image}}
            outputs = self.llm.generate(inputs, self.sampling_params)
            return outputs[0].outputs[0].text.strip() if outputs else "Failed to generate response."
        finally:
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

class GeneralChatProcessor:
    def __init__(self):
        self.llm = llm
        self.sampling_params = sampling_params
        self.history = []

    def process_query(self, query: str) -> str:
        context = "\n".join(self.history[-5:])
        prompt = f"""[INST] You are a friendly and knowledgeable AI assistant. Please respond to the user's question following these guidelines:

                      1. Carefully read and understand the question.
                      2. Consider the conversation history for context: {context}
                      3. Utilize relevant information and knowledge to provide accurate and informative answers.
                      4. Write your response clearly and concisely, but include sufficient explanation when necessary.
                      5. Provide only fact-based information, and explicitly state when something is uncertain.
                      6. Maintain a polite and courteous attitude, considering the user's feelings.
                      7. When needed, provide examples or step-by-step instructions.
                      8. Only provide information that is ethical and legally appropriate.
                      9. Show a welcoming attitude towards additional questions or requests for clarification from the user.

                      User's question: {query}

                      Please respond following the above guidelines.

                      [/INST]"""

        output = self.llm.generate([prompt], self.sampling_params)
        response = output[0].outputs[0].text.strip()
        self.history.append(f"User: {query}")
        self.history.append(f"Assistant: {response}")

        return response

"""
세션 관리를 위한 SessionManager class
"""
class SessionManager:
    def __init__(self):
        self.sessions = {
            "Example": {
                "history": [],
                "mode": "General Chat",
                "mode_locked": False,
                "rag_pipeline": RAGPipeline(),
                "img_processor": LLaVAImageQAProcessor(),
                "general_processor": GeneralChatProcessor(),
                "current_image": None,
                "selected_image_name": ""
            }
        }
        self.current_session = "Example"

    def create_session(self, session_name: str) -> bool:
        if session_name in self.sessions:
            return False

        self.sessions[session_name] = {
            "history": [],
            "mode": "General Chat",
            "mode_locked": False,
            "rag_pipeline": RAGPipeline(),
            "img_processor": LLaVAImageQAProcessor(),
            "general_processor": GeneralChatProcessor(),
            "current_image": None,
            "selected_image_name": ""
        }
        self.current_session = session_name
        return True

    def delete_session(self, session_name: str) -> Tuple[str, dict]:
        if len(self.sessions) <= 1:
            return None, None

        if session_name in self.sessions:
            del self.sessions[session_name]
            next_session = next(iter(self.sessions.keys()))
            self.current_session = next_session
            return next_session, self.sessions[next_session]
        return None, None

    def get_session(self, session_name: str) -> Optional[dict]:
        return self.sessions.get(session_name)

In [None]:
# GUI 구현부
def create_ui():
    session_manager = SessionManager()

    custom_css = """
    .message-box {
        display: flex;
        align-items: center;
        gap: 0.5rem;
    }
    .file-btn {
        max-width: 40px;
    }
    .send-btn {
        max-width: 40px;
    }
    .selected-file {
        margin: 0.5rem 0;
        padding: 0.3rem;
        background: #f0f0f0;
        border-radius: 4px;
        font-size: 0.9em;
    }
    """

    with gr.Blocks(css=custom_css) as demo:
        with gr.Row():
            # 왼쪽 패널: 세션 관리
            with gr.Column(scale=1):
                new_session_btn = gr.Button("+ New Session")
                session_title_input = gr.Textbox(
                    label="Session Title",
                    visible=False
                )
                with gr.Column(elem_classes="session-container"):
                    gr.Markdown("Sessions")
                    session_list = gr.Radio(
                        choices=["Example"],
                        value="Example",
                        label=""
                    )
                    delete_btn = gr.Button("🗑️ Delete Session")

            # 오른쪽 패널: 채팅 인터페이스
            with gr.Column(scale=3):
                current_title = gr.Markdown("## Example")

                with gr.Row():
                    chat_mode = gr.Radio(
                        choices=["General Chat", "Image Chat", "RAG Chat"],
                        value="General Chat",
                        label=""
                    )

                chatbot = gr.Chatbot(
                    height=400,
                    render_markdown=True,
                    show_copy_button=True,
                    bubble_full_width=False
                )

                # 메시지 입력 영역
                with gr.Row():
                    # Image Chat 모드용 파일 업로드 (작은 버튼)
                    with gr.Column(scale=1, visible=False, min_width=50) as image_chat:
                        file_upload_image = gr.UploadButton(
                            "📎",
                            file_types=[".jpg", ".jpeg", ".png"],
                            scale=1
                        )

                    # 메시지 입력창
                    with gr.Column(scale=8):
                        msg = gr.Textbox(
                            show_label=False,
                            placeholder="메시지를 입력하세요...",
                            container=False
                        )

                    # 전송 버튼
                    with gr.Column(scale=1, min_width=50):
                        send_btn = gr.Button("↑")

                # RAG Chat 모드용 PDF 업로드 (Clear Chat과 동일한 너비)
                with gr.Row():
                    with gr.Column(scale=1, visible=False) as rag_chat:
                        file_upload_pdf = gr.File(
                            label="PDF Upload",
                            file_types=[".pdf"],
                            file_count="multiple"
                        )

                with gr.Row():
                    clear_btn = gr.Button("Clear Chat")

                with gr.Row(visible=False) as general_file_info:
                    selected_image = gr.Textbox(
                        label="Selected Image",
                        interactive=False
                    )

                with gr.Row(visible=False) as rag_file_info:
                    selected_pdf = gr.Textbox(
                        label="Selected PDF",
                        interactive=False
                    )

        # 메시지 처리 함수
        def process_message(message, file_image, files_pdf, mode, history, session_name):
            try:
                session = session_manager.get_session(session_name)
                if not session:
                    return "세션을 찾을 수 없습니다."

                current_mode = session["mode"]

                if current_mode == "Image Chat":
                    if file_image:
                        with Image.open(file_image) as image:
                            if image.mode != 'RGB':
                                image = image.convert('RGB')
                            session["current_image"] = image.copy()
                            session["selected_image_name"] = file_image.name
                            image_for_process = image
                    elif session["current_image"]:
                        image_for_process = session["current_image"]
                    else:
                        return "이미지를 먼저 업로드해주세요."

                    question = message if message.strip() else "이 이미지에 대해 설명해주세요."
                    return session["img_processor"].process_image(image_for_process, question)

                elif current_mode == "RAG Chat":
                    if files_pdf:
                        pdf_paths = [f.name for f in files_pdf]
                        session["rag_pipeline"].indexing_pdf(pdf_paths)
                        return f"{len(pdf_paths)}개의 PDF가 성공적으로 처리되었습니다. 이제 문서에 대해 질문할 수 있습니다."

                    return session["rag_pipeline"].answer_query(message)

                else:  # General Chat
                    return session["general_processor"].process_query(message)

            except Exception as e:
                logging.error(f"메시지 처리 중 오류: {str(e)}")
                return f"오류가 발생했습니다: {str(e)}"

        def chat_mode_change(mode, session_name):
            session = session_manager.get_session(session_name)
            if not session:
                return [gr.update()] * 6
            if session["mode_locked"]:
                gr.Warning("대화가 시작된 후에는 모드를 변경할 수 없습니다. 새 세션을 만들어주세요.")
                current_mode = session["mode"]
            else:
                session["mode"] = mode
                current_mode = mode

            is_image = current_mode == "Image Chat"
            is_rag = current_mode == "RAG Chat"

            return [
                gr.update(value=current_mode),
                gr.update(),  # msg는 항상 표시
                gr.update(visible=is_image),
                gr.update(visible=is_rag),
                gr.update(visible=is_image),
                gr.update(visible=is_rag)
            ]


        # 메시지 전송 처리 함수
        def send_message(message, file_image, files_pdf, session_name, mode, history):
            if not message.strip() and not (file_image or files_pdf):
                return history, "", None, None, "", ""

            try:
                session = session_manager.get_session(session_name)
                if not session:
                    return history, "", None, None, "", ""

                if not session["mode_locked"] and (message.strip() or file_image or files_pdf):
                    session["mode_locked"] = True
                    session["mode"] = mode

                current_mode = session["mode"]
                response = process_message(message, file_image if current_mode == "Image Chat" else None,
                                          files_pdf if current_mode == "RAG Chat" else None,
                                          current_mode, history, session_name)

                if current_mode == "Image Chat" and file_image:
                    history.append(((file_image.name, file_image), message if message.strip() else None))
                elif current_mode == "RAG Chat" and files_pdf:
                    pdf_names = [f.name for f in files_pdf]
                    history.append((f"Uploaded PDFs: {', '.join(pdf_names)}", None))
                else:
                    history.append((None, message))

                history.append((None, response))
                session["history"] = history

                return (history, "", None, None,
                        session["selected_image_name"] if current_mode == "Image Chat" else "",
                        ", ".join(f.name for f in files_pdf) if files_pdf else "")

            except Exception as e:
                logging.error(f"메시지 전송 중 오류: {str(e)}")
                return history, "", None, None, "", ""

        # 세션 관리 함수들
        def add_session(title):
            if not title:
                return gr.update(visible=False), gr.update(choices=list(session_manager.sessions.keys()))

            if session_manager.create_session(title):
                return gr.update(visible=False), gr.update(choices=list(session_manager.sessions.keys()), value=title)
            else:
                gr.Warning("이미 존재하는 세션 이름입니다.")
                return gr.update(visible=False), gr.update(choices=list(session_manager.sessions.keys()))

        def switch_session(session_name):
            session = session_manager.get_session(session_name)
            if session:
                session_manager.current_session = session_name
                return (
                    f"## {session_name}",
                    session["history"],
                    session["mode"]
                )
            return current_title, [], chat_mode.value

        def delete_session(session_name):
            next_session, session_data = session_manager.delete_session(session_name)
            if next_session is None:
                gr.Warning("마지막 세션은 삭제할 수 없습니다")
                return (
                    gr.update(choices=list(session_manager.sessions.keys()), value=session_name),
                    current_title,
                    chatbot,
                    chat_mode
                )

            return (
                gr.update(choices=list(session_manager.sessions.keys()), value=next_session),
                f"## {next_session}",
                session_data["history"],
                session_data["mode"]
            )

        def update_selected_file(file):
            session = session_manager.get_session(session_manager.current_session)
            if session:
                session["selected_image_name"] = file.name

            is_image = file.name.lower().endswith(('.jpg', '.jpeg', '.png'))
            return (
                file.name if is_image else "",  # selected_image
                "" if is_image else file.name,  # selected_pdf
                gr.update(visible=is_image),    # general_file_info visibility
                gr.update(visible=not is_image) # rag_file_info visibility
            )

        def clear_chat():
            session = session_manager.get_session(session_manager.current_session)
            if session:
                session["current_image"] = None
                session["selected_image_name"] = ""
            return [], "", ""

        # 이벤트 바인딩
        new_session_btn.click(
            lambda: gr.update(visible=True),
            outputs=session_title_input
        )

        session_title_input.submit(
            add_session,
            inputs=[session_title_input],
            outputs=[session_title_input, session_list]
        )

        session_list.change(
            switch_session,
            inputs=[session_list],
            outputs=[current_title, chatbot, chat_mode]
        )

        delete_btn.click(
            delete_session,
            inputs=[session_list],
            outputs=[session_list, current_title, chatbot, chat_mode]
        )

        send_btn.click(
            send_message,
            inputs=[msg, file_upload_image, file_upload_pdf, session_list, chat_mode, chatbot],
            outputs=[chatbot, msg, file_upload_image, file_upload_pdf, selected_image, selected_pdf]
        )

        msg.submit(
            send_message,
            inputs=[msg, file_upload_image, file_upload_pdf, session_list, chat_mode, chatbot],
            outputs=[chatbot, msg, file_upload_image, file_upload_pdf, selected_image, selected_pdf]
        )

        file_upload_image.upload(
            update_selected_file,
            inputs=[file_upload_image],
            outputs=[
                selected_image,
                selected_pdf,
                general_file_info,
                rag_file_info
            ]
        )

        chat_mode.change(
            chat_mode_change,
            inputs=[chat_mode, session_list],
            outputs=[chat_mode, msg, image_chat, rag_chat, general_file_info, rag_file_info]
        )

        clear_btn.click(
            clear_chat,
            outputs=[chatbot, selected_image, selected_pdf]
        )

        return demo

# GUI 실행
demo = create_ui()
demo.launch(debug=True)