## 🧬 **Introduction**

Welcome to VariantAInalyser !

VariantAInalyser is a comprehensive all-in-one platform that combines traditional bioinformatics with generative AI to streamline genomic variant analysis from VCF files, eliminating the need to juggle multiple specialised tools.

## 👩🏻‍🔬 **The Challenge**

Genomic variants are essential biomarkers for understanding diseases, drug responses, and creating personalised treatment plans. However, traditional analysis workflows force researchers to:

- Switch between multiple disconnected tools (VCF parsers, annotation software, database interfaces)
- Master different user interfaces and data formats
- Manually integrate results across platforms
- Spend valuable time on technical tasks rather than interpretation

This fragmented approach creates inefficiencies, increases the potential for errors, and significantly extends analysis time.

## 💡 **The Solution: VariantAInalyser!**

VariantAInalyser revolutionises genomic analysis by unifying the entire workflow in a single, intuitive interface. This integrated pipeline:

- **Consolidates multiple tools into one platform**, eliminating the need to switch between systems
- **Automates the complete workflow** from raw VCF data to clinical interpretation
- **Requires minimal technical expertise** to operate effectively

Researchers and clinicians can now:
- Process VCF files to extract critical variant data
- Run SegmentNT analysis to identify genomic regions
- Query ClinVar for clinical significance
- Generate comprehensive reports
- Ask questions in natural language

All without ever leaving the interface or needing to reformat data between tools!

## 🤖 **Generative AI Capabilities**

VariantAInalyser's unified approach leverages cutting-edge AI to replace what traditionally required multiple specialised tools:

- **Retrieval Augmented Generation**: VariantAInalyser combines the power of Google's Gemini 2.0 Flash model with contextual retrieval from genomic data. As shown in the methods defined within the ScientificTextGenerator class (i.e. generate_report(), ask_agent(), and generate_comparision()), the system dynamically retrieves relevant genomic information from the analysis pipeline and integrates it into the prompt, enabling Gemini to provide more accurate and contextually relevant responses based on the specific variants being analysed.

- **Grounding**: Ensures that Gemini's responses are firmly anchored in the actual genomic data processed by the system rather than generic information. The implementation binds the AI to concrete facts from VCF files, ClinVar database results, and SegmentNT predictions, preventing hallucinations and ensuring scientific accuracy in the generated reports. This grounding mechanism creates a direct link between specific variant characteristics and the AI's interpretations of their clinical significance.

- **Document Understanding**: VariantAInalyser demonstrates strong document understanding by parsing and extracting key information from complex genomic data files (like VCFs). It leverages bioinformatics libraries (pysam and BioPython) to preprocess and structure the data, enabling the LLM to reason about the variants and their significance.

- **Structured Output Generation**: The pipeline automatically produces organised and clinically relevant reports from raw VCF input, SegmentNT outputs and queried ClinVar data  in a way that's ready for clinical review. The LLM is prompted to generate responses in a controlled, consistent format, making the results easy to interpret and use.


## 🏁 **Getting Started**

Before using VariantAInalyser, you'll need:

1. **Google API key**
   - Generate it from [AI Studio](https://aistudio.google.com/app/apikey)
   - Add to the designated field in the interface

2. **ClinVar API key** (optional)
   - Create a free [NCBI account](https://account.ncbi.nlm.nih.gov/signup/?back_url=https%3A%2F%2Fwww.ncbi.nlm.nih.gov%2F)
   - Navigate to "Account Settings"
   - Generate an API key in the "API Key Management" section
   - Add to the corresponding field in the interface

🗒️ **Note**: No worries if you don't have a ClinVar API key, you will still be able to use the VariantAInalyser interface. The ClinVar API key helps avoid rate limiting for multiple queries, but for typical use cases, the system should work fine without it :)


## ⚙️ **Set-Up and Configuration**

This section **installs required packages**, and **imports libraries**, and **configures logging** to track the execution of the pipeline and diagnose any issues that might occur.


In [None]:
## Install Required Packages

!pip install google-generativeai --upgrade
!pip install biopython
!pip install pysam
!pip install transformers
!pip install matplotlib seaborn
!pip install ipywidgets

print("\n", "-" * 75)
print("\n ⚙️ Required packages were successfully intalled !")

## Import Libraries

import json
import os
from datetime import datetime
import re
import tempfile
import gzip
from typing import Dict, Any, List, Optional
import logging

import numpy as np
import torch
import requests
import matplotlib.pyplot as plt
import seaborn as sns
import ipywidgets as widgets
from IPython.display import display, HTML, clear_output, Markdown
from Bio import SeqIO
from transformers import AutoTokenizer, AutoModel
import pysam
from google import genai
from google.genai import types
import transformers

print("\n", "-" * 75)
print("\n ⚙️ Required libraries were successfully imported !")

## Set Up Logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("VariantAnalysisPipeline")

print("\n", "-" * 75)
print("\n ⚙️ Logging was successfully configured !")

## 💻 Definition of Core Classes

In this section, we will be defining the classes that implement the core functionality of our VariantAInalyser.

### 1. ScientificTextGenerator 🖺
This class serves as the bridge between raw genomic data and human-readable insights:


- Leverages Google's Gemini 2.0 Flah model

- Constructs prompts that combine variant data (from VCF files), genomic classifications (from segmentNT results), and queried ClinVar data

- Generates detailed scientific explanations tailored to the user's specific questions

### 2. VariantAnalysisPipeline 📈
This class corresponds to the core analysis engine that orchestrates the entire workflow:


- Extracts variant information from VCF files

- Loads and runs the SegmentNT model to identify genomic regions

- Queries the ClinVar database for clinical significance

- Integrates all results into a comprehensive analysis

### 3. VariantAnalysisUI 👩🏻‍💻
This class defines the interactive interface that makes the system accessible to users:


- Provides intuitive controls for uploading VCF files and selecting variants

- Displays variant information in a structured format

- Offers a chat-like interface for natural language interaction

- Allows downloading of results and conversation history

In [None]:
class ScientificTextGenerator:
    """
    A class for generating scientific text reports about genomic variants
    using Generative AI (Google's Gemini model).

    This class handles connection to the Gemini API and provides methods for
    generating various types of genomic reports and answering variant-related queries.

    Attributes:
        client: Google Generative AI client for making API calls
    """

    def __init__(self, api_key: str):
        """
        Initialize the text generator with API credentials.

        Args:
            api_key (str): Google API key for Gemini access. If None, will attempt
                          to use GOOGLE_API_KEY environment variable.

        Raises:
            ValueError: If no API key is provided and no environment variable is set.
        """
        # Check for API key in arguments or environment variables
        if api_key:
            os.environ["GOOGLE_API_KEY"] = api_key
            self.client = genai.Client(api_key=api_key)
        elif "GOOGLE_API_KEY" in os.environ:
            self.client = genai.Client(api_key=os.environ["GOOGLE_API_KEY"])
        else:
            raise ValueError("No API key provided. Please provide an API key or set GOOGLE_API_KEY environment variable.")

    def generate_report(self, variant_data: Dict, clinvar_data: Dict,
                       segment_nt_classes: Dict, protein_change: str,
                       gene_name: str, user_question: str) -> str:
        """
        Generate a scientific report about a variant based on analyzed data.

        Combines information from multiple sources to create a comprehensive
        clinical report tailored to the user's specific question.

        Args:
            variant_data (Dict): Dictionary containing basic variant information
            clinvar_data (Dict): Dictionary containing clinical significance data from ClinVar
            segment_nt_classes (Dict): Dictionary of genomic classes with probabilities from SegmentNT
            protein_change (str): String describing the protein change caused by the variant
            gene_name (str): Name of the gene containing the variant
            user_question (str): User's specific question about the variant

        Returns:
            str: Detailed scientific report addressing the user's question about the variant

        Raises:
            Exception: If there's an error generating the report via the Gemini API
        """
        # Format input data for the prompt
        input_data = f"Variant Data: {variant_data}\n"

        # Add optional data if available
        if segment_nt_classes:
            input_data += f"Genomic Classes: {segment_nt_classes}\n"
        if clinvar_data:
            input_data += f"ClinVar Data: {clinvar_data}\n"
        if protein_change:
            input_data += f"Protein Change: {protein_change}\n"
        if gene_name:
            input_data += f"Gene: {gene_name}\n"

        # Craft the prompt for Gemini with specific instructions
        prompt = f"""
            You are an AI agent, helping out a biologist. The user asked the following question: "{user_question}".
            Based on the question and the provided analysis results below, please generate a clinical report for the variant.

            Analysis Results:
            {input_data}

            Make sure the report is well structured, professional, and accurate. Provide helpful technical information and
            make sure to combine/relate together data from vcf file, genomic results (from segmentNT) and clinvar queries to give
            most exhaustive, insightful and clinically useful report.
            Address the user's question directly.
            Specifically mention the identified genomic classes and their associated probabilities
            given in the current_analysis[genomic_classes_with_prob] (exon, intron, 5UTR, slicing region ... etc) and their relevance to the variant.
        """

        # Generate report using Gemini API
        try:
            # Initialize chat with system instructions
            chat = self.client.chats.create(
                model="gemini-2.0-flash",
                config=types.GenerateContentConfig(
                    system_instruction=prompt
                ),
            )
            # Send the user question to get response
            resp = chat.send_message(user_question)
            logger.info("Successfully generated report with Gemini")
            return resp.text
        except Exception as e:
            logger.error(f"Error generating report: {e}")
            return f"Error generating report: {str(e)}"


    def ask_agent(self, query: str, analysis_data: Dict = None) -> str:
        """
        Generate a response using the AI agent for general genomics questions.

        Uses the Gemini API to provide an informed response based on both general
        knowledge and specific analysis results when available.

        Args:
            query (str): User's query text about genomics or variant analysis
            analysis_data (Dict, optional): Current analysis data to incorporate into the response

        Returns:
            str: AI-generated response addressing the user's question
        """
        # Extract reports from analysis data if available
        input_data = analysis_data.get("reports", {}) if analysis_data else {}

        # Create a prompt that instructs the AI to provide both general knowledge and specific analysis
        prompt = f"""
            You are an AI agent, helping out a biologist. The user asked the following question: "{query}".
            Based on the question, internet searches (especially on PubMed, ClinVar and other reputable genomics sources)
            and the provided analysis results below, please generate a response.
            Unless the user specifies they want you to base your reply only on the results,
            make sure to always first answer the question generally (based on knowledge or internet search results),
            and then more specifically based on the results provided below.

            Analysis Results:
            {input_data}
        """

        # Initialize chat with system instructions
        chat = self.client.chats.create(
            model="gemini-2.0-flash",
            config=types.GenerateContentConfig(
                system_instruction=prompt
            )
        )

        # Send user query and return the response
        resp = chat.send_message(query)
        return resp.text

    def generate_comparison_report(self, query: str, reports: List[str], indices: List[int]) -> str:
        """
        Generate a report comparing multiple variants.

        Takes individual variant reports and produces a comparative analysis
        addressing the user's specific question about the variants.

        Args:
            query (str): User's query about the variants comparison
            reports (List[str]): List of individual variant reports to compare
            indices (List[int]): Indices of the variants being compared

        Returns:
            str: Generated comparison report highlighting differences and similarities

        Raises:
            Exception: If there's an error generating the comparison report
        """
        # Create a prompt that instructs the AI to compare variants
        augmented_prompt = f"""
        You are an AI agent, helping out a biologist. The user asked the following question: "{query}".
        Here are the reports for two variants:

        Report for Variant {indices[0]+1}:
        {reports[0]}

        Report for Variant {indices[1]+1}:
        {reports[1]}

        Based on these reports, please provide a comparison of the two variants,
        addressing the user's question.
        """

        try:
            # Initialize chat with system instructions
            chat = self.client.chats.create(
                model="gemini-2.0-flash",
                config=types.GenerateContentConfig(
                    system_instruction=augmented_prompt
                )
            )
            # Send query and return response
            resp = chat.send_message(query)
            return resp.text
        except Exception as e:
            logger.error(f"Error generating report: {e}")
            return f"Error generating comparison report: {str(e)}"

print("\n", "-" * 75)
print("\n 💻 ScientificTextGenerator Class was successfully defined ! (1/3)")

In [None]:
class VariantAnalysisPipeline:
    """
    An end-to-end pipeline for analyzing genomic variants.

    This pipeline handles the complete workflow of variant analysis:
    - Extract variant information from VCF files
    - Create altered genome sequences with variants
    - Analyze genomic features using SegmentNT neural network
    - Query ClinVar for clinical significance
    - Generate scientific reports using AI

    The pipeline maintains state between operations and allows for analysis
    of multiple variants from the same VCF file.

    Attributes:
        vcf_file_path (str): Path to the VCF file containing variants
        segment_nt_model (str): Model ID for the SegmentNT model to use
        clinvar_api_key (str): API key for accessing ClinVar database
        tokenizer: Tokenizer for the neural network model
        model: Neural network model for genomic analysis
        text_generator: ScientificTextGenerator instance for report generation
        variants (list): List of extracted variants from VCF
        current_variant_index (int): Index of the currently selected variant
        current_analysis (dict): Dictionary containing analysis results for current variant
    """

    def __init__(self, vcf_file_path: str, api_key: str,
                 segment_nt_model: str = "InstaDeepAI/segment_nt",
                 clinvar_api_key: Optional[str] = None):
        """
        Initialize the Variant Analysis Pipeline.

        Sets up the pipeline with necessary components and configuration for
        analyzing genomic variants from VCF files.

        Args:
            vcf_file_path (str): Path to VCF file containing variants
            api_key (str): Google API key for Gemini model access
            segment_nt_model (str): Model ID for SegmentNT neural network
            clinvar_api_key (str, optional): API key for ClinVar database
        """
        self.vcf_file_path = vcf_file_path
        self.api_key = api_key
        self.segment_nt_model = segment_nt_model
        self.clinvar_api_key = clinvar_api_key

        # Initialize neural network components
        self.tokenizer = None
        self.model = None
        # Initialize text generator for reports
        self.text_generator = ScientificTextGenerator(api_key)

        # Initialize data storage
        self.variants = []
        self.current_variant_index = 0
        self.current_analysis = {}

        logger.info("Variant Analysis Pipeline initialized")

    def load_model(self):
        """
        Load the SegmentNT neural network model for genomic analysis.

        Downloads and initializes the SegmentNT model and tokenizer from HuggingFace,
        configuring it for genomic sequence analysis. Handles model configuration
        for appropriate sequence lengths.

        Returns:
            self: The current instance for method chaining

        Raises:
            Exception: If there's an error downloading or loading the model
        """
        logger.info(f"Loading SegmentNT model: {self.segment_nt_model}")
        try:
            # Set cache directory explicitly and force download
            cache_dir = "./model_cache"
            os.makedirs(cache_dir, exist_ok=True)

            # Set HF_HUB_OFFLINE to 0 to ensure online mode
            os.environ['HF_HUB_OFFLINE'] = '0'

            print(f"Attempting to download tokenizer from {self.segment_nt_model}")
            # Load tokenizer with appropriate configuration
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.segment_nt_model,
                trust_remote_code=True,
                cache_dir=cache_dir,
                local_files_only=False,
                use_auth_token=False  # Set to your HF token if using private models
            )

            print(f"Tokenizer loaded, now loading model from {self.segment_nt_model}")
            # Load neural network model
            self.model = AutoModel.from_pretrained(
                self.segment_nt_model,
                trust_remote_code=True,
                cache_dir=cache_dir,
                local_files_only=False,
                use_auth_token=False)

            # Configure model for appropriate sequence length processing
            max_num_dna_tokens = 1668
            if max_num_dna_tokens + 1 > 5001:
                # Calculate rescaling factor for rotary embeddings
                inference_rescaling_factor = (max_num_dna_tokens + 1) / 2048
                num_layers = len(self.model.esm.encoder.layer)
                # Apply rescaling to each layer
                for layer in range(num_layers):
                   self.model.esm.encoder.layer[layer].attention.self.rotary_embeddings.rescaling_factor = inference_rescaling_factor

            logger.info("Model loaded successfully")

        except Exception as e:
            logger.error(f"Error loading model: {e}")
            print(f"Detailed error: {str(e)}")
            raise
        return self


    def extract_variants(self, max_variants: int = 10):
        """
        Extract variant information from VCF file.

        Parses VCF file and extracts essential information about genetic variants,
        including position, reference/alternate alleles, and SPDI notation.
        Limits extraction to the specified maximum number of variants.

        Args:
            max_variants (int): Maximum number of variants to extract from the VCF

        Returns:
            self: The current instance for method chaining

        Raises:
            Exception: If there's an error reading or parsing the VCF file
        """
        logger.info(f"Extracting variants from {self.vcf_file_path}")
        try:
            # Open VCF file with pysam
            vcf = pysam.VariantFile(self.vcf_file_path)  # auto-detects format

            output_list = []
            for rec in vcf:
                # Extract the SPDI notation (Sequence Position Deletion Insertion)
                spdi = rec.info.get('CLNHGVS')
                if spdi:
                    spdi = spdi[0]
                    # Clean up SPDI notation if needed
                    reformatted_spdi = spdi
                    count = 0
                    for letter in spdi[-3:]:
                        if letter.islower():
                            count += 1
                    if count == 3:
                        reformatted_spdi = spdi[:-3]
                else:
                    reformatted_spdi = None

                # reformat the Alternate Allele:
                alt_allele=rec.alts[0]
                # remove parenthesis and quotation marks
                reformatted_alt=alt_allele.strip("'\"")

                # Create variant record with essential information
                output = {
                    "Type": rec.info.get('MC'),
                    "Chromosome": rec.contig,
                    "Position": rec.pos,
                    "Reference Allele": rec.ref,
                    "Alternate Allele": reformatted_alt,
                    "Canonical SPDI": reformatted_spdi
                }

                output_list.append(output)
                # Stop if we've reached the maximum number of variants
                if len(output_list) >= max_variants:
                    break

            vcf.close()
            self.variants = output_list

            # Set the first variant as current by default if any variants were found
            if self.variants:
                self.current_analysis["variant"] = self.variants[0]
                logger.info(f"Set initial variant in current_analysis: {self.current_analysis['variant']}")

            logger.info(f"Extracted {len(output_list)} variants")

        except Exception as e:
            logger.error(f"Error extracting variants: {e}")
            raise

        return self

    def prepare_altered_genome(self, variant_index: int):
        """
        Prepare the altered genome by inserting the variant into the reference sequence.

        Retrieves the reference genome sequence and creates an altered version
        by inserting the variant at the appropriate position. This creates the
        genome sequence that will be analyzed by SegmentNT.

        Args:
            variant_index (int): Index of the variant in the variants list to analyze

        Returns:
            self: The current instance for method chaining

        Raises:
            ValueError: If no variants have been extracted or index is invalid
            FileNotFoundError: If reference genome file is not found
            Exception: For other errors in sequence preparation
        """
        # Validate that variants have been extracted
        if not self.variants:
            logger.error("No variants extracted. Call extract_variants() first.")
            return self

        # Validate variant index
        if variant_index >= len(self.variants):
            logger.error(f"Invalid variant index {variant_index}. Only {len(self.variants)} variants available.")
            return self

        # Set current variant and extract information
        self.current_variant_index = variant_index
        variant = self.variants[variant_index]

        logger.info(f"Preparing altered genome for variant at position {variant['Position']}")

        # Extract variant details
        chromosome = variant["Chromosome"]
        position = variant["Position"]
        ref_allele = variant["Reference Allele"]
        alt_allele = variant["Alternate Allele"]
        alt_allele_string = ''.join(str(x) for x in alt_allele)

        # Access reference genome - handles both .fa and .fa.gz files
        fasta_path = f"/content/Homo_sapiens.GRCh38.dna.chromosome.{chromosome}.fa"
        if not os.path.exists(fasta_path):
            fasta_path = f"/content/Homo_sapiens.GRCh38.dna.chromosome.{chromosome}.fa.gz"

        try:
            # Determine if file is gzipped
            is_gzipped = fasta_path.endswith('.gz')

            # Open file with appropriate method
            opener = gzip.open if is_gzipped else open
            mode = 'rt' if is_gzipped else 'r'

            # Read reference genome sequence
            with opener(fasta_path, mode) as handle:
                record = next(SeqIO.parse(handle, "fasta"))
                ref_sequence = str(record.seq)

            # Alter the reference sequence by inserting variant
            altered_seq = ref_sequence[:position-1] + alt_allele_string + ref_sequence[position:]

            # Store in current analysis
            self.current_analysis["altered_genome"] = altered_seq
            self.current_analysis["variant"] = variant
            self.current_analysis["position"] = position

            logger.info("Altered genome prepared successfully")

        except FileNotFoundError:
            logger.error(f"Reference genome file not found: {fasta_path}")
            raise
        except (OSError, IOError) as e:
            logger.error(f"Error reading genome file: {e}")
            raise
        except Exception as e:
            logger.error(f"Error preparing altered genome: {e}")
            raise

        return self

    def run_segment_nt(self, flank_size: int = 5000):
        """
        Run SegmentNT neural network analysis on the altered genome.

        Analyzes the genomic context around the variant to identify functional
        elements and their probabilities. Processes a sequence window centered
        on the variant position and calculates probabilities for various
        genomic features.

        Args:
            flank_size (int): Size of the flanking region to analyze around the variant

        Returns:
            self: The current instance for method chaining

        Raises:
            ValueError: If altered genome not prepared or model not loaded
            Exception: For errors during model inference or analysis
        """
        # Validate that altered genome has been prepared
        if "altered_genome" not in self.current_analysis:
            logger.error("Altered genome not prepared. Call prepare_altered_genome() first.")
            return self

        # Validate that model has been loaded
        if self.model is None or self.tokenizer is None:
            logger.error("SegmentNT model not loaded. Call load_model() first.")
            return self

        logger.info("Running SegmentNT analysis")

        try:
            # Get altered sequence and variant position
            altered_seq = self.current_analysis["altered_genome"]
            position = self.current_analysis["position"]

            # Calculate start and end indices for the sequence window
            max_num_dna_tokens = 1668
            idx_start = position - 1 - flank_size
            idx_stop = idx_start + max_num_dna_tokens * 6

            # Ensure indices are valid
            if idx_start < 0:
                idx_start = 0
            if idx_stop > len(altered_seq):
                idx_stop = len(altered_seq)

            # Extract sequence window and tokenize for model input
            sequences = [altered_seq[idx_start:idx_stop]]
            tokens = self.tokenizer.batch_encode_plus(
                sequences,
                return_tensors="pt",
                padding="max_length",
                max_length=max_num_dna_tokens
            )["input_ids"]

            # Run inference on GPU if available, otherwise CPU
            device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
            attention_mask = (tokens != self.tokenizer.pad_token_id)

            # Perform model inference
            with torch.no_grad():
                outputs = self.model(
                    tokens,
                    attention_mask=attention_mask,
                )

            # Process model outputs
            logits = outputs["logits"]
            probabilities = np.asarray(torch.nn.functional.softmax(logits, dim=-1).cpu())[...,-1]

            # Get feature list from model config
            features = self.model.config.features

            # Analyze probabilities at variant position
            index_of_position = position - (position - 1 - flank_size)
            if 0 <= index_of_position < probabilities.shape[1]:
                probabilities_at_position = probabilities[0][index_of_position]

                # Print probabilities for variant position
                logger.info(f"Probabilities at position {position}:")

                # Extract classes with probabilities
                class_probabilities = {}
                for i, feature in enumerate(features):
                    prob = probabilities_at_position[i]
                    logger.info(f"  {feature}: {prob:.4f}")
                    class_probabilities[feature] = float(prob)  # Convert to float for JSON serialization

                # Extract significant genomic classes (probability > 0.8)
                genomic_classes_with_prob = {}
                for class_name, probability in class_probabilities.items():
                    if probability > 0.8:
                        genomic_classes_with_prob[class_name] = probability

                # Compute length of analyzed genome sequence
                seq_length = probabilities.shape[-2]

                # Store results in current analysis
                self.current_analysis["sequence_length"] = seq_length
                self.current_analysis["probabilities_all_segNT"] = probabilities[0]
                self.current_analysis["features_all_segNT"] = features
                self.current_analysis["likely_genomic_classes_with_prob"] = genomic_classes_with_prob

                logger.info(f"Identified genomic classes with probabilities: {genomic_classes_with_prob}")

            else:
                logger.warning(f"Position index {index_of_position} out of range for probabilities array")

        except Exception as e:
            logger.error(f"Error running SegmentNT analysis: {e}")
            raise

        return self

    def query_clinvar(self):
        """
        Query ClinVar API for clinical significance information about the variant.

        Retrieves pathogenicity, allelic frequency, protein change, and gene information
        from NCBI's ClinVar database using the variant's SPDI notation as identifier.

        Returns:
            self: The current instance for method chaining

        Raises:
            ValueError: If no variant has been selected
            RequestException: For API request failures
            Exception: For other errors during ClinVar querying
        """
        # Validate that a variant has been selected
        if "variant" not in self.current_analysis:
            logger.error("No variant selected. Call prepare_altered_genome() first.")
            return self

        # Get variant and SPDI notation
        variant = self.current_analysis["variant"]
        spdi = variant.get('Canonical SPDI')

        # Check if SPDI notation is available
        if not spdi:
            logger.warning("No SPDI notation available for this variant. Skipping ClinVar query.")
            self.current_analysis["clinvar_data"] = {
                "pathogenicity": None,
                "allelic_freq": None,
                "protein_change": None,
                "gene_name": None
            }
            return self

        logger.info(f"Querying ClinVar for variant: {spdi}")

        # Check API key availability
        if not self.clinvar_api_key:
            logger.warning("No ClinVar API key provided. Using unauthenticated requests.")

        try:
            # Construct API query to search for variant
            base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
            params = {
                "db": "clinvar",
                "term": f"{spdi}",
                "retmode": "json",
                "retmax": 3
            }

            # Add API key if available
            if self.clinvar_api_key:
              params["api_key"] = self.clinvar_api_key

            # Send search request
            response = requests.get(base_url, params=params)
            response.raise_for_status()
            results = response.json()

            # Check if we got any results
            if ("esearchresult" not in results or
                "idlist" not in results["esearchresult"] or
                not results["esearchresult"]["idlist"]):
                logger.warning(f"No ClinVar records found for variant: {spdi}")
                self.current_analysis["clinvar_data"] = {
                    "pathogenicity": None,
                    "allelic_freq": None,
                    "protein_change": None,
                    "gene_name": None
                }
                return self

            # Get the first ID from the list of results
            variant_id = results["esearchresult"]["idlist"][0]

            # Get detailed information using the variant ID
            esummary_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esummary.fcgi"
            esummary_params = {
                "db": "clinvar",
                "id": variant_id,
                "retmode": "json"
           }

            # Add API key if available
            if self.clinvar_api_key:
                esummary_params["api_key"] = self.clinvar_api_key

            # Send summary request
            esummary_response = requests.get(esummary_url, params=esummary_params)
            esummary_response.raise_for_status()
            esummary_results = esummary_response.json()

            # Extract relevant clinical information
            result_data = esummary_results['result'][variant_id]
            pathogenicity = result_data.get('germline_classification', {}).get('description')
            allelic_freq = result_data.get('allele_freq_set')
            protein_change = result_data.get('protein_change')
            gene_name = result_data.get('gene_sort')

            # Store ClinVar data in current analysis
            self.current_analysis["clinvar_data"] = {
                "pathogenicity": pathogenicity,
                "allelic_freq": allelic_freq,
                "protein_change": protein_change,
                "gene_name": gene_name
            }

            logger.info(f"ClinVar data retrieved: {pathogenicity}, {gene_name}")

        except requests.exceptions.RequestException as e:
            # Handle API request errors
            logger.error(f"ClinVar API request failed: {e}")
            self.current_analysis["clinvar_data"] = {
                "error": str(e),
                "pathogenicity": None,
                "allelic_freq": None,
                "protein_change": None,
                "gene_name": None
            }
        except KeyError as e:
            # Handle response parsing errors
            logger.error(f"Error parsing ClinVar response: {e}")
            self.current_analysis["clinvar_data"] = {
                "error": f"Error parsing response: {str(e)}",
                "pathogenicity": None,
                "allelic_freq": None,
                "protein_change": None,
                "gene_name": None
            }
        except Exception as e:
            # Handle unexpected errors
            logger.error(f"Unexpected error querying ClinVar: {e}")
            self.current_analysis["clinvar_data"] = {
                "error": f"Unexpected error: {str(e)}",
                "pathogenicity": None,
                "allelic_freq": None,
                "protein_change": None,
                "gene_name": None
            }

        return self

    def generate_reports(self, user_question: str, output_dir: Optional[str] = None):
        """
        Generate scientific reports for the variant using AI text generation.

        Creates a comprehensive scientific report based on all collected
        analysis data, addressing the user's specific question about the variant.
        Can optionally save the report to a file.

        Args:
            user_question (str): Question from the user about the variant
            output_dir (str, optional): Directory to save the reports to

        Returns:
            str: The generated report text

        Raises:
            ValueError: If no variant has been selected for analysis
            Exception: For errors during report generation
        """
        # Validate that a variant has been selected
        if "variant" not in self.current_analysis:
            logger.error("No variant selected. Pipeline must be run first.")
            return "Error: No variant selected. Please run the analysis pipeline first."

        try:
            # Gather all analysis data
            variant_data = self.current_analysis["variant"] # Data extracted from uploaded VCF file
            genomic_classes_with_prob = self.current_analysis.get("likely_genomic_classes_with_prob", {}) #Outputs of SegmentNT model
            clinvar_data = self.current_analysis.get("clinvar_data", {}) #Data queried from ClinVar database

            # Extract ClinVar data
            protein_change = clinvar_data.get("protein_change")
            gene_name = clinvar_data.get("gene_name")

            logger.info(f"Generating reports for variant at position {variant_data['Position']}")

            # Generate report using text generator
            response = self.text_generator.generate_report(
                variant_data, clinvar_data, genomic_classes_with_prob, protein_change,
                gene_name, user_question)

            # Store report in current analysis
            self.current_analysis["reports"] = {
             "response": response
            }

            # Save report to file if output directory is provided
            if output_dir:
                os.makedirs(output_dir, exist_ok=True)
                # Create a filename based on variant ID
                variant_id = str(variant_data.get("Canonical SPDI", "variant")).replace(":", "_")
                response_path = os.path.join(output_dir, f"report_{variant_id}.txt")

                # Write report to file
                with open(response_path, "w") as f:
                    f.write(response)

                logger.info(f"Reports saved to {output_dir}")

            return response

        except Exception as e:
            logger.error(f"Error generating reports: {e}")
            return f"Error generating reports: {str(e)}"

print("\n", "-" * 75)
print("\n 💻 VariantAnalysisPipeline Class was successfully defined ! (2/3)")

In [None]:
class VariantAnalysisUI:
    """
    User interface for the Variant Analysis Pipeline.

    This class provides a complete UI for genomic variant analysis, featuring:
    - VCF file upload functionality
    - Google API integration for genomic queries
    - Variant information display
    - Interactive chat interface for querying variant data
    - Report generation and download capabilities

    The UI is designed for biologists and researchers working with genomic variants
    and provides intuitive access to complex analysis functions.
    """

    def __init__(self):
        """
        Initialize the UI components and state.

        Sets up all UI elements, defines styling, and prepares event handlers.
        The UI starts in a disabled state until a VCF file is uploaded and processed.
        """
        # Core state variables
        self.pipeline = None
        self.chat_history = []
        self.output_dir = "/content/output"
        os.makedirs(self.output_dir, exist_ok=True)

        # Create styling for UI components
        self.style = self._create_ui_styles()

        # Define download button for chat history
        self.download_chat_button = widgets.Button(
            description='Download Chat History',
            disabled=False,
            icon='download'
        )

        # Set up UI components, event handlers, and chat display
        self._create_ui_components()
        self._setup_event_handlers()
        self._initialize_chat_display()

    def _create_ui_styles(self):
        """
        Create and return CSS styles for UI components.

        Defines the complete styling for the application including colors, fonts,
        layout, and responsive design elements for all UI components.

        Returns:
            str: HTML/CSS styling for the UI components
        """
        return """
        <style>
        /* Main container and overall styling */
        .genomic-app {
            font-family: 'Roboto', sans-serif;
            max-width: 1200px;
            margin: 0 auto;
            background-color: #f8fafb;
            padding: 20px;
            border-radius: 8px;
            box-shadow: 0 2px 10px rgba(0,0,0,0.05);
        }

        /* Header styling */
        .app-header {
            background: linear-gradient(135deg, #4285F4, #34A853);
            color: white;
            padding: 15px 20px;
            border-radius: 8px 8px 0 0;
            margin-bottom: 20px;
            text-align: center;
        }

        .app-header h1 {
            margin: 0;
            font-size: 28px;
            font-weight: 500;
        }

        .app-header p {
            margin: 5px 0 0;
            opacity: 0.9;
            font-size: 16px;
        }

        /* Section styling */
        .section {
            background-color: white;
            border-radius: 6px;
            box-shadow: 0 1px 3px rgba(0,0,0,0.1);
            padding: 15px;
            margin-bottom: 20px;
        }

        .section-header {
            font-size: 18px;
            font-weight: 500;
            color: #4285F4;
            border-bottom: 2px solid #e0e0e0;
            padding-bottom: 8px;
            margin-bottom: 15px;
        }

        /* Chat container */
        .chat-container {
            max-height: 400px;
            overflow-y: auto;
            border: 1px solid #e0e0e0;
            padding: 10px;
            margin-bottom: 15px;
            background-color: #f9f9f9;
            border-radius: 8px;
            box-shadow: inset 0 1px 3px rgba(0,0,0,0.05);
        }

        .user-message {
            background-color: #E3F2FD;
            padding: 12px;
            border-radius: 18px 18px 3px 18px;
            margin-bottom: 12px;
            max-width: 80%;
            margin-left: auto;
            word-wrap: break-word;
            color: #37474F;
            box-shadow: 0 1px 0.5px rgba(0,0,0,0.1);
        }

        .bot-message {
            background-color: #FFFFFF;
            padding: 12px;
            border-radius: 18px 18px 18px 3px;
            margin-bottom: 12px;
            max-width: 80%;
            border-left: 4px solid #4285F4;
            word-wrap: break-word;
            color: #212121;
            box-shadow: 0 1px 0.5px rgba(0,0,0,0.1);
        }

        /* VCF info display */
        .vcf-info {
            background-color: #F5F7FF;
            padding: 15px;
            border-radius: 6px;
            margin-top: 10px;
            margin-bottom: 10px;
            border-left: 4px solid #7986CB;
        }

        .variant-card {
            background-color: white;
            border-radius: 6px;
            padding: 12px;
            margin-bottom: 10px;
            box-shadow: 0 1px 2px rgba(0,0,0,0.1);
            transition: all 0.2s ease;
        }

        .variant-card:hover {
            box-shadow: 0 3px 8px rgba(0,0,0,0.15);
        }

        /* Titles and text */
        .title {
            font-size: 22px;
            font-weight: bold;
            margin-bottom: 12px;
            color: #4285F4;
        }

        .subtitle {
            font-size: 16px;
            font-weight: 500;
            margin-top: 15px;
            margin-bottom: 8px;
            color: #5F6368;
            border-bottom: 1px solid #e0e0e0;
            padding-bottom: 5px;
        }

        /* Status indicators */
        .loading {
            color: #5F6368;
            margin-top: 10px;
            display: flex;
            align-items: center;
        }

        .loading:before {
            content: '';
            display: inline-block;
            width: 16px;
            height: 16px;
            margin-right: 8px;
            border: 2px solid #4285F4;
            border-radius: 50%;
            border-top-color: transparent;
            animation: spin 1s linear infinite;
        }

        @keyframes spin {
            to { transform: rotate(360deg); }
        }

        /* Button styling */
        .button-primary {
            background-color: #4285F4;
            color: white;
            border: none;
            padding: 10px 16px;
            border-radius: 4px;
            cursor: pointer;
            font-weight: 500;
            transition: background-color 0.2s;
        }

        .button-primary:hover {
            background-color: #3367D6;
        }

        .button-secondary {
            background-color: #5F6368;
            color: white;
            border: none;
            padding: 10px 16px;
            border-radius: 4px;
            cursor: pointer;
            font-weight: 500;
            transition: background-color 0.2s;
        }

        .button-secondary:hover {
            background-color: #494C50;
        }

        /* Input fields */
        .input-field {
            border: 1px solid #dadce0;
            border-radius: 4px;
            padding: 10px 12px;
            font-size: 14px;
            width: 100%;
            transition: border 0.2s;
        }

        .input-field:focus {
            border-color: #4285F4;
            outline: none;
        }

        /* Pathogenicity indicators */
        .pathogenic {
            color: #D32F2F;
            font-weight: 500;
        }

        .benign {
            color: #388E3C;
            font-weight: 500;
        }

        .uncertain {
            color: #FFA000;
            font-weight: 500;
        }

        /* Genomic class labels */
        .genomic-class {
            display: inline-block;
            padding: 3px 8px;
            border-radius: 12px;
            margin-right: 5px;
            margin-bottom: 5px;
            font-size: 12px;
            font-weight: 500;
            background-color: #E8F0FE;
            color: #185ABC;
        }
        </style>
        """
    def _create_ui_components(self):
        """
        Create and configure all UI widgets.

        Initializes all UI components including:
        - File upload widget
        - Status and output displays
        - Query input field
        - Action buttons
        - API key input field
        - ClinVar API key input field

        Each component is configured with appropriate properties and layout settings.
        """
        # File upload component
        self.file_upload = widgets.FileUpload(
            description='Upload VCF file:',
            accept='.vcf, .vcf.gz',
            multiple=False
        )

        # Status and output components
        self.status_output = widgets.Output()
        self.vcf_info_output = widgets.Output()
        self.chat_output = widgets.Output()
        self.print_output = widgets.Output()

        # Input components
        self.query_input = widgets.Text(
            placeholder='Ask a question about the uploaded variant data...',
            description='Query:',
            disabled=False,
            layout=widgets.Layout(width='80%')
        )

        # Action buttons
        self.send_button = widgets.Button(
            description='Send',
            disabled=False,
            button_style='primary',
            tooltip="Send your query",
            icon='paper-plane'
        )

        self.clear_chat_button = widgets.Button(
            description='Clear Chat',
            disabled=False,
            button_style='danger',
            tooltip="Clear the chat history",
            icon='trash'
        )

        # API key input and button
        self.api_key_input = widgets.Text(
            placeholder='Enter your Google API key',
            description='API Key:',
            layout=widgets.Layout(width='50%')
        )

        self.apply_api_key_button = widgets.Button(
            description='Apply API Key(s)',
            button_style='primary',
            icon='check'
        )

        # ClinVar API key input (optional)
        self.clinvar_key_input = widgets.Text(
            placeholder='Enter your ClinVar API key (optional)',
            description='ClinVar Key:',
            layout=widgets.Layout(width='50%')
        )


    def _setup_event_handlers(self):
        """
        Set up event handlers for UI components.

        Connects UI components to their respective handler methods:
        - File upload widget to on_file_upload_change
        - Send button to on_send_button_click
        - Clear chat button to on_clear_chat_button_click
        - Apply API key button to on_apply_api_key_button_click
        - Download chat button to on_download_chat_button_click

        This establishes the interactive behavior of the UI.
        """
        self.file_upload.observe(self.on_file_upload_change, names='value')
        self.send_button.on_click(self.on_send_button_click)
        self.clear_chat_button.on_click(self.on_clear_chat_button_click)
        self.apply_api_key_button.on_click(self.on_apply_api_key_button_click)
        self.download_chat_button.on_click(self.on_download_chat_button_click)

    def _initialize_chat_display(self):
        """
        Initialize the chat display widget.

        Creates an empty chat container with appropriate styling.
        This prepares the UI for user interaction with the chat interface.
        """
        with self.chat_output:
            clear_output()
            display(HTML(self.style + "<div class='chat-container' id='chat-container'></div>"))

    def on_file_upload_change(self, change):
        """
        Handle file upload events.

        Processes the uploaded VCF file and initializes the analysis pipeline.
        Validates that an API key has been provided before processing the file.
        Saves the uploaded file to a temporary location, initializes the
        pipeline, extracts variants, and displays variant information.

        Args:
            change (dict): Change event data containing the uploaded file information
        """
        if change['type'] == 'change' and change['name'] == 'value' and change['new']:
            # Check if API key is provided
            api_key = self.api_key_input.value.strip()
            if not api_key:
                with self.print_output:
                    clear_output()
                    print("Please enter your Google API key and click 'Apply API Key' before uploading a file.")
                    return

            try:
                # Get the uploaded file data
                uploaded_files = change['new']
                if not uploaded_files:
                    return

                # Get the first file
                file_name = next(iter(uploaded_files))  # Get the filename
                file_content = uploaded_files[file_name]['content']  # Get the content

                # Create directory and save file
                upload_dir = "/content/uploads"
                os.makedirs(upload_dir, exist_ok=True)
                file_path = os.path.join(upload_dir, file_name)

                # Write the file content
                with open(file_path, 'wb') as f:
                    f.write(file_content)

                with self.print_output:
                    clear_output()
                    print(f"File uploaded: {file_name}")
                    print("Initializing pipeline and extracting variants...")

                # Initialize the pipeline
                try:
                    self.pipeline = self._initialize_pipeline(file_path, api_key)

                    if self.pipeline is None:
                        return

                    # Extract variants
                    self.pipeline.extract_variants()

                    # Display variant information
                    if self.pipeline.variants:
                        self._display_variant_info()

                        # Enable the query components
                        self.query_input.disabled = False
                        self.send_button.disabled = False
                        self.clear_chat_button.disabled = False

                        # Add a welcome message to the chat
                        self.add_message_to_chat(
                            "Welcome! I've analyzed your VCF file. You can now ask questions about the variants.",
                            is_user=False
                        )
                    else:
                        with self.print_output:
                            clear_output()
                            print("No variants found in the uploaded file.")

                except Exception as e:
                    with self.print_output:
                        clear_output()
                        print(f"Error processing file: {str(e)}")

            except Exception as e:
                with self.print_output:
                    clear_output()
                    print(f"Error handling file upload: {str(e)}")
                    print("Please make sure you're uploading a valid VCF file.")

    def on_apply_api_key_button_click(self, b):
        """
        Handle apply API key button click events.

        Validates and applies the provided Google API key.
        Displays a success message or error message based on input validation.

        Args:
            b (Button): Button widget that triggered the event
        """
        api_key = self.api_key_input.value.strip()

        with self.print_output:
            clear_output()
            try:
                # Test API key validity
                genai.configure(api_key=api_key)
                model = genai.GenerativeModel('gemini-pro')
                response = model.generate_content("Test")
                print("✅ API Key validated successfully!")
                print("You can now upload your VCF file.")
            except Exception as e:
                print("❌ Invalid API key. Please check your credentials.")
                print(f"Error: {str(e)}")
                return

    def _initialize_pipeline(self, vcf_file_path, api_key=None):
        """
        Initialize the variant analysis pipeline with the uploaded file.

        Creates an instance of VariantAnalysisPipeline with the provided
        VCF file path and API key. Optionally includes ClinVar API key
        if provided.

        Args:
            vcf_file_path (str): Path to the uploaded VCF file
            api_key (str, optional): Google API key for accessing services

        Returns:
            VariantAnalysisPipeline: Initialized pipeline object or None if error occurs
        """
        if not api_key:
            with self.print_output:
                clear_output()
                print("Please enter your Google API key.")
            return None

        try:
            # Get ClinVar API key if provided
            clinvar_key = self.clinvar_key_input.value.strip() or None

            pipeline = VariantAnalysisPipeline(
                vcf_file_path=vcf_file_path,
                api_key=api_key,
                clinvar_api_key=clinvar_key
            )
            return pipeline
        except Exception as e:
            with self.print_output:
                clear_output()
                print(f"Error initializing pipeline: {str(e)}")
            return None

    def _display_variant_info(self):
        """
        Display variant information in the UI.

        Creates a formatted HTML display of all identified variants.
        Clears any previous variant information and shows the new data
        in a structured, user-friendly format.
        """
        with self.vcf_info_output:
            clear_output()
            display(HTML(self.style))
            variant_html = "<div class='vcf-info'>"
            variant_html += "<div class='title'>Variant Information</div>"

            for i, variant in enumerate(self.pipeline.variants):
                variant_html += f"<div class='subtitle'>Variant {i+1}</div>"
                variant_html += "<ul>"
                for key, value in variant.items():
                    variant_html += f"<li><strong>{key}:</strong> {value}</li>"
                variant_html += "</ul>"

            variant_html += "</div>"
            display(HTML(variant_html))

    def on_send_button_click(self, b):
        """
        Handle send button click events.

        Processes the user's query and generates appropriate responses.
        Detects the intent of the query, extracts relevant variant indices,
        and performs the appropriate action based on the query intent.
        Adds the user's query and the system's response to the chat history.

        Args:
            b (Button): Button widget that triggered the event
        """
        query = self.query_input.value
        if query and self.pipeline:

            # Add the user's message to the chat
            self.add_message_to_chat(query, is_user=True)

            # Clear the input
            self.query_input.value = ''

            # Process the query
            with self.print_output:
                clear_output()
                print("Processing your query...")

            # Detect intent
            intent = self._detect_intent(query)
            try:
                if intent["compare_variants"]:
                    # Compare the 2 specified variants
                    indices = self._extract_variant_indices_for_comparison(query, len(self.pipeline.variants)) # extract indices of specified variants to compare
                    if len(indices) == 2:
                        reports = self._get_reports_for_variants(indices)
                        if isinstance(reports,int):
                          idx=reports
                          with self.print_output:
                            clear_output()
                            print(f"Report not found for variant {idx+1}. If you would like me to analyse variant {idx+1} now, in order to be able to conduct the comparison: Please write 'Analyse variant {idx+1}.'")
                          self.add_message_to_chat(f"Report not found for variant {idx+1}. If you would like me to analyse variant {idx+1} now, in order to be able to conduct the comparison: Please write 'Analyse variant {idx+1}.'"
                          , is_user=False)
                          return
                        if all(reports):
                            response = self.pipeline.text_generator.generate_comparison_report(query, reports, indices)
                            self.add_message_to_chat(response, is_user=False)
                        else:
                          self.add_message_to_chat("I couldn't find reports for both variants.", is_user=False)
                    else:
                        self.add_message_to_chat("Please specify two variants to compare.", is_user=False)

                elif intent["plot_segment_nt_results"]:
                    # Plot the last SegmentNT results
                    with self.print_output:
                        self.plot_segment_nt_results()
                        self.add_message_to_chat(f"The SegmentNT results have been plotted and saved to: {self.output_dir} directory", is_user=False)

                elif intent["variant_analysis"]:
                    # Run variant analysis pipeline on specified variant
                    idx = self._extract_variant_index_from_query(query, len(self.pipeline.variants)) # extract index of specified variant
                    if not hasattr(self.pipeline, 'model') or self.pipeline.model is None:
                        with self.print_output:
                            print("Loading SegmentNT model...")
                            self.pipeline.load_model()

                    with self.print_output:
                        clear_output()
                        self.add_message_to_chat(
                                f"Analyzing Variant {idx + 1} based on your query...",
                                is_user=False)
                        print("Preparing altered genome based on variant selected...")
                        self.pipeline.prepare_altered_genome(idx)

                        print("Running SegmentNT analysis...")
                        self.pipeline.run_segment_nt()

                        print("Querying additional info about the variant to ClinVar...")
                        self.pipeline.query_clinvar()

                        # Generate a report
                        print("Generating analysis report...")
                        report = self.pipeline.generate_reports(query, output_dir=self.output_dir)
                    with self.print_output:
                        clear_output()
                        print("Analysis complete!")
                        print(display(Markdown(report)))
                        self.add_message_to_chat(f"The report has been generated and saved to {self.output_dir}.", is_user=False)

                else:
                    search_results = self.pipeline.text_generator.ask_agent(query, self.pipeline.current_analysis)
                    self.add_message_to_chat(search_results, is_user=False)

            except Exception as e:
                with self.print_output:
                    clear_output()
                    print(f"Error processing query: {str(e)}")
                # Add an error message to the chat
                self.add_message_to_chat(
                    f"I'm sorry, I encountered an error while processing your query: {str(e)}",
                    is_user=False)

    def on_clear_chat_button_click(self, b):
        """
        Handle clear chat button click events.

        Clears the chat history and resets the chat display.
        Removes all messages from the chat container and initializes it again.

        Args:
            b (Button): Button widget that triggered the event
        """
        self.chat_history = []

        with self.chat_output:
            clear_output()
            display(HTML(self.style + "<div class='chat-container' id='chat-container'></div>"))

    def _format_chat_history_for_download(self):
        """
        Format chat history for download.

        Converts the chat history to a formatted text representation
        with timestamps, user/assistant indicators, and message content.

        Returns:
            str: Formatted chat history as text
        """
        formatted_chat = []
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

        for msg in self.chat_history:
            user = "User" if msg['is_user'] else "Assistant"
            formatted_chat.append(f"[{timestamp}] {user}: {msg['content']}")

        return "\n\n".join(formatted_chat)

    def _download_text_file(self, text, filename):
        """
        Create a downloadable text file using HTML anchor and Blob.

        Generates JavaScript code that creates a Blob object from the text,
        creates a download link, and triggers the download automatically.

        Args:
            text (str): Text content to download
            filename (str): Name of the file to download

        Returns:
            Javascript: JavaScript code for file download
        """
        # Convert text to a JavaScript-safe format
        js_safe_text = text.replace("\n", "\\n").replace("'", "\\'").replace('"', '\\"')

        js_code = f"""
        (function() {{
            var text = "{js_safe_text}";
            var blob = new Blob([text], {{type: 'text/plain'}});
            var anchor = document.createElement('a');
            anchor.href = window.URL.createObjectURL(blob);
            anchor.download = "{filename}";
            anchor.style.display = 'none';
            document.body.appendChild(anchor);
            anchor.click();
            document.body.removeChild(anchor);
            window.URL.revokeObjectURL(anchor.href);
        }})();
        """

        return Javascript(js_code)

    def on_download_chat_button_click(self, b):
        """
        Handle download chat button click events.

        Saves the chat history to a file and provides download instructions.
        Formats the chat history, saves it to a file in the output directory,
        and informs the user of the file location.

        Args:
            b (Button): Button widget that triggered the event
        """
        with self.print_output:
            clear_output()
            print("Saving chat history...")

        with self.print_output:
            # Get formatted chat history text
            chat_text = self._format_chat_history_for_download()

            # Save to system
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            filename = os.path.join(self.output_dir, f"chat_history_{timestamp}.txt")

            with open(filename, "w") as f:
                f.write(chat_text)

        with self.print_output:
            clear_output()
            print(f"Chat history saved to '{filename}' in the notebook directory.")
            print("You can download it from the 'Output' tab in the sidebar.")

    def add_message_to_chat(self, message, is_user=True):
        """
        Add a message to the chat history and update the display.

        Appends a new message to the chat history, updates the chat display
        with the complete history, and applies appropriate styling based on
        whether the message is from the user or the assistant.

        Args:
            message (str): Message content to add
            is_user (bool, optional): True if the message is from the user, False otherwise
        """
        # Add the message to the history
        self.chat_history.append({
            'content': message,
            'is_user': is_user
        })

        # Update the chat display
        with self.chat_output:
            clear_output()

            chat_html = self.style
            chat_html += "<div class='chat-container' id='chat-container'>"

            for msg in self.chat_history:
                if msg['is_user']:
                    chat_html += f"<div class='user-message'>You: {msg['content']}</div>"
                else:
                    chat_html += f"<div class='bot-message'>Assistant: {msg['content']}</div>"

            chat_html += "</div>"

            display(HTML(chat_html))

    def _detect_intent(self, query):
        """
        Detect the intent of a user query.

        Analyzes the query text to determine what kind of operation the user
        is requesting. Identifies intents like variant analysis, asking the
        agent questions, plotting results, or comparing variants.

        Args:
            query (str): User's query text

        Returns:
            dict: Intent classification dictionary with boolean values for each intent
        """
        query = query.lower()

        intent = {
        "variant_analysis": any(word in query for word in ['analyse', 'analyze']) and any(word in query for word in ['variant', 'variants']),
        "ask_agent": any(word in query for word in ['search', 'what', 'where', 'how', 'why']) or '?' in query,
        "plot_segment_nt_results": "plot" in query and ("segmentNT" in query or "results" in query),
        "compare_variants": "compare" in query and any(word in query for word in ['variant', 'variants'])
    }
        return intent

    def _extract_variant_index_from_query(self, query, total_variants):
        """
        Extract a variant index (0-based) from a user query.

        Identifies variant references like "variant 1", "first variant", etc.
        Supports both numeric and ordinal variant references.

        Args:
            query (str): User's query text
            total_variants (int): Total number of available variants

        Returns:
            int: 0-based variant index (defaults to 0 if no valid index is found)
        """
        # Match "variant 1", "variant 2", etc.
        match = re.search(r'variant\s*(\d+)', query, re.IGNORECASE)
        if match:
            idx = int(match.group(1)) - 1
            if 0 <= idx < total_variants:
                return idx

        # Handle "first", "second", "third", etc.
        ordinal_map = {
            "first": 0, "second": 1, "third": 2, "fourth": 3,
            "fifth": 4, "sixth": 5, "seventh": 6, "eighth": 7, "ninth": 8, "tenth": 9
        }
        for word, idx in ordinal_map.items():
            if word in query.lower() and idx < total_variants:
                return idx

        # Fallback: use the first variant
        return 0

    def _extract_variant_indices_for_comparison(self, query, total_variants):
        """
        Extract variant indices (0-based) for comparison from a query.

        Identifies and extracts multiple variant references for comparison.
        Supports numeric variant references like "variant 1 and variant 2".

        Args:
            query (str): User's query text
            total_variants (int): Total number of available variants

        Returns:
            list: List of up to 2 variant indices (0-based) for comparison
        """
        matches = re.findall(r'(?:variant(?:s)?\s*)?(\d+)', query, re.IGNORECASE)
        indices = []
        for match in matches:
            idx = int(match) - 1
            if 0 <= idx < total_variants:
                indices.append(idx)
        return indices[:2]  # Return up to 2 indices for comparison

    def _get_reports_for_variants(self, indices):
        """
        Retrieve stored reports for specified variants.

        Loads previously generated reports for specified variants from
        the output directory.

        Args:
            indices (list): List of variant indices

        Returns:
            list: List of report contents for each variant
        """
        reports = []
        for idx in indices:
            variant_id = str(self.pipeline.variants[idx].get("Canonical SPDI", "variant")).replace(":", "_")
            response_path = os.path.join(self.output_dir, f"report_{variant_id}.txt")
            try:
                with open(response_path, "r") as f:
                    reports.append(f.read())
            except FileNotFoundError:
                return idx
        return reports

    def plot_segment_nt_results(self):
        """
        Plot the SegmentNT results.

        This function visualizes the predicted probabilities for various genomic features
        from a SegmentNT analysis across a sequence window centered on a variant position.

        The plot includes:
        - Multiple subplots with 2 features per subplot
        - Color-coded probability lines for each genomic feature
        - A vertical red line indicating the variant position
        - Proper labeling and formatting

        The function retrieves analysis data from the pipeline, arranges features in a
        specific order to match Figure 3 from the paper, and saves the plot to the output directory.
        It also prints the probability values at the variant position.

        Returns:
            None. Displays the plot and saves it to the output directory.
        """
        try:
            if self.pipeline and self.pipeline.current_analysis:
                if 'probabilities_all_segNT' not in self.pipeline.current_analysis:
                    print("SegmentNT data is not available. Please run SegmentNT analysis first.")
                    return

                # Get necessary data from pipeline
                predicted_probabilities_all = self.pipeline.current_analysis.get("probabilities_all_segNT", [])
                features = self.pipeline.current_analysis.get("features_all_segNT", [])
                position = self.pipeline.current_analysis.get("position")
                seq_length = self.pipeline.current_analysis.get("sequence_length")

                # Filter order_to_plot to only include features we actually have
                # order_to_plot = [feat for feat in features_rearranged if feat in features]
                # if not order_to_plot:
                #     print("No features to plot.")
                #     return

                sc = 1.8
                n_panels = len(features)
                panels_per_subplot = 2
                n_subplots = (n_panels + panels_per_subplot - 1) // panels_per_subplot  # Ceiling division

                # set colors
                colors = sns.color_palette("Set2").as_hex()
                colors2 = sns.color_palette("husl").as_hex()

                fig_width=8

                # Create figure with appropriate dimensions
                _, axes = plt.subplots(n_subplots, 1, figsize=(int(fig_width) * sc, (n_subplots + 4) * sc))

                # Make sure axes is always an array, even if there's only one subplot
                if n_subplots == 1:
                    axes = [axes]

                position_int = int(position) if position is not None else seq_length // 2

                for n, feat in enumerate(features):
                    feat_id = features.index(feat)
                    prob_dist = predicted_probabilities_all[:, feat_id]
                    # Use the appropriate subplot
                    ax = axes[n // 2]
                    try:
                        id_color = colors[feat_id]
                    except:
                        id_color = colors2[feat_id - 8]

                    # Create x-axis values corresponding to nucleotide indices
                    x_values = np.arange(position - 5000, position + 5000) #change x values to go from position-5000 to position+5000
                    ax.plot(
                        x_values[:len(prob_dist)], #plot using x_values as x-coordinates of the data points. prob_dist[:len(x_values)] will contain the corresponding probability values.
                        prob_dist[:len(x_values)],
                        color=id_color,
                        label=feat,
                        linestyle="-",
                        linewidth=1.5,
                    )
                    ax.grid(False)
                    ax.spines['bottom'].set_color('black')
                    ax.spines['top'].set_color('black')
                    ax.spines['right'].set_color('black')
                    ax.spines['left'].set_color('black')

                # Set the x and y-axis limits
                for a in range(0, n_subplots):  # Change n_panels to n_subplots
                    axes[a].set_xlim(position - 5000, position + 5000)
                    axes[a].set_ylim(0, 1.05)
                    axes[a].set_ylabel("Prob.")
                    axes[a].legend(loc="upper left", bbox_to_anchor=(1, 1), borderaxespad=0)

                   # Add vertical line to highlight the variant nucleotide at x = position
                    axes[a].axvline(x=position, color='red', linestyle='--')

                    if a != (n_subplots-1):  # Change n_panels to n_subplots
                        axes[a].tick_params(axis='x', which='both', bottom=True, top=False, labelbottom=False)

                # Set common x-axis label
                axes[-1].set_xlabel("Nucleotides")
                axes[n_subplots-1].grid(False)
                axes[n_subplots-1].tick_params(axis='y', which='both', left=True, right=False, labelleft=True, labelright=False)

                axes[0].set_title("Probabilities predicted over all genomics features", fontweight="bold")

                figure=plt.gcf()
                plt.tight_layout()

                # Get Canonical SPDI for filename
                canonical_spdi = self.pipeline.current_analysis["variant"].get("Canonical SPDI", "variant")
                canonical_spdi = canonical_spdi.replace(":", "_")

                # Save the plot
                plot_filename = os.path.join(self.output_dir, f"plot_{canonical_spdi}.png")
                figure.savefig(plot_filename, bbox_inches='tight')

                # Print probabilities for nucleotide at x = position
                index_of_position = position - (position - 5000) # Calculate the index corresponding to the position
                probabilities_at_position = predicted_probabilities_all[index_of_position]

                with self.print_output:
                    print(f"Probabilities at position {position}:")
                    for i, feature in enumerate(features):
                        print(f"  {feature}: {probabilities_at_position[i]:.4f}")
                    # Display figure
                    plt.show()
                    plt.close(figure)  # Close the figure to free memory

                print(f"The SegmentNT results have been plotted and saved to: {plot_filename}")

            else:
                print("Pipeline or analysis data not available.")

        except Exception as e:
            import traceback
            traceback.print_exc()
            print(f"Error plotting SegmentNT results: {str(e)}")

    def display_ui(self):
        """
        Display the complete UI.

        This function creates and renders the entire user interface for the VariantAInalyser
        application using ipywidgets. It organizes the interface into several logical sections:

        1. Header section with app title and description
        2. API Configuration section for API keys
        3. File Upload section for VCF files
        4. Variant Information section to display variant details
        5. Analysis Assistant section for chatbot interaction
        6. Pipeline Output section for logging messages

        Each section is styled with CSS classes for consistent appearance. The function also
        adds JavaScript code to handle Enter key presses in the query input field.

        The UI components include:
        - Input fields for API keys
        - File upload widget
        - Text output areas
        - Chat interface with history
        - Action buttons

        Returns:
            None. Displays the UI in the current Kaggle notebook cell.
        """
        # Create HTML header with improved styling
        header_html = """
        <div class="app-header">
            <h1>🧬 VariantAInalyser</h1>
            <p>Advanced Genomic Variant Analysis with AI</p>
            </div>
            """

        # Create layout sections with new CSS classes
        api_key_section = widgets.VBox([
            widgets.HTML('<div class="section-header">API Configuration</div>'),
            widgets.HBox([self.api_key_input, self.apply_api_key_button]),
            widgets.HBox([self.clinvar_key_input])  # Added ClinVar key input
        ])
        api_key_section.add_class("section")

        upload_section = widgets.VBox([
            widgets.HTML('<div class="section-header">Upload VCF File</div>'),
            self.file_upload
        ])
        upload_section.add_class("section")

        variant_section = widgets.VBox([
            widgets.HTML('<div class="section-header">Variant Information</div>'),
            self.vcf_info_output
        ])
        variant_section.add_class("section")

        chat_section = widgets.VBox([
            widgets.HTML('<div class="section-header">Analysis Assistant</div>'),
            self.chat_output,
            widgets.HBox([self.download_chat_button]),
            widgets.HBox([self.query_input, self.send_button, self.clear_chat_button])
        ])
        chat_section.add_class("section")

        print_section = widgets.VBox([
            widgets.HTML('<div class="section-header">Pipeline Output</div>'),
            self.print_output
        ])
        print_section.add_class("section")

        # Apply additional styles to widgets
        self.query_input.add_class("input-field")
        self.send_button.add_class("button-primary")
        self.clear_chat_button.add_class("button-secondary")

        # Container for all components
        main_container = widgets.VBox([
            widgets.HTML(header_html),
            api_key_section,
            upload_section,
            variant_section,
            chat_section,
            print_section  # Added print section
        ])
        main_container.add_class("genomic-app")

        # Apply overall styles
        display(HTML(self.style))
        display(main_container)

        # Add JavaScript handler for Enter key
        js_code = """
        document.addEventListener('keydown', function(event) {
            if (event.key === 'Enter' && document.activeElement === document.querySelector('.input-field')) {
                event.preventDefault();
                IPython.notebook.kernel.execute("ui.on_send_button_click(None)");
            }
        });
        """
        display(HTML(f"<script>{js_code}</script>"))

print("\n", "-" * 75)
print("\n 💻 VariantAnalysisUI Class was successfully defined ! (3/3)")

## 🧬 **Try out VariantAInalyser!**

Now that all the classes have been successfully defined, run the cell below to experiment with the VariantAInalyser interface.

Here is some guidance on how to use the interface to make the most out of it:


1. **Setup**: Enter your Google API key and Clinvar API key (this one is optional) into their corresponding boxes and click "Apply API Key/s"
2. **Upload Data**: Upload a VCF file containing genetic variants by clicking on the "Upload VCF file" button. You can download an example gzipped VCF folder containing multiple variants' VCF files from the official NCBI ClinVar webpage by clicking on this link: https://ftp.ncbi.nlm.nih.gov/pub/clinvar/vcf_GRCh38/clinvar_papu.vcf.gz.    
**NOTE**: To create the altered genome, a reference genome (i.e. with no variant) is needed. The majority of the variants present in the previously linked VCF folder are present in the Y chromosome. As such, I have uploaded the reference genome of chromosome Y as a FASTA file in the Input directory under the "/kaggle/input/example-chromosome-genome-fasta-file" folder. The notebook directly links to this file whenever it needs the reference genome so there is no need to make any changes in the code. However, if you would like to test out a different VCF file with variants on other chromosomes, make sure to upload the genome of those chromosomes and change the path to the relevant one in the prepare_altered_genome() method.
3. **Explore Variants**: View extracted variant information displayed on the interface
4. **Ask Questions**: Use the chat to ask questions about your variants
    Examples of questions to ask include (you can ask all of them and in this order if you would like to experience all the features offered by the interface):
   - Please analyse in detail variant 1.
   - Analyze (different spelling on purpose to check the how spelling-proof the intent function is) variant 7.
   - Please plot the SegmentNT results.
   - Can you explain the results generated in a simple and concise manner?
   - You mentionned that the variant was located in the 5'UTR, what is that?
   - Please compare variants 1 and 7. Is one more likely to be pathogenic than the other?
6. **Download Results**: The generated analysis reports and plots are automatically saved into the output directory. You also have the possibility to download the chat history for future reference!

In [None]:
### Run this cell to instantiate the VariantAInalyser interface and start using it !

def setup_variant_analysis_ui():
    """Set up and return the Variant Analysis UI."""
    ui = VariantAnalysisUI()
    ui.display_ui()
    return ui

ui = setup_variant_analysis_ui()