<a href="https://colab.research.google.com/github/rephrain/Chat-With-CSV/blob/main/llm_(gpu).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install transformers accelerate bitsandbytes

Collecting bitsandbytes
  Downloading bitsandbytes-0.46.0-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2

In [2]:
!pip install streamlit

Collecting streamlit
  Downloading streamlit-1.45.1-py3-none-any.whl.metadata (8.9 kB)
Collecting watchdog<7,>=2.1.5 (from streamlit)
  Downloading watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.3/44.3 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
Collecting pydeck<1,>=0.8.0b4 (from streamlit)
  Downloading pydeck-0.9.1-py2.py3-none-any.whl.metadata (4.1 kB)
Downloading streamlit-1.45.1-py3-none-any.whl (9.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.9/9.9 MB[0m [31m60.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pydeck-0.9.1-py2.py3-none-any.whl (6.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.9/6.9 MB[0m [31m89.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl (79 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.1/79.1 kB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
[?25hInst

In [8]:
import streamlit as st
import pandas as pd
import numpy as np
import io
import json
import re
from typing import Dict, Any, List, Tuple, Optional, Union
import sqlite3
from datetime import datetime
import plotly.express as px
import plotly.graph_objects as go
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
import warnings
from dataclasses import dataclass
from enum import Enum
from datetime import datetime, timedelta
import spacy
from collections import defaultdict, Counter
import json
warnings.filterwarnings('ignore')
import os
torch.classes.__path__ = [os.path.join(torch.__path__[0], torch.classes.__file__)]

class CSVProcessor:
    """Advanced CSV processing with data cleaning and analysis capabilities"""

    def __init__(self):
        self.df = None
        self.metadata = {}
        self.sql_connection = None

    def load_and_process_csv(self, uploaded_file) -> pd.DataFrame:
        """Load CSV with robust processing and cleaning"""
        try:
            # Try different encodings
            encodings = ['utf-8', 'latin-1', 'cp1252', 'iso-8859-1']

            for encoding in encodings:
                try:
                    uploaded_file.seek(0)
                    df = pd.read_csv(uploaded_file, encoding=encoding)
                    break
                except UnicodeDecodeError:
                    continue
            else:
                raise ValueError("Could not decode file with any supported encoding")

            # Clean column names
            df.columns = df.columns.str.strip().str.replace(r'[^\w\s]', '', regex=True)

            # Advanced data type inference
            df = self._infer_and_convert_types(df)

            # Generate metadata
            self._generate_metadata(df)

            # Create SQL connection for complex queries
            self._create_sql_connection(df)

            self.df = df
            return df

        except Exception as e:
            st.error(f"Error processing CSV: {str(e)}")
            return None

    def _infer_and_convert_types(self, df: pd.DataFrame) -> pd.DataFrame:
        """Advanced type inference and conversion"""
        for col in df.columns:
            # Skip if already numeric
            if pd.api.types.is_numeric_dtype(df[col]):
                continue

            # Try to convert to numeric
            numeric_series = pd.to_numeric(df[col], errors='coerce')
            if not numeric_series.isna().all() and numeric_series.notna().sum() > len(df) * 0.5:
                df[col] = numeric_series
                continue

            # Try to convert to datetime
            try:
                if df[col].dtype == 'object':
                    datetime_series = pd.to_datetime(df[col], errors='coerce', infer_datetime_format=True)
                    if datetime_series.notna().sum() > len(df) * 0.5:
                        df[col] = datetime_series
                        continue
            except:
                pass

            # Clean string columns
            if df[col].dtype == 'object':
                df[col] = df[col].astype(str).str.strip()

        return df

    def _generate_metadata(self, df: pd.DataFrame):
        """Generate comprehensive metadata about the dataset"""
        self.metadata = {
            'shape': df.shape,
            'columns': list(df.columns),
            'dtypes': df.dtypes.to_dict(),
            'numeric_columns': df.select_dtypes(include=[np.number]).columns.tolist(),
            'categorical_columns': df.select_dtypes(include=['object']).columns.tolist(),
            'datetime_columns': df.select_dtypes(include=['datetime64']).columns.tolist(),
            'missing_values': df.isnull().sum().to_dict(),
            'summary_stats': {}
        }

        # Generate summary statistics for numeric columns
        for col in self.metadata['numeric_columns']:
            self.metadata['summary_stats'][col] = {
                'mean': float(df[col].mean()),
                'median': float(df[col].median()),
                'std': float(df[col].std()),
                'min': float(df[col].min()),
                'max': float(df[col].max()),
                'count': int(df[col].count())
            }

    def _create_sql_connection(self, df: pd.DataFrame):
        """Create in-memory SQLite database for complex queries"""
        self.sql_connection = sqlite3.connect(':memory:')
        df.to_sql('data', self.sql_connection, index=False, if_exists='replace')

    def execute_sql_query(self, query: str) -> pd.DataFrame:
        """Execute SQL query on the dataset"""
        try:
            return pd.read_sql_query(query, self.sql_connection)
        except Exception as e:
            st.error(f"SQL Error: {str(e)}")
            return pd.DataFrame()

class QueryType(Enum):
    AGGREGATION = "aggregation"
    FILTER = "filter"
    COMPARISON = "comparison"
    SUMMARY = "summary"
    GROUPBY = "groupby"
    CORRELATION = "correlation"
    TEMPORAL = "temporal"
    STATISTICAL = "statistical"
    RANKING = "ranking"
    COMPLEX_MULTI_STEP = "complex_multi_step"
    PREDICTION = "prediction"
    ANOMALY_DETECTION = "anomaly_detection"

@dataclass
class QueryContext:
    """Enhanced context for query processing"""
    intent: QueryType
    entities: List[Dict]
    columns: List[str]
    conditions: Dict[str, Any]
    operations: List[str]
    temporal_info: Dict[str, Any]
    statistical_params: Dict[str, Any]
    dependencies: List[str]
    sub_queries: List[Dict]
    visualization_hint: str

class CSVProcessor:
    """Advanced CSV processing with data cleaning and analysis capabilities"""

    def __init__(self):
        self.df = None
        self.metadata = {}
        self.sql_connection = None

    def load_and_process_csv(self, uploaded_file) -> pd.DataFrame:
        """Load CSV with robust processing and cleaning"""
        try:
            # Try different encodings
            encodings = ['utf-8', 'latin-1', 'cp1252', 'iso-8859-1']

            for encoding in encodings:
                try:
                    uploaded_file.seek(0)
                    df = pd.read_csv(uploaded_file, encoding=encoding)
                    break
                except UnicodeDecodeError:
                    continue
            else:
                raise ValueError("Could not decode file with any supported encoding")

            # Clean column names
            df.columns = df.columns.str.strip().str.replace(r'[^\w\s]', '', regex=True)

            # Advanced data type inference
            df = self._infer_and_convert_types(df)

            # Generate metadata
            self._generate_metadata(df)

            # Create SQL connection for complex queries
            self._create_sql_connection(df)

            self.df = df
            return df

        except Exception as e:
            st.error(f"Error processing CSV: {str(e)}")
            return None

    def _infer_and_convert_types(self, df: pd.DataFrame) -> pd.DataFrame:
        """Advanced type inference and conversion"""
        for col in df.columns:
            # Skip if already numeric
            if pd.api.types.is_numeric_dtype(df[col]):
                continue

            # Try to convert to numeric
            numeric_series = pd.to_numeric(df[col], errors='coerce')
            if not numeric_series.isna().all() and numeric_series.notna().sum() > len(df) * 0.5:
                df[col] = numeric_series
                continue

            # Try to convert to datetime
            try:
                if df[col].dtype == 'object':
                    datetime_series = pd.to_datetime(df[col], errors='coerce', infer_datetime_format=True)
                    if datetime_series.notna().sum() > len(df) * 0.5:
                        df[col] = datetime_series
                        continue
            except:
                pass

            # Clean string columns
            if df[col].dtype == 'object':
                df[col] = df[col].astype(str).str.strip()

        return df

    def _generate_metadata(self, df: pd.DataFrame):
        """Generate comprehensive metadata about the dataset"""
        self.metadata = {
            'shape': df.shape,
            'columns': list(df.columns),
            'dtypes': df.dtypes.to_dict(),
            'numeric_columns': df.select_dtypes(include=[np.number]).columns.tolist(),
            'categorical_columns': df.select_dtypes(include=['object']).columns.tolist(),
            'datetime_columns': df.select_dtypes(include=['datetime64']).columns.tolist(),
            'missing_values': df.isnull().sum().to_dict(),
            'summary_stats': {}
        }

        # Generate summary statistics for numeric columns
        for col in self.metadata['numeric_columns']:
            self.metadata['summary_stats'][col] = {
                'mean': float(df[col].mean()),
                'median': float(df[col].median()),
                'std': float(df[col].std()),
                'min': float(df[col].min()),
                'max': float(df[col].max()),
                'count': int(df[col].count())
            }

    def _create_sql_connection(self, df: pd.DataFrame):
        """Create in-memory SQLite database for complex queries"""
        self.sql_connection = sqlite3.connect(':memory:')
        df.to_sql('data', self.sql_connection, index=False, if_exists='replace')

    def execute_sql_query(self, query: str) -> pd.DataFrame:
        """Execute SQL query on the dataset"""
        try:
            return pd.read_sql_query(query, self.sql_connection)
        except Exception as e:
            st.error(f"SQL Error: {str(e)}")
            return pd.DataFrame()

class QueryType(Enum):
    AGGREGATION = "aggregation"
    FILTER = "filter"
    COMPARISON = "comparison"
    SUMMARY = "summary"
    GROUPBY = "groupby"
    CORRELATION = "correlation"
    TEMPORAL = "temporal"
    STATISTICAL = "statistical"
    RANKING = "ranking"
    COMPLEX_MULTI_STEP = "complex_multi_step"
    PREDICTION = "prediction"
    ANOMALY_DETECTION = "anomaly_detection"

@dataclass
class QueryContext:
    """Enhanced context for query processing"""
    intent: QueryType
    entities: List[Dict]
    columns: List[str]
    conditions: Dict[str, Any]
    operations: List[str]
    temporal_info: Dict[str, Any]
    statistical_params: Dict[str, Any]
    dependencies: List[str]
    sub_queries: List[Dict]
    visualization_hint: str

class LLMQueryProcessor:
    """Enhanced LLM-based query processor with T5 text-to-SQL capabilities"""

    def __init__(self):
        self.tokenizer = None
        self.model = None
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self._initialize_model()

    def _initialize_model(self):
        """Initialize the SQLCoder model (Defog/sqlcoder-7b-2 with quantization)"""
        try:
            bnb_config = BitsAndBytesConfig(load_in_4bit=True)
            model_name = "defog/sqlcoder-7b-2"

            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name,
                quantization_config=bnb_config,
                device_map="auto"
            )
            self.model.eval()

        except Exception as e:
            st.error(f"Failed to load SQLCoder model: {str(e)}")
            st.warning("Falling back to rule-based SQL generation")


    def _clean_user_prompt(self, query: str) -> str:
        """Clean and preprocess user query"""
        # Remove extra whitespace
        query = re.sub(r'\s+', ' ', query.strip())

        # Remove special characters that might interfere with SQL generation
        query = re.sub(r'[^\w\s\-\.\,\?\!\:\;]', '', query)

        # Convert to lowercase for better processing
        query = query.lower()

        # Remove common filler words that don't add value to SQL generation
        filler_words = ['please', 'can you', 'could you', 'would you', 'i want to', 'i need to', 'help me']
        for filler in filler_words:
            query = query.replace(filler, '')

        # Clean up extra spaces again
        query = re.sub(r'\s+', ' ', query.strip())

        return query

    def _generate_table_schema(self, metadata: Dict) -> str:
        """Generate SQL CREATE TABLE statement from metadata"""
        columns = metadata.get('columns', [])
        dtypes = metadata.get('dtypes', {})

        # Map pandas dtypes to SQL types
        type_mapping = {
            'object': 'TEXT',
            'int64': 'INTEGER',
            'int32': 'INTEGER',
            'float64': 'REAL',
            'float32': 'REAL',
            'bool': 'INTEGER',
            'datetime64[ns]': 'TEXT',
            'category': 'TEXT'
        }

        column_definitions = []
        for col in columns:
            dtype = str(dtypes.get(col, 'object'))
            sql_type = type_mapping.get(dtype, 'TEXT')
            # Clean column name for SQL compatibility
            clean_col = re.sub(r'[^\w]', '_', col)
            column_definitions.append(f"{clean_col} {sql_type}")

        schema = f"CREATE TABLE data ({', '.join(column_definitions)})"
        return schema

    def _generate_sql_with_sqlcoder(self, query: str, metadata: Dict) -> str:
        """Generate SQL using SQLCoder (Causal LM)"""
        if not self.model or not self.tokenizer:
            raise Exception("SQLCoder model not initialized")

        try:
            # Generate the table schema from metadata
            table_schema = self._generate_table_schema(metadata)

            # Build the prompt format expected by SQLCoder
            input_prompt = f"""
    ### Task
    Generate a SQL query to answer [QUESTION]{query}[/QUESTION]

    ### Database Schema
    The query will run on a database with the following schema:
    {table_schema}

    ### Answer
    Given the database schema, here is the SQL query that [QUESTION]{query}[/QUESTION]
    [SQL]
    """.strip()

            inputs = self.tokenizer(input_prompt, return_tensors="pt").to(self.model.device)

            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_length=512,
                    do_sample=False,
                    temperature=0.7,
                    top_p=0.95,
                    num_return_sequences=1
                )

            generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            generated_sql = generated_text.split("[SQL]")[-1].strip()

            # Optional: clean/validate the generated SQL
            generated_sql = self._clean_generated_sql(generated_sql, metadata)

            return generated_sql

        except Exception as e:
            st.error(f"Error generating SQL with SQLCoder: {str(e)}")
            raise

    def _clean_generated_sql(self, sql: str, metadata: Dict) -> str:
        """Clean and validate generated SQL query"""
        # Remove any extra whitespace
        sql = re.sub(r'\s+', ' ', sql.strip())

        # Ensure SQL ends with semicolon
        if not sql.endswith(';'):
            sql += ';'

        # Replace table references to match our 'data' table
        sql = re.sub(r'\bFROM\s+\w+', 'FROM data', sql, flags=re.IGNORECASE)

        # Clean column names to match our processed column names
        columns = metadata.get('columns', [])
        for col in columns:
            clean_col = re.sub(r'[^\w]', '_', col)
            # Replace various possible column name formats
            patterns = [
                f'\\b{re.escape(col)}\\b',
                f'\\b{re.escape(col.lower())}\\b',
                f'\\b{re.escape(col.upper())}\\b'
            ]
            for pattern in patterns:
                sql = re.sub(pattern, clean_col, sql, flags=re.IGNORECASE)

        return sql

    def parse_complex_query(self, query: str, metadata: Dict) -> QueryContext:
        """Parse and analyze complex queries"""
        # Clean the user prompt
        cleaned_query = self._clean_user_prompt(query)

        # Simple intent classification
        intent = self._classify_intent(cleaned_query)

        # Extract relevant columns
        columns = self._extract_columns(cleaned_query, metadata)

        # Extract operations
        operations = self._extract_operations(cleaned_query)

        return QueryContext(
            intent=intent,
            entities=[],
            columns=columns,
            conditions={},
            operations=operations,
            temporal_info={},
            statistical_params={},
            dependencies=[],
            sub_queries=[],
            visualization_hint=""
        )

    def _classify_intent(self, query: str) -> QueryType:
        """Classify the intent of the query"""
        query_lower = query.lower()

        if any(word in query_lower for word in ['average', 'mean', 'sum', 'count', 'max', 'min']):
            return QueryType.AGGREGATION
        elif any(word in query_lower for word in ['summary', 'describe', 'overview']):
            return QueryType.SUMMARY
        elif any(word in query_lower for word in ['group', 'by']):
            return QueryType.GROUPBY
        elif any(word in query_lower for word in ['where', 'filter', 'condition']):
            return QueryType.FILTER
        else:
            return QueryType.AGGREGATION

    def _extract_columns(self, query: str, metadata: Dict) -> List[str]:
        """Extract relevant columns from the query"""
        columns = metadata.get('columns', [])
        mentioned_columns = []

        query_lower = query.lower()
        for col in columns:
            if col.lower() in query_lower:
                mentioned_columns.append(col)

        return mentioned_columns

    def _extract_operations(self, query: str) -> List[str]:
        """Extract operations from the query"""
        operations = []
        query_lower = query.lower()

        operation_keywords = {
            'average': 'AVG',
            'mean': 'AVG',
            'sum': 'SUM',
            'total': 'SUM',
            'count': 'COUNT',
            'maximum': 'MAX',
            'max': 'MAX',
            'minimum': 'MIN',
            'min': 'MIN'
        }

        for keyword, operation in operation_keywords.items():
            if keyword in query_lower:
                operations.append(operation)

        return operations

    def generate_advanced_sql(self, context: QueryContext, metadata: Dict, query) -> Dict[str, Any]:
        """Generate advanced SQL based on context"""
        try:
            # Try to generate SQL using T5 model first
            sql_query = self._generate_sql_with_t5(
                query,
                metadata
            )

            return {
                'main_query': sql_query,
                'context': context,
                'success': True
            }

        except Exception as e:
            st.error(f"Error generating SQL: {str(e)}")

            # Final fallback
            fallback_query = self._fallback_sql_generation(
                " ".join(context.operations + context.columns),
                metadata
            )

            return {
                'main_query': fallback_query,
                'context': context,
                'success': False,
                'error': str(e)
            }

class ChatCSVApp:
    """Main Streamlit application class"""

    def __init__(self):
        self.csv_processor = CSVProcessor()
        self.llm_processor = LLMQueryProcessor()
        # Initialize conversation history in session state
        if 'conversation_history' not in st.session_state:
            st.session_state.conversation_history = []

    def run(self):
        """Main application runner"""
        st.set_page_config(
            page_title="Chat with CSV (LLM)",
            layout="wide"
        )

        with st.container():
            st.markdown(
                """
                <div style='text-align: center;'>
                    <h1>Chat with CSV (LLM)</h1>
                    <p style='font-size: 18px;'>Upload a CSV file and ask questions in natural language to get intelligent insights!</p>
                </div>
                """,
                unsafe_allow_html=True
            )

        # Sidebar for file upload and settings
        with st.sidebar:
            st.header("📁 File Upload")
            uploaded_file = st.file_uploader("Choose a CSV file", type="csv")

            if uploaded_file is not None:
                st.success("File uploaded successfully!")

                # Process the CSV
                with st.spinner("Processing CSV..."):
                    df = self.csv_processor.load_and_process_csv(uploaded_file)

        # Main content area
        if uploaded_file is not None and self.csv_processor.df is not None:
            col1, col2 = st.columns([1, 1])

            with col1:
                self._render_chat_interface()

            with col2:
                self._render_data_preview()
        else:
            st.info("Please upload a CSV file to get started!")

            # Show example queries
            st.subheader("Example Queries")
            examples = [
                "What is the average age?",
                "How many records are there?",
                "Show me the maximum salary",
                "What's the total revenue?",
                "Give me a summary of the data"
            ]

            for example in examples:
                st.code(example)

    def _render_chat_interface(self):
        """Render the chat interface"""
        st.subheader("Chat Interface")

        # Display conversation history from session state FIRST
        for i, (query, response_text, result_df, debug_info) in enumerate(st.session_state.conversation_history):
            with st.chat_message("user"):
                st.markdown(query)
            with st.chat_message("assistant"):
                st.markdown(response_text)

                if result_df is not None:
                    st.dataframe(result_df, use_container_width=True)

                if debug_info:
                    with st.expander("🧠 Debug Info", expanded=False):
                        for key, value in debug_info.items():
                            if key == "Generated SQL":
                                st.code(value, language="sql")
                            else:
                                st.markdown(f"- **{key}**: `{value}`")

            st.markdown("---")

        # Use a form for input - this auto-clears after submission
        with st.form(key="query_form", clear_on_submit=True):
            query = st.text_input("Ask a question about your data:",
                                placeholder="e.g., What is the average age?")
            submitted = st.form_submit_button("Send", type="primary")

        if submitted and query:
            with st.spinner("Analyzing your query..."):
                response_text, result_df, debug_info = self._process_query(query)

                # Save history for display
                st.session_state.conversation_history.append((query, response_text, result_df, debug_info))
                st.rerun()

    def _process_query(self, query: str) -> Tuple[str, Optional[pd.DataFrame], dict]:
        """Returns (response_text, table_df, debug_info_dict)"""
        try:
            context = self.llm_processor.parse_complex_query(query, self.csv_processor.metadata)

            # Collect debug info
            debug_info = {
                "Query": query,
                "Intent": getattr(context, 'intent', 'N/A'),
                "Columns": getattr(context, 'columns', []),
                "Operations": getattr(context, 'operations', []),
            }

            sql_result = self.llm_processor.generate_advanced_sql(context, self.csv_processor.metadata, query)
            sql_query = sql_result.get('main_query')
            if sql_query:
                debug_info["Generated SQL"] = sql_query

                result_df = self.csv_processor.execute_sql_query(sql_query)

                if not result_df.empty:
                    if len(result_df) == 1 and len(result_df.columns) == 1:
                        value = result_df.iloc[0, 0]
                        value_str = f"**{value:.2f}**" if isinstance(value, (int, float)) else f"**{value}**"
                        return f"The result is {value_str}", None, debug_info
                    else:
                        return "Here are the results:", result_df, debug_info
                else:
                    return "No results found for your query.", None, debug_info

            elif getattr(context, 'analysis_type', '') == 'summary':
                return self._generate_summary_response(), None, debug_info

            else:
                return "I need more specific information to give an accurate answer.", None, debug_info

        except Exception as e:
            return f"I encountered an error while processing your query:\n\n`{str(e)}`", None, {}

    def _generate_summary_response(self) -> str:
        """Generate a comprehensive summary of the dataset"""
        metadata = self.csv_processor.metadata

        summary = f"""
**Dataset Summary:**

**Basic Info:**
- Total rows: {metadata['shape'][0]:,}
- Total columns: {metadata['shape'][1]}

**Column Types:**
- Numeric columns: {len(metadata['numeric_columns'])}
- Text columns: {len(metadata['categorical_columns'])}
- Date columns: {len(metadata['datetime_columns'])}

**Key Statistics:**
"""

        for col, stats in metadata['summary_stats'].items():
            summary += f"\n**{col}:** Mean = {stats['mean']:.2f}, Range = {stats['min']:.2f} to {stats['max']:.2f}"

        return summary

    def _render_data_preview(self):
        """Render comprehensive data preview and statistics"""
        st.subheader("Data Preview")

        if self.csv_processor.df is not None:
            df = self.csv_processor.df

            # Create tabs for better organization
            preview_tab, stats_tab, viz_tab, quality_tab = st.tabs([
                "Data Sample", "Statistics", "Visualizations", "Data Quality"
            ])

            with preview_tab:
                # Dataset overview
                col1, col2, col3, col4 = st.columns(4)
                with col1:
                    st.metric("Total Rows", f"{len(df):,}")
                with col2:
                    st.metric("Total Columns", len(df.columns))
                with col3:
                    st.metric("Memory Usage", f"{df.memory_usage(deep=True).sum() / 1024**2:.1f} MB")
                with col4:
                    st.metric("Duplicate Rows", df.duplicated().sum())

                # Show customizable data sample
                st.subheader("Data Sample")
                sample_size = st.slider("Number of rows to display:", 5, min(50, len(df)), 10)
                sample_type = st.radio("Sample type:", ["First rows", "Random sample", "Last rows"], horizontal=True)

                if sample_type == "First rows":
                    sample_df = df.head(sample_size)
                elif sample_type == "Random sample":
                    sample_df = df.sample(min(sample_size, len(df))) if len(df) > 0 else df
                else:
                    sample_df = df.tail(sample_size)

                st.dataframe(sample_df, use_container_width=True)

                # Column information
                st.subheader("Column Information")
                col_info = pd.DataFrame({
                    'Column': df.columns,
                    'Data Type': df.dtypes.astype(str),
                    'Non-Null Count': df.count(),
                    'Null Count': df.isnull().sum(),
                    'Null %': (df.isnull().sum() / len(df) * 100).round(2),
                    'Unique Values': df.nunique(),
                    'Unique %': (df.nunique() / len(df) * 100).round(2)
                })
                st.dataframe(col_info, use_container_width=True)

            with stats_tab:
                numeric_cols = self.csv_processor.metadata['numeric_columns']
                categorical_cols = [col for col in df.columns if col not in numeric_cols]

                # Numeric statistics
                if numeric_cols:
                    st.subheader("Numeric Columns Statistics")
                    numeric_stats = df[numeric_cols].describe()
                    st.dataframe(numeric_stats, use_container_width=True)

                    # Additional statistics
                    st.subheader("Additional Numeric Statistics")
                    additional_stats = pd.DataFrame({
                        'Column': numeric_cols,
                        'Skewness': [df[col].skew() for col in numeric_cols],
                        'Kurtosis': [df[col].kurtosis() for col in numeric_cols],
                        'Variance': [df[col].var() for col in numeric_cols],
                        'Range': [df[col].max() - df[col].min() for col in numeric_cols],
                        'IQR': [df[col].quantile(0.75) - df[col].quantile(0.25) for col in numeric_cols]
                    })
                    st.dataframe(additional_stats.round(3), use_container_width=True)

                # Categorical statistics
                if categorical_cols:
                    st.subheader("Categorical Columns Statistics")
                    cat_stats = []
                    for col in categorical_cols:
                        if df[col].dtype == 'object' or df[col].dtype.name == 'category':
                            mode_val = df[col].mode().iloc[0] if not df[col].mode().empty else 'N/A'
                            cat_stats.append({
                                'Column': col,
                                'Unique Values': df[col].nunique(),
                                'Most Frequent': mode_val,
                                'Frequency': df[col].value_counts().iloc[0] if len(df[col].value_counts()) > 0 else 0,
                                'Frequency %': (df[col].value_counts().iloc[0] / len(df) * 100).round(2) if len(df[col].value_counts()) > 0 else 0
                            })

                    if cat_stats:
                        cat_df = pd.DataFrame(cat_stats)
                        st.dataframe(cat_df, use_container_width=True)

            with viz_tab:
                st.subheader("Data Visualizations")

                # Visualization options
                viz_type = st.selectbox("Choose visualization type:", [
                    "Distribution Analysis", "Correlation Analysis", "Missing Data Pattern", "Category Analysis"
                ])

                if viz_type == "Distribution Analysis" and numeric_cols:
                    selected_col = st.selectbox("Select numeric column:", numeric_cols)

                    if selected_col:
                        col1, col2 = st.columns(2)

                        with col1:
                            # Histogram
                            fig_hist = px.histogram(df, x=selected_col,
                                                title=f"Distribution of {selected_col}",
                                                marginal="box")
                            st.plotly_chart(fig_hist, use_container_width=True)

                        with col2:
                            # Box plot
                            fig_box = px.box(df, y=selected_col,
                                            title=f"Box Plot of {selected_col}")
                            st.plotly_chart(fig_box, use_container_width=True)

                elif viz_type == "Correlation Analysis" and len(numeric_cols) > 1:
                    # Correlation heatmap
                    corr_matrix = df[numeric_cols].corr()
                    fig_corr = px.imshow(corr_matrix,
                                    title="Correlation Heatmap",
                                    color_continuous_scale="RdBu_r",
                                    aspect="auto")
                    fig_corr.update_layout(width=600, height=500)
                    st.plotly_chart(fig_corr, use_container_width=True)

                    # Pairwise scatter plot option
                    if len(numeric_cols) >= 2:
                        st.subheader("Pairwise Relationship")
                        col1, col2 = st.columns(2)
                        with col1:
                            x_col = st.selectbox("X-axis:", numeric_cols, key="x_axis")
                        with col2:
                            y_col = st.selectbox("Y-axis:", [col for col in numeric_cols if col != x_col], key="y_axis")

                        if x_col and y_col:
                            fig_scatter = px.scatter(df, x=x_col, y=y_col,
                                                title=f"{x_col} vs {y_col}")
                            st.plotly_chart(fig_scatter, use_container_width=True)

                elif viz_type == "Missing Data Pattern":
                    # Missing data visualization
                    missing_data = df.isnull().sum()
                    missing_data = missing_data[missing_data > 0].sort_values(ascending=False)

                    if not missing_data.empty:
                        fig_missing = px.bar(x=missing_data.values, y=missing_data.index,
                                        orientation='h',
                                        title="Missing Data by Column",
                                        labels={'x': 'Number of Missing Values', 'y': 'Columns'})
                        st.plotly_chart(fig_missing, use_container_width=True)

                        # Missing data percentage
                        missing_pct = (missing_data / len(df) * 100).round(2)
                        fig_pct = px.bar(x=missing_pct.values, y=missing_pct.index,
                                    orientation='h',
                                    title="Missing Data Percentage by Column",
                                    labels={'x': 'Percentage Missing', 'y': 'Columns'})
                        st.plotly_chart(fig_pct, use_container_width=True)
                    else:
                        st.success("No missing data found in the dataset!")

                elif viz_type == "Category Analysis":
                    categorical_cols_viz = [col for col in df.columns if df[col].dtype == 'object' or df[col].nunique() < 20]

                    if categorical_cols_viz:
                        selected_cat = st.selectbox("Select categorical column:", categorical_cols_viz)

                        if selected_cat:
                            value_counts = df[selected_cat].value_counts().head(20)  # Top 20 categories

                            col1, col2 = st.columns(2)
                            with col1:
                                # Bar chart
                                fig_bar = px.bar(x=value_counts.index, y=value_counts.values,
                                            title=f"Distribution of {selected_cat}")
                                fig_bar.update_xaxes(tickangle=45)
                                st.plotly_chart(fig_bar, use_container_width=True)

                            with col2:
                                # Pie chart (for top categories)
                                top_categories = value_counts.head(10)
                                fig_pie = px.pie(values=top_categories.values, names=top_categories.index,
                                            title=f"Top 10 Categories in {selected_cat}")
                                st.plotly_chart(fig_pie, use_container_width=True)
                    else:
                        st.info("No suitable categorical columns found for visualization.")

            with quality_tab:
                st.subheader("Data Quality Assessment")

                # Overall data quality score
                total_cells = len(df) * len(df.columns)
                missing_cells = df.isnull().sum().sum()
                duplicate_rows = df.duplicated().sum()

                quality_score = max(0, 100 - (missing_cells / total_cells * 50) - (duplicate_rows / len(df) * 30))

                st.metric("Data Quality Score", f"{quality_score:.1f}/100")

                # Quality issues breakdown
                col1, col2 = st.columns(2)

                with col1:
                    st.subheader("Quality Issues")
                    issues = []

                    if missing_cells > 0:
                        issues.append(f"• {missing_cells:,} missing values ({missing_cells/total_cells*100:.1f}% of all data)")

                    if duplicate_rows > 0:
                        issues.append(f"• {duplicate_rows:,} duplicate rows ({duplicate_rows/len(df)*100:.1f}% of total rows)")

                    # Check for potential outliers in numeric columns
                    outlier_cols = []
                    for col in numeric_cols:
                        Q1 = df[col].quantile(0.25)
                        Q3 = df[col].quantile(0.75)
                        IQR = Q3 - Q1
                        outliers = df[(df[col] < (Q1 - 1.5 * IQR)) | (df[col] > (Q3 + 1.5 * IQR))][col].count()
                        if outliers > 0:
                            outlier_cols.append(f"{col}: {outliers} outliers")

                    if outlier_cols:
                        issues.append("• Potential outliers detected:")
                        for outlier_info in outlier_cols:
                            issues.append(f"  - {outlier_info}")

                    if not issues:
                        st.success("No major data quality issues detected!")
                    else:
                        for issue in issues:
                            st.warning(issue)

                with col2:
                    st.subheader("Recommendations")
                    recommendations = []

                    if missing_cells > total_cells * 0.05:  # More than 5% missing
                        recommendations.append("• Consider data imputation strategies for missing values")

                    if duplicate_rows > 0:
                        recommendations.append("• Remove or investigate duplicate rows")

                    if len(outlier_cols) > 0:
                        recommendations.append("• Investigate potential outliers in numeric columns")

                    # Check for high cardinality categorical columns
                    high_card_cols = [col for col in df.columns if df[col].dtype == 'object' and df[col].nunique() > len(df) * 0.8]
                    if high_card_cols:
                        recommendations.append(f"• Consider feature engineering for high-cardinality columns: {', '.join(high_card_cols)}")

                    if not recommendations:
                        st.success("Data appears to be in good condition!")
                    else:
                        for rec in recommendations:
                            st.info(rec)

                # Detailed column analysis
                st.subheader("Detailed Column Analysis")
                problematic_cols = []

                for col in df.columns:
                    issues = []

                    # Check missing values
                    missing_pct = df[col].isnull().sum() / len(df) * 100
                    if missing_pct > 20:
                        issues.append(f"High missing rate: {missing_pct:.1f}%")

                    # Check constant values
                    if df[col].nunique() == 1:
                        issues.append("Constant values (no variation)")

                    # Check high cardinality for categorical
                    if df[col].dtype == 'object' and df[col].nunique() > len(df) * 0.9:
                        issues.append("Very high cardinality")

                    if issues:
                        problematic_cols.append({
                            'Column': col,
                            'Issues': ', '.join(issues),
                            'Data Type': str(df[col].dtype),
                            'Missing %': f"{missing_pct:.1f}%"
                        })

                if problematic_cols:
                    prob_df = pd.DataFrame(problematic_cols)
                    st.dataframe(prob_df, use_container_width=True)
                else:
                    st.success("All columns appear to be in good condition!")

        else:
            st.warning("No data available. Please upload a CSV file first.")

def main():
    """Main function to run the Streamlit app"""
    app = ChatCSVApp()
    app.run()

if __name__ == "__main__":
    main()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



In [None]:
!streamlit run /usr/local/lib/python3.11/dist-packages/colab_kernel_launcher.py


Collecting usage statistics. To deactivate, set browser.gatherUsageStats to false.
[0m
[0m
[34m[1m  You can now view your Streamlit app in your browser.[0m
[0m
[34m  Local URL: [0m[1mhttp://localhost:8501[0m
[34m  Network URL: [0m[1mhttp://172.28.0.12:8501[0m
[34m  External URL: [0m[1mhttp://34.169.144.122:8501[0m
[0m
