In [2]:
import tkinter as tk
from tkinter import ttk, scrolledtext, messagebox
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
import json
from googletrans import Translator
from difflib import SequenceMatcher
import re

class Gemma3QAInference:
    def __init__(self, model_path: str = "gemma3"):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")

        # Load Gemma3 model and tokenizer
        print("Loading Gemma3 model and tokenizer...")
        try:
            self.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
            self.model = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b-it")
            self.model.to(self.device)
            self.model.eval()
            print("Gemma3 model loaded successfully!")
        except Exception as e:
            print(f"Error loading Gemma3 model: {e}")
            # Fallback to a smaller model if available
            try:
                self.tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
                self.model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
                self.model.to(self.device)
                self.model.eval()
                print("Fallback model loaded successfully!")
            except Exception as e2:
                print(f"Error loading fallback model: {e2}")
                raise e2
        
        # Initialize translator
        self.translator = Translator()
        
        # Load QA dataset (Malay only)
        self.qa_data = self.load_qa_data("qa_train.json")
    
    def translate(self, text, src='auto', dest='ms'):
        """Translate text using Google Translate."""
        try:
            result = self.translator.translate(text, src=src, dest=dest)
            return result.text
        except Exception as e:
            print(f"Translation error: {e}")
            return text
    
    def find_similar_question(self, question, threshold=0.3):
        """
        Find the most similar question in the training data.
        
        Args:
            question: The input question
            threshold: Minimum similarity threshold
            
        Returns:
            Best matching training example or None
        """
        if not self.qa_data:
            return None
        
        best_match = None
        best_score = 0
        
        # Clean the input question
        question_clean = re.sub(r'[^\w\s]', '', question.lower())
        
        for example in self.qa_data:
            # Clean the training question
            train_question_clean = re.sub(r'[^\w\s]', '', example['question'].lower())
            
            # Calculate similarity
            similarity = SequenceMatcher(None, question_clean, train_question_clean).ratio()
            
            # Also check for keyword overlap
            question_words = set(question_clean.split())
            train_words = set(train_question_clean.split())
            keyword_overlap = len(question_words.intersection(train_words)) / max(len(question_words), 1)
            
            # Combined score
            combined_score = (similarity + keyword_overlap) / 2
            
            if combined_score > best_score:
                best_score = combined_score
                best_match = example
        
        return best_match if best_score >= threshold else None
    
    def load_qa_data(self, file_path):
        """Load Malay QA data from JSON file."""
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
                # Ensure we have a list of QA pairs
                if isinstance(data, dict):
                    return [data]  # If single QA pair, make it a list
                return data
        except Exception as e:
            print(f"Error loading QA data: {e}")
            return []
    
    def get_context_for_question(self, question: str):
        # Try to find similar question in training data
        similar_example = self.find_similar_question(question)
        
        if similar_example:
            return similar_example
        
        # Fallback: search for keywords in training data
        question_lower = question.lower()
        
        # Extract potential keywords from the question
        keywords = []
        for word in question_lower.split():
            if len(word) > 3:  # Only consider words longer than 3 characters
                keywords.append(word)
        
        # Search for examples containing these keywords
        for example in self.qa_data:
            context_lower = example['context'].lower()
            question_lower = example['question'].lower()
            
            # Check if any keyword appears in context or question
            for keyword in keywords:
                if keyword in context_lower or keyword in question_lower:
                    return example
        
        # Final fallback: return a generic context
        return {
            'context': "Budaya Cina Malaysia adalah sebahagian penting dalam masyarakat Malaysia yang telah diwarisi dari generasi ke generasi.",
            'question': question,
            'answer': "Budaya Cina Malaysia merangkumi pelbagai aspek termasuk perayaan, makanan, dan adat istiadat yang telah diwarisi dari generasi ke generasi."
        }
    
    def ask_question(self, question: str, context: str, max_length: int = 512):
        """
        Return the context as the answer (no LLM generation)
        """
        try:
            # Return the context directly as the answer
            return {
                "question": question,
                "context": context,
                "answer": context,  # Use context as the answer
                "confidence": 1.0,  # 100% confidence since we're using the exact context
                "start_confidence": 1.0,
                "end_confidence": 1.0,
                "answer_start": 0,
                "answer_end": len(context)
            }
        except Exception as e:
            print(f"Error generating answer: {e}")
            return {
                "question": question,
                "context": context,
                "answer": "Maaf, berlaku ralat semasa memproses soalan anda.",
                "confidence": 0.0,
                "start_confidence": 0.0,
                "end_confidence": 0.0,
                "answer_start": 0,
                "answer_end": 0
            }

class QAGUI:
    def __init__(self, root):
        self.root = root
        self.root.title("Gemma3 Question Answering System")
        self.root.geometry("800x600")
        
        # Initialize QA system
        self.qa_system = None
        self.model_loaded = False
        
        self.setup_ui()
        self.load_model()
    
    def setup_ui(self):
        """Setup the user interface."""
        # Main frame
        main_frame = ttk.Frame(self.root, padding="10")
        main_frame.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S))
        
        # Configure grid weights
        self.root.columnconfigure(0, weight=1)
        self.root.rowconfigure(0, weight=1)
        main_frame.columnconfigure(0, weight=1)
        main_frame.rowconfigure(2, weight=1)

        # Title
        title_label = ttk.Label(main_frame, text="ü§ñ Gemma3 Question Answering System", 
                               font=("Arial", 16, "bold"))
        title_label.grid(row=0, column=0, pady=(0, 20))
        
        # Model status
        self.status_label = ttk.Label(main_frame, text="Loading model...", 
                                     foreground="orange")
        self.status_label.grid(row=1, column=0, pady=(0, 20))
        
        # Question input
        question_frame = ttk.LabelFrame(main_frame, text="Ask Your Question", padding="10")
        question_frame.grid(row=2, column=0, sticky=(tk.W, tk.E, tk.N, tk.S), pady=5)
        question_frame.columnconfigure(0, weight=1)
        
        ttk.Label(question_frame, text="Question:").grid(row=0, column=0, sticky=tk.W, pady=5)
        self.question_entry = ttk.Entry(question_frame, width=80, font=("Arial", 12))
        self.question_entry.grid(row=1, column=0, sticky=(tk.W, tk.E), pady=5)
        
        # Ask button
        self.ask_button = ttk.Button(question_frame, text="Ask Question", 
                                   command=self.ask_question, style="Accent.TButton")
        self.ask_button.grid(row=2, column=0, pady=20)
        
        # Results area
        results_frame = ttk.LabelFrame(main_frame, text="Answer", padding="10")
        results_frame.grid(row=3, column=0, sticky=(tk.W, tk.E, tk.N, tk.S), pady=5)
        results_frame.columnconfigure(0, weight=1)
        results_frame.rowconfigure(0, weight=1)
        
        self.results_text = scrolledtext.ScrolledText(results_frame, height=15, width=80, 
                                                    font=("Arial", 11))
        self.results_text.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S))
        
        # Example questions button
        ttk.Button(main_frame, text="Load Example Questions", 
                  command=self.load_examples).grid(row=4, column=0, pady=10)
    
    def load_model(self):
        """Load the Gemma3 model."""
        try:
            self.qa_system = Gemma3QAInference()
            self.model_loaded = True
            
            # Show model and training data status
            training_count = len(self.qa_system.qa_data) if self.qa_system.qa_data else 0
            status_text = f"‚úÖ Gemma3 model loaded successfully! | üìö {training_count} training examples loaded"
            self.status_label.config(text=status_text, foreground="green")
            
        except Exception as e:
            self.status_label.config(text=f"‚ùå Error loading model: {str(e)}", foreground="red")
            self.ask_button.config(state="disabled")
    
    def ask_question(self):
        """Ask a question and display the result."""
        if not self.model_loaded:
            messagebox.showerror("Error", "Model not loaded!")
            return
        
        question = self.question_entry.get().strip()
        
        if not question:
            messagebox.showwarning("Warning", "Please enter a question!")
            return
        
        try:
            # Show loading indicator and disable button
            self.results_text.delete("1.0", tk.END)
            self.results_text.insert("1.0", "Loading...\n")
            self.ask_button.config(state="disabled")
            self.root.update_idletasks()
            
            # Use googletrans to detect language
            detected = self.qa_system.translator.detect(question)
            lang = detected.lang
            if lang == 'en':
                # English: translate to Malay for inference, then back to English for display
                question_ms = self.qa_system.translate(question, src='en', dest='ms')
                context_data = self.qa_system.get_context_for_question(question_ms)
                result_ms = self.qa_system.ask_question(question_ms, context_data['context'])
                context_en = self.qa_system.translate(result_ms['context'], src='ms', dest='en')
                answer_en = self.qa_system.translate(result_ms['answer'], src='ms', dest='en')
                result = result_ms.copy()
                result['context'] = context_en
                result['answer'] = answer_en
                result['source'] = 'model'
            else:
                # Malay or other: use as is
                context_data = self.qa_system.get_context_for_question(question)
                result = self.qa_system.ask_question(question, context_data['context'])
                result['source'] = 'model'
                
                # Try to find a better answer from training data
                similar_example = self.qa_system.find_similar_question(question)
                if similar_example and result['confidence'] < 0.5:
                    # If model confidence is low, use training data answer
                    result['answer'] = similar_example['answer']
                    result['context'] = similar_example['context']
                    result['confidence'] = 0.8  # Higher confidence for training data
                    result['start_confidence'] = 0.8
                    result['end_confidence'] = 0.8
                    result['answer_start'] = 0
                    result['answer_end'] = len(similar_example['answer'])
                    result['source'] = 'training_data'
            
            # Ensure we show the original question in results
            result['question'] = question
            
            # Display result
            self.display_result(result)
            
        except Exception as e:
            error_msg = f"Error processing question: {str(e)}"
            messagebox.showerror("Error", error_msg)
            self.ask_button.config(state="normal")
            return
        
        # Re-enable button after displaying result
        self.ask_button.config(state="normal")
    
    def display_result(self, result):
        """Display the result in the results area."""
        self.results_text.delete("1.0", tk.END)
        
        # Determine source indicator
        source_indicator = "ü§ñ Model Prediction" if result.get('source') == 'model' else "üìö Training Data"
        
        output = f"""üìù Question: {result['question']}

‚úÖ Answer: {result['answer']}

üéØ Confidence: {result['confidence']:.2%}
{source_indicator}

üìä Details:
- Start Confidence: {result['start_confidence']:.2%}
- End Confidence: {result['end_confidence']:.2%}
- Answer Span: {result['answer_start']} to {result['answer_end']}
"""
        
        self.results_text.insert("1.0", output)
    
    def load_examples(self):
        """Load example questions from training data."""
        # Get unique questions from training data
        unique_questions = []
        seen_questions = set()
        
        if self.qa_system and self.qa_system.qa_data:
            for example in self.qa_system.qa_data:
                question = example['question']
                if question not in seen_questions:
                    unique_questions.append(question)
                    seen_questions.add(question)
                    if len(unique_questions) >= 10:  # Limit to 10 examples
                        break
        
        # Fallback examples if no training data
        if not unique_questions:
            unique_questions = [
                "Apakah Tahun Baru Cina?",
                "Apakah Angpau?",
                "Apakah Kuih Bulan?",
                "Bagaimana Pesta Tanglung disambut?",
                "Apakah kepentingan Masakan Cina Malaysia?",
                "Bagaimana Wayang Kulit dipersembahkan?",
                "Apakah maksud Kongsi Raya?",
                "Bagaimana budaya Cina diwarisi?",
                "Apakah ciri-ciri perayaan Cina?",
                "Bagaimana makanan tradisional Cina disediakan?"
            ]
        
        # Create example window
        example_window = tk.Toplevel(self.root)
        example_window.title("Example Questions")
        example_window.geometry("600x400")
        
        # Example list
        example_frame = ttk.Frame(example_window, padding="10")
        example_frame.pack(fill=tk.BOTH, expand=True)
        
        ttk.Label(example_frame, text="Click an example to load it:", 
                 font=("Arial", 12, "bold")).pack(pady=(0, 10))
        
        for i, example in enumerate(unique_questions, 1):
            btn = ttk.Button(example_frame, text=f"{i}. {example}", 
                           command=lambda ex=example: self.load_example(ex))
            btn.pack(fill=tk.X, pady=2)
    
    def load_example(self, example):
        """Load an example into the question field."""
        self.question_entry.delete(0, tk.END)
        self.question_entry.insert(0, example)

def main():
    """Main function to run the GUI."""
    root = tk.Tk()
    app = QAGUI(root)
    root.mainloop()

if __name__ == "__main__":
    main()

Using device: cpu
Loading Gemma3 model and tokenizer...
Error loading Gemma3 model: You are trying to access a gated repo.
Make sure to have access to it at https://huggingface.co/google/gemma-2-9b-it.
401 Client Error. (Request ID: Root=1-685bf297-39b141082b15cbf22e7b43df;8c7c9439-0e6b-4e48-9ceb-e53c6596b7fa)

Cannot access gated repo for url https://huggingface.co/google/gemma-2-9b-it/resolve/main/config.json.
Access to model google/gemma-2-9b-it is restricted. You must have access to it and be authenticated to access it. Please log in.
Fallback model loaded successfully!
