<a href="https://colab.research.google.com/github/ozzaii/AI-ScratchBook/blob/main/Working_of_new_business_rag_deployment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install torch transformers unsloth langchain_huggingface langchain_community faiss-cpu pandas tqdm pickle-mixin

Collecting unsloth
  Downloading unsloth-2024.10.7-py3-none-any.whl.metadata (56 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.8/56.8 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting langchain_huggingface
  Downloading langchain_huggingface-0.1.0-py3-none-any.whl.metadata (1.3 kB)
Collecting langchain_community
  Downloading langchain_community-0.3.3-py3-none-any.whl.metadata (2.8 kB)
Collecting faiss-cpu
  Downloading faiss_cpu-1.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.4 kB)
Collecting pickle-mixin
  Downloading pickle-mixin-1.0.2.tar.gz (5.1 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting unsloth-zoo (from unsloth)
  Downloading unsloth_zoo-2024.10.4-py3-none-any.whl.metadata (48 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m48.2/48.2 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting xformers>=0.0.27.post2 (from unsloth)
  Downloading xformers-0.0.28.post2-

In [2]:
!pip install plotly seaborn faiss-gpu

Collecting faiss-gpu
  Downloading faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)
Downloading faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (85.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.5/85.5 MB[0m [31m26.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: faiss-gpu
Successfully installed faiss-gpu-1.7.2


In [None]:
# Import statements
import os
import pickle
import faiss
import numpy as np
from typing import List, Dict, Any, Tuple, Set
import pandas as pd
import torch
from dataclasses import dataclass
from datetime import datetime
import calendar
import re
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
import spacy
from dateutil.parser import parse
from dateutil.relativedelta import relativedelta
from tqdm.auto import tqdm
import json
import gc
import logging
from sentence_transformers import SentenceTransformer
from unsloth import FastLanguageModel
import plotly.graph_objects as go

# Define File Paths and Constants
BASE_PATH = "/content/drive2/MyDrive/trade_analysis_model"
CHECKPOINT_PATH = f"{BASE_PATH}/checkpoint-45"
CSV_PATH = f"{BASE_PATH}/export_data.csv"
CHART_STORE_PATH = f"{BASE_PATH}/charts"
INDEX_PATH = f"{BASE_PATH}/rag_indices"

# Model Constants
EMBEDDING_MODEL_NAME = 'sentence-transformers/all-mpnet-base-v2'
MAX_LENGTH = 2048
BATCH_SIZE = 1000

# Define Product Keywords for IntelligentRAG
PRODUCT_KEYWORDS = {
    'rice': ['rice', 'basmati', 'irri', 'sella', 'paddy', 'chawal', 'برن', 'چاول'],
    'cotton': ['cotton', 'yarn', 'textile', 'fabric', 'کپاس', 'روئی'],
    'wheat': ['wheat', 'flour', 'grain'],
    'sugar': ['sugar', 'raw sugar', 'refined sugar', 'brown sugar']
    # Add more product categories and their keywords as needed
}

@dataclass
class TradeRecord:
    date: datetime
    description: str
    origin: str
    exporter: str
    importer: str
    quantity: float
    value_pkr: float
    unit: str
    original_text: str

class IntelligentRAG:
    def __init__(self, base_path):
        self.base_path = base_path
        self.index_path = f"{base_path}/rag_indices"
        self.semantic_index_path = f"{self.index_path}/semantic_index"

        # Create directories if they don't exist
        os.makedirs(self.semantic_index_path, exist_ok=True)

        # Initialize sentence transformer
        self.embed_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
        self.embed_dim = self.embed_model.get_sentence_embedding_dimension()

        # Initialize indices
        self.faiss_index = None
        self.record_embeddings = None
        self.records = []

    def _build_semantic_index(self, force_rebuild: bool = False):
        """Build and save semantic search index"""
        index_files = {
            'faiss': f"{self.semantic_index_path}/faiss.index",
            'embeddings': f"{self.semantic_index_path}/embeddings.npy",
            'records': f"{self.semantic_index_path}/records.pkl"
        }

        # Check if index exists
        if not force_rebuild and all(os.path.exists(f) for f in index_files.values()):
            print("Loading existing semantic index...")
            try:
                # Load FAISS index
                self.faiss_index = faiss.read_index(index_files['faiss'])
                # Load embeddings
                self.record_embeddings = np.load(index_files['embeddings'])
                # Load records mapping
                with open(index_files['records'], 'rb') as f:
                    self.records = pickle.load(f)
                print(f"Loaded semantic index with {len(self.records)} records")
                return
            except Exception as e:
                print(f"Error loading semantic index: {e}")
                print("Rebuilding index...")

        print("Building new semantic index...")
        texts = []

        # Create rich text representations
        for record in tqdm(self.records, desc="Creating text representations"):
            text = f"""
            Product: {record['description']}
            Origin: {record['origin']}
            Exporter: {record['exporter']}
            Importer: {record['importer']}
            Value: {record['value_pkr']} PKR
            Quantity: {record['quantity']} {record['unit']}
            """
            texts.append(text)

        # Generate embeddings in batches
        print("Generating embeddings...")
        batch_size = 32
        all_embeddings = []

        for i in tqdm(range(0, len(texts), batch_size), desc="Generating embeddings"):
            batch = texts[i:i + batch_size]
            with torch.no_grad():
                embeddings = self.embed_model.encode(batch, convert_to_tensor=True)
                all_embeddings.append(embeddings.cpu().numpy())

            # Clear CUDA cache periodically
            if torch.cuda.is_available() and i % (batch_size * 10) == 0:
                torch.cuda.empty_cache()

        self.record_embeddings = np.vstack(all_embeddings)

        # Initialize and build FAISS index
        print("Building FAISS index...")
        self.faiss_index = faiss.IndexFlatL2(self.embed_dim)
        self.faiss_index.add(self.record_embeddings)

        # Save all components
        print("Saving semantic index...")
        try:
            # Save FAISS index
            faiss.write_index(self.faiss_index, index_files['faiss'])
            # Save embeddings
            np.save(index_files['embeddings'], self.record_embeddings)
            # Save records mapping
            with open(index_files['records'], 'wb') as f:
                pickle.dump(self.records, f)
            print(f"Successfully saved semantic index with {len(self.records)} records")
        except Exception as e:
            print(f"Error saving semantic index: {e}")

    def retrieve(self, query: str, time_info: Dict = None, country: str = None,
                product_type: str = None, k: int = 20) -> List[Dict]:
        """Semantic search with filtering"""
        try:
            # Get query embedding
            query_embedding = self.embed_model.encode([query])

            # Perform semantic search
            D, I = self.faiss_index.search(query_embedding, k * 2)  # Get more candidates for filtering

            # Get candidate records
            candidates = [self.records[i] for i in I[0]]

            # Apply filters
            filtered_records = []
            for record in candidates:
                # Time filter
                if time_info and time_info.get('time_range'):
                    start_date, end_date = time_info['time_range']
                    if not (start_date <= record['date'] <= end_date):
                        continue

                # Country filter
                if country and record['origin'].lower() != country.lower():
                    continue

                # Product filter
                if product_type:
                    product_keywords = PRODUCT_KEYWORDS.get(product_type, [])
                    if not any(keyword in record['description'].lower() for keyword in product_keywords):
                        continue

                filtered_records.append(record)
                if len(filtered_records) >= k:
                    break

            print(f"Found {len(filtered_records)} relevant records after filtering")
            return filtered_records

        except Exception as e:
            print(f"Error during retrieval: {e}")
            return []

    def load_data(self, csv_path: str, force_rebuild: bool = False):
        """Load data and build semantic search index"""
        print("Loading trade data...")
        self.df = pd.read_csv(csv_path, low_memory=False)

        # Convert dates and clean data
        self.df['DATE'] = pd.to_datetime(self.df['DATE'], format='mixed', errors='coerce')
        self.df = self.df.dropna(subset=['DATE'])

        # Convert records to list of dicts for RAG
        self.records = []
        for _, row in tqdm(self.df.iterrows(), desc="Processing records", total=len(self.df)):
            record = {
                'date': row['DATE'],
                'description': str(row['DESCRIPTION']),
                'origin': str(row['ORIGIN']),
                'exporter': str(row['EXPORTER']),
                'importer': str(row['IMPORTER']),
                'quantity': float(row['QTY']),
                'value_pkr': float(row['VALUE (PKR)']),
                'unit': str(row['UNIT'])
            }
            self.records.append(record)

        print(f"Loaded {len(self.records)} records")

        # Build or load semantic search index
        self._build_semantic_index(force_rebuild=force_rebuild)

    def _create_record_text(self, row: pd.Series) -> str:
        """Create rich text representation for embedding"""
        return f"""
        Trade Record:
        Product: {row['DESCRIPTION']}
        Origin Country: {row['ORIGIN']}
        Date: {row['DATE']}
        Quantity: {row['QTY']} {row['UNIT']}
        Value: {row['VALUE (PKR)']} PKR
        Exporter: {row['EXPORTER']}
        Importer: {row['IMPORTER']}
        """

    def _update_indices(self, record: Dict, idx: int):
        """Update all index structures"""
        # Time index
        year = record['date'].year
        month = record['date'].month
        if year not in self.time_index:
            self.time_index[year] = {}
        if month not in self.time_index[year]:
            self.time_index[year][month] = []
        self.time_index[year][month].append(idx)

        # Country index
        country = record['origin'].lower()
        if country not in self.country_index:
            self.country_index[country] = []
        self.country_index[country].append(idx)

        # Product index
        desc = record['description'].lower()
        for product_type, keywords in PRODUCT_KEYWORDS.items():
            if any(keyword in desc for keyword in keywords):
                if product_type not in self.product_index:
                    self.product_index[product_type] = []
                self.product_index[product_type].append(idx)

    def format_results(self, results: List[Dict]) -> Dict:
        """Format retrieval results for analysis"""
        if not results:
            print("No results to format")
            return None

        try:
            # Compute aggregations
            total_value = sum(record['value_pkr'] for record in results)
            total_quantity = sum(record['quantity'] for record in results)

            # Group by origin
            origin_values = {}
            for record in results:
                origin = record['origin']
                value = record['value_pkr']
                origin_values[origin] = origin_values.get(origin, 0) + value

            # Sort origins by value
            top_origins = sorted(origin_values.items(), key=lambda x: x[1], reverse=True)

            # Group by product description
            product_values = {}
            for record in results:
                desc = record['description']
                value = record['value_pkr']
                product_values[desc] = product_values.get(desc, 0) + value

            # Sort products by value
            top_products = sorted(product_values.items(), key=lambda x: x[1], reverse=True)

            # Format the results for analysis
            formatted = {
                'records': [{
                    'date': record['date'],
                    'description': record['description'],
                    'origin': record['origin'],
                    'exporter': record['exporter'],
                    'importer': record['importer'],
                    'quantity': record['quantity'],
                    'value_pkr': record['value_pkr'],
                    'unit': record['unit']
                } for record in results],
                'summary': {
                    'total_records': len(results),
                    'total_value_pkr': total_value,
                    'total_quantity': total_quantity,
                    'date_range': f"{min(r['date'] for r in results)} to {max(r['date'] for r in results)}",
                    'top_origins': top_origins[:5],
                    'top_products': top_products[:5]
                },
                'analysis': {
                    'average_value': total_value / len(results) if results else 0,
                    'origin_distribution': {
                        origin: (value, (value / total_value * 100) if total_value else 0)
                        for origin, value in origin_values.items()
                    },
                    'product_distribution': {
                        desc: (value, (value / total_value * 100) if total_value else 0)
                        for desc, value in product_values.items()
                    }
                }
            }

            # Print summary for debugging
            print("\nResults Summary:")
            print(f"Total Records: {formatted['summary']['total_records']}")
            print(f"Total Value: {formatted['summary']['total_value_pkr']:,.2f} PKR")
            print(f"Date Range: {formatted['summary']['date_range']}")

            print("\nTop Origins by Value:")
            for origin, value in formatted['summary']['top_origins']:
                percentage = (value / total_value * 100) if total_value else 0
                print(f"- {origin.capitalize()}: {value:,.2f} PKR ({percentage:.1f}%)")

            print("\nTop Products by Value:")
            for product, value in formatted['summary']['top_products']:
                percentage = (value / total_value * 100) if total_value else 0
                print(f"- {product.capitalize()}: {value:,.2f} PKR ({percentage:.1f}%)")

            return formatted

        except Exception as e:
            print(f"Error formatting results: {e}")
            return None

    def analyze_query(self, query: str, time_info: Dict = None, country: str = None,
                    product_type: str = None, k: int = 20) -> Dict:
        """Complete analysis pipeline"""
        try:
            # Retrieve relevant records
            results = self.retrieve(query, time_info, country, product_type, k)

            # Format results
            formatted = self.format_results(results)

            if not formatted:
                return None

            # Add query analysis
            formatted['query_info'] = {
                'original_query': query,
                'time_info': time_info,
                'country': country,
                'product_type': product_type
            }

            return formatted

        except Exception as e:
            print(f"Error in analysis pipeline: {e}")
            return None

class QueryParser:
    """Advanced query understanding and parsing"""

    def __init__(self):
        import spacy
        # Uncomment the following lines to enable GPU for spaCy
        # spacy.prefer_gpu()

        self.nlp = spacy.load("en_core_web_sm")
        # Common trade-related terms
        self.product_terms = {
            'rice': ['basmati', 'white rice', 'brown rice', 'long grain', 'broken rice'],
            'cotton': ['raw cotton', 'cotton yarn', 'cotton fabric'],
            'wheat': ['wheat flour', 'wheat grain'],
            'sugar': ['raw sugar', 'refined sugar', 'brown sugar']
        }
        self.time_terms = {
            'months': list(calendar.month_name)[1:] + list(calendar.month_abbr)[1:],
            'quarters': ['q1', 'q2', 'q3', 'q4', 'quarter 1', 'quarter 2', 'quarter 3', 'quarter 4'],
            'periods': ['year', 'month', 'week', 'day']
        }

        # Time-related patterns
        self.month_map = {month.lower(): num for num, month in enumerate(calendar.month_name) if month}
        self.quarter_map = {
            'q1': [1, 2, 3],
            'q2': [4, 5, 6],
            'q3': [7, 8, 9],
            'q4': [10, 11, 12]
        }

        # Product categories and their variations
        self.product_patterns = {
            'rice': {
                'keywords': ['rice', 'basmati', 'irri', 'sella', 'paddy', 'chawal', 'برنج', 'چاول'],
                'variants': {
                    'basmati': ['super basmati', 'basmati sella', 'brown basmati', 'premium basmati'],
                    'irri': ['irri-6', 'irri6', 'irri 6', 'irri-9', 'irri9', 'irri 9'],
                    'broken': ['broken rice', 'rice broken', '100% broken', 'double broken'],
                    'white': ['white rice', 'milled rice', 'polished rice']
                }
            },
            'cotton': {
                'keywords': ['cotton', 'yarn', 'textile', 'fabric', 'کپاس', 'روئی'],
                'variants': {
                    'raw': ['raw cotton', 'seed cotton', 'cotton lint'],
                    'processed': ['cotton yarn', 'cotton fabric', 'cotton textile']
                }
            }
            # Add other product categories as needed
        }

        # Trade action patterns
        self.trade_patterns = {
            'import': [
                r'import[s|ed|ing]?',
                r'bring[s|ing]? in',
                r'bought from',
                r'purchased from',
                r'receiving from',
                r'incoming',
                r'coming in'
            ],
            'export': [
                r'export[s|ed|ing]?',
                r'send[s|ing]? to',
                r'sold to',
                r'shipping to',
                r'outgoing',
                r'going to'
            ]
        }

        # Time expressions
        self.time_patterns = {
            'explicit_date': [
                r'(?P<month>january|february|march|april|may|june|july|august|september|october|november|december)\s+(?P<year>20\d{2})',
                r'(?P<quarter>q[1-4])\s+(?P<year>20\d{2})',
                r'(?P<year>20\d{2})\s+(?P<quarter>q[1-4])',
            ],
            'relative_date': [
                r'last (?P<unit>month|quarter|year)',
                r'past (?P<number>\d+) (?P<unit>month|quarter|year)s?',
                r'previous (?P<unit>month|quarter|year)',
                r'current (?P<unit>month|quarter|year)',
                r'this (?P<unit>month|quarter|year)'
            ],
            'year_only': r'\b20\d{2}\b',
            'month_only': '|'.join(calendar.month_name[1:] + [m[:3] for m in calendar.month_name[1:]])
        }

    def _extract_time_info(self, query: str) -> Dict:
        """Parse query into structured components"""
        query_lower = query.lower()

        time_info = {
            'year': None,
            'month': None,
            'quarter': None,
            'time_range': None,
            'is_relative': False
        }

        # Match quarters (e.g., "q1 2024", "Q2")
        quarter_pattern = r'q([1-4])(?:\s*20)?(\d{2,4})?'
        quarter_match = re.search(quarter_pattern, query_lower)

        # Match months (e.g., "april 2024", "jan", "december 23")
        month_pattern = r'(jan(?:uary)?|feb(?:ruary)?|mar(?:ch)?|apr(?:il)?|may|jun(?:e)?|jul(?:y)?|aug(?:ust)?|sep(?:tember)?|oct(?:ober)?|nov(?:ember)?|dec(?:ember)?)\s*(?:20)?(\d{2,4})?'
        month_match = re.search(month_pattern, query_lower)

        if quarter_match:
            quarter = int(quarter_match.group(1))
            year = quarter_match.group(2)

            if year:
                year = int('20' + year if len(year) == 2 else year)
            else:
                year = datetime.now().year

            time_info['year'] = year
            time_info['quarter'] = quarter

            start_month = 3 * quarter - 2
            end_month = 3 * quarter
            time_info['time_range'] = (
                datetime(year, start_month, 1),
                datetime(year, end_month, calendar.monthrange(year, end_month)[1])
            )

        elif month_match:
            month_str = month_match.group(1)
            year = month_match.group(2)

            # Convert month string to number
            month_map = {
                'jan': 1, 'january': 1,
                'feb': 2, 'february': 2,
                'mar': 3, 'march': 3,
                'apr': 4, 'april': 4,
                'may': 5,
                'jun': 6, 'june': 6,
                'jul': 7, 'july': 7,
                'aug': 8, 'august': 8,
                'sep': 9, 'september': 9,
                'oct': 10, 'october': 10,
                'nov': 11, 'november': 11,
                'dec': 12, 'december': 12
            }

            month = month_map[month_str[:3]]

            if year:
                year = int('20' + year if len(year) == 2 else year)
            else:
                year = datetime.now().year

            time_info['year'] = year
            time_info['month'] = month
            time_info['time_range'] = (
                datetime(year, month, 1),
                datetime(year, month, calendar.monthrange(year, month)[1])
            )

        # Rest of the parsing code remains the same...

        return time_info

    def _extract_trade_type(self, query: str) -> str:
        """Intelligently determine trade direction from query"""
        query = query.lower()

        # Check explicit patterns
        for trade_type, patterns in self.trade_patterns.items():
            if any(re.search(pattern, query) for pattern in patterns):
                return trade_type

        # Context-based inference
        doc = self.nlp(query)

        # Look for directional prepositions and their objects
        for token in doc:
            if token.dep_ == 'prep' and token.text in ['to', 'from']:
                if token.text == 'from':
                    return 'import'
                else:
                    return 'export'

        # Default to 'export' if direction is ambiguous
        return 'export'

    def _extract_product_info(self, query: str) -> Dict:
        """Extract detailed product information with variants"""
        query = query.lower()
        product_info = {
            'main_category': None,
            'subcategory': None,
            'variants': set()
        }

        # Check each product category
        for category, patterns in self.product_patterns.items():
            # Check main keywords
            if any(keyword in query for keyword in patterns['keywords']):
                product_info['main_category'] = category

                # Check variants
                for subcategory, variant_patterns in patterns['variants'].items():
                    if any(variant in query for variant in variant_patterns):
                        product_info['subcategory'] = subcategory
                        product_info['variants'].add(next(variant for variant in variant_patterns if variant in query))

                break

        return product_info

    def _extract_location_info(self, query: str) -> Dict:
        """Extract location information with context"""
        doc = self.nlp(query)
        location_info = {
            'countries': set(),
            'regions': set(),
            'context': None  # 'source', 'destination', or None
        }

        for ent in doc.ents:
            if ent.label_ in ['GPE', 'LOC']:
                # Try to determine if it's a source or destination
                context = None
                for token in ent.subtree:
                    if token.text in ['from', 'in']:
                        context = 'source'
                    elif token.text in ['to', 'for']:
                        context = 'destination'

                if context:
                    location_info['context'] = context

                if ent.label_ == 'GPE':
                    location_info['countries'].add(ent.text.lower())
                else:
                    location_info['regions'].add(ent.text.lower())

        return location_info

    def parse_query(self, query: str) -> Dict:
        """Main parsing function that combines all extractors"""
        parsed = {
            'time_info': self._extract_time_info(query),
            'trade_type': self._extract_trade_type(query),
            'product_info': self._extract_product_info(query),
            'location_info': self._extract_location_info(query),
            'original_query': query
        }

        # Add confidence scores
        parsed['confidence'] = self._calculate_confidence(parsed)

        # Add query interpretation
        parsed['interpretation'] = self._generate_interpretation(parsed)

        return parsed

    def _calculate_confidence(self, parsed: Dict) -> Dict:
        """Calculate confidence scores for each component"""
        confidence = {
            'time': 0.0,
            'trade_type': 0.0,
            'product': 0.0,
            'location': 0.0,
            'overall': 0.0
        }

        # Time confidence
        if parsed['time_info']['time_range']:
            confidence['time'] = 1.0
        elif parsed['time_info']['year']:
            confidence['time'] = 0.7

        # Trade type confidence
        if parsed['trade_type']:
            confidence['trade_type'] = 0.8

        # Product confidence
        if parsed['product_info']['main_category']:
            confidence['product'] = 0.7
            if parsed['product_info']['subcategory']:
                confidence['product'] = 0.9

        # Location confidence
        if parsed['location_info']['countries']:
            confidence['location'] = 0.9
        elif parsed['location_info']['regions']:
            confidence['location'] = 0.7

        # Calculate overall confidence
        weights = {'time': 0.3, 'trade_type': 0.2, 'product': 0.25, 'location': 0.25}
        confidence['overall'] = sum(confidence[k] * weights[k] for k in weights)

        return confidence

    def _generate_interpretation(self, parsed: Dict) -> str:
        """Generate human-readable interpretation of the query"""
        parts = []

        # Add trade type
        if parsed['trade_type']:
            parts.append(f"Looking for {parsed['trade_type']}s")

        # Add product info
        if parsed['product_info']['main_category']:
            product_desc = parsed['product_info']['main_category']
            if parsed['product_info']['subcategory']:
                product_desc += f" ({parsed['product_info']['subcategory']})"
            parts.append(f"of {product_desc}")

        # Add location info
        if parsed['location_info']['countries']:
            countries = ', '.join(parsed['location_info']['countries'])
            context = f"{parsed['location_info']['context']} " if parsed['location_info']['context'] else ""
            parts.append(f"{context}{countries}")

        # Add time info
        if parsed['time_info']['time_range']:
            start, end = parsed['time_info']['time_range']
            parts.append(f"during {start.strftime('%B %Y')} to {end.strftime('%B %Y')}")

        return ' '.join(parts).capitalize()

class EnhancedTradeAnalysisSystem:
    def __init__(self):
        print("Initializing Enhanced Trade Analysis System...")
        self._setup_logging()
        self.model_context_length = MAX_LENGTH
        self.max_input_length = 1500

        # Create directories if they don't exist
        for path in [BASE_PATH, CHART_STORE_PATH]:
            os.makedirs(path, exist_ok=True)

        # Initialize components
        self.query_parser = QueryParser()
        self.model, self.tokenizer = self._init_model()

        # Initialize IntelligentRAG
        self.rag = IntelligentRAG(BASE_PATH)

        try:
            # Load data and build indices
            self.rag.load_data(CSV_PATH)
            print("Successfully loaded all indices!")
        except Exception as e:
            self.logger.error(f"Error loading data: {e}")
            raise

        # Enhanced system prompt
        self.system_prompt = """You are an expert trade data analyst. Analyze the provided trade data and focus on:
1. Import/Export patterns and trends
2. Key trading partners and their significance
3. Product-specific insights and market dynamics
4. Time-based trends and seasonality
5. Value chain analysis and pricing patterns

Provide analysis in clear, structured format with:
- Key findings and insights
- Specific calculations and metrics
- Market implications and recommendations
- Notable patterns or anomalies

Base all analysis strictly on the provided data. If data is insufficient, clearly state limitations."""

    def _setup_logging(self):
        """Initialize logging configuration"""
        log_filename = f'{BASE_PATH}/trade_analysis_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
            handlers=[
                logging.StreamHandler(),
                logging.FileHandler(log_filename)
            ]
        )
        self.logger = logging.getLogger('TradeAnalysis')

    def _init_model(self):
        """Initialize the language model using checkpoint-45"""
        self.logger.info("Initializing language model from checkpoint-45...")
        try:
            # Load the model and tokenizer from the specified checkpoint
            model, tokenizer = FastLanguageModel.from_pretrained(
                model_name=CHECKPOINT_PATH,
                max_seq_length=self.model_context_length,
                load_in_4bit=True,  # Adjust based on your model's requirements
                trust_remote_code=True  # Set to True if the model uses custom code
            )

            # Prepare the model for inference
            model = FastLanguageModel.for_inference(model)

            self.logger.info("Language model initialized successfully from checkpoint-45.")
            return model, tokenizer
        except Exception as e:
            self.logger.error(f"Model initialization error: {e}", exc_info=True)
            raise

    def _create_document(self, row: pd.Series) -> Tuple[str, Dict]:
        """Create a document from a dataframe row"""
        try:
            text = f"{row['DESCRIPTION']} from {row['ORIGIN']} on {row['DATE']}"
            metadata = {
                'date': row['DATE'],
                'description': row['DESCRIPTION'],
                'origin': row['ORIGIN'],
                'value': row['VALUE (PKR)']
            }
            return text, metadata
        except Exception as e:
            self.logger.error(f"Error creating document from row: {e}", exc_info=True)
            return None, None

    def analyze(self, query: str):
        """Analyze trade data using IntelligentRAG"""
        try:
            # Parse query
            parsed_query = self.query_parser.parse_query(query)
            print("\nParsed Query Components:")
            print(f"Time Info: {parsed_query.get('time_info')}")
            print(f"Location Info: {parsed_query.get('location_info')}")
            print(f"Product Info: {parsed_query.get('product_info')}")

            # Use IntelligentRAG for retrieval
            retrieved_records = self.rag.retrieve(
                query=query,
                time_info=parsed_query['time_info'],
                country=next(iter(parsed_query['location_info']['countries']), None),
                product_type=parsed_query['product_info']['main_category']
            )

            if not retrieved_records:
                print("No relevant data found for the query.")
                return None

            # Format the retrieved results
            formatted_results = self.rag.format_results(retrieved_records)

            if not formatted_results:
                return None

            # Generate analysis prompt
            analysis_prompt = self._create_analysis_prompt(query, formatted_results['summary'])

            # Generate AI analysis
            ai_analysis = self._generate_analysis(analysis_prompt)

            # Combine everything into final result
            result = {
                'raw_data': formatted_results['records'],
                'summary': formatted_results['summary'],
                'ai_analysis': ai_analysis,
                'charts': self.create_visualizations(pd.DataFrame(formatted_results['records']))
            }

            return result

        except Exception as e:
            self.logger.error(f"Error during analysis: {e}", exc_info=True)
            print(f"Error during analysis: {str(e)}")
            return None

    def _create_analysis_prompt(self, query: str, summary: Dict) -> str:
        """Create analysis prompt using summary data in the training format"""
        formatted_data = {
            "analysis_request": query,
            "trade_data": {
                "monetary_metrics": {
                    "total_trade_value_pkr": summary['total_value_pkr'],
                    "total_quantity": summary['total_quantity']
                },
                "time_period": str(summary['date_range']),
                "product_analysis": [
                    {"product": p, "value_pkr": v, "share": (v/summary['total_value_pkr'])*100}
                    for p, v in summary['top_products']
                ],
                "geographical_distribution": [
                    {"country": o, "value_pkr": v, "share": (v/summary['total_value_pkr'])*100}
                    for o, v in summary['top_origins']
                ]
            }
        }

        prompt = f"""<|im_start|>system
You are a mathematical trade analysis expert. Your task is to analyze trade data and provide precise, data-driven insights. Focus solely on the provided data - do not make assumptions about data not shown. Use exact numbers and percentages from the data. Do not engage in follow-up discussions or ask questions.
<|im_end|>
<|im_start|>human
Analyze the following trade data to answer this specific question: {query}

The complete dataset for analysis:
{json.dumps(formatted_data, indent=2)}

Provide a single, comprehensive analysis based only on this data. Include relevant calculations and percentages to support your insights. Do not speculate beyond the provided information.
<|im_end|>
<|im_start|>assistant"""

        return prompt

    def _generate_analysis(self, prompt: str) -> str:
        """Generate analysis using the FastLanguageModel"""
        try:
            inputs = self.tokenizer(
                prompt,
                return_tensors="pt",
                truncation=True,
                max_length=self.max_input_length
            ).to("cuda" if torch.cuda.is_available() else "cpu")

            # Generate response using the model
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=1000,
                temperature=0.7,
                top_p=0.95,
                do_sample=True,
                num_beams=1,  # Set to 1 to avoid beam search
                repetition_penalty=1.2,
                no_repeat_ngram_size=3,
                pad_token_id=self.tokenizer.eos_token_id,
                eos_token_id=self.tokenizer.eos_token_id
            )

            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

            # Clean up response
            response = self._clean_response(response)

            return response

        except Exception as e:
            self.logger.error(f"Error generating analysis: {e}", exc_info=True)
            # Fallback response if generation fails
            sample_records = self.df.head(5).to_dict(orient='records')
            fallback = "I apologize, but I encountered an error while generating the analysis. Here are some sample records:\n"
            for record in sample_records:
                date_str = record['DATE'].strftime('%Y-%m-%d') if isinstance(record['DATE'], pd.Timestamp) else record['DATE']
                fallback += f"- {date_str}: {record['DESCRIPTION']} from {record['ORIGIN']}\n"
            return fallback

    def _clean_response(self, response: str) -> str:
        """Clean and format the model's response"""
        # Remove system/human prompts if present
        if "system" in response.lower():
            try:
                response = response.split("assistant")[-1].strip()
            except:
                pass

        # Remove markdown code blocks
        response = re.sub(r'```json\s*|\s*```', '', response)

        # Clean up whitespace
        response = re.sub(r'\s+', ' ', response).strip()

        return response

    def _prepare_analysis_data(self, df: pd.DataFrame) -> Dict:
        """Prepare comprehensive analysis data from the DataFrame"""
        try:
            # Extract basic metrics
            values = df['VALUE (PKR)'].astype(float).tolist()
            dates = df['date'].dropna().tolist()
            origins = df['origin'].astype(str).str.lower().tolist()
            products = df['description'].astype(str).str.lower().tolist()

            # Calculate aggregates
            analysis_data = {
                'total_records': len(df),
                'total_value': sum(values),
                'avg_value': sum(values) / len(values) if values else 0,
                'date_range': f"{min(dates).strftime('%Y-%m-%d') if not pd.isnull(min(dates)) else 'Unknown'} to {max(dates).strftime('%Y-%m-%d') if not pd.isnull(max(dates)) else 'Unknown'}",
                'top_origins': self._get_top_items(origins, values),
                'top_products': self._get_top_items(products, values)
            }

            return analysis_data

        except Exception as e:
            self.logger.error(f"Error preparing analysis data: {e}", exc_info=True)
            return self._get_default_analysis_data()

    def _get_top_items(self, items: List[str], values: List[float], top_n: int = 5) -> List[Tuple[str, float]]:
        """Get top items by value with proper aggregation"""
        item_values = {}
        for item, value in zip(items, values):
            item_values[item] = item_values.get(item, 0) + value

        return sorted(
            item_values.items(),
            key=lambda x: x[1],
            reverse=True
        )[:top_n]

    def _get_default_analysis_data(self) -> Dict:
        """Return default analysis data structure"""
        return {
            'total_records': 0,
            'total_value': 0,
            'avg_value': 0,
            'date_range': 'Unknown',
            'top_origins': [],
            'top_products': []
        }

    def create_visualizations(self, df: pd.DataFrame) -> Dict[str, str]:
        """Create visualizations with proper column handling"""
        try:
            charts = {}

            if df.empty:
                return charts

            # Convert date column if exists (handle both 'DATE' and 'date')
            date_col = 'date' if 'date' in df.columns else 'DATE' if 'DATE' in df.columns else None
            if date_col:
                df['date'] = pd.to_datetime(df[date_col], errors='coerce')
            else:
                print("Warning: No date column found for time series visualization")
                return charts

            # Time series chart
            fig = go.Figure()
            df_sorted = df.sort_values('date')

            # Value over time
            value_col = 'value_pkr' if 'value_pkr' in df.columns else 'VALUE (PKR)' if 'VALUE (PKR)' in df.columns else None
            if value_col:
                fig.add_trace(go.Scatter(
                    x=df_sorted['date'],
                    y=df_sorted[value_col],
                    mode='lines+markers',
                    name='Trade Value'
                ))

                fig.update_layout(
                    title='Trade Value Over Time',
                    xaxis_title='Date',
                    yaxis_title='Value (PKR)',
                    template='plotly_dark'
                )
                charts['time_series'] = fig

            # Save charts
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            for name, fig in charts.items():
                try:
                    fig.write_html(f"{self.chart_store_path}/{name}_{timestamp}.html")
                except Exception as e:
                    print(f"Warning: Could not save chart {name}: {e}")

            return charts

        except Exception as e:
            self.logger.error(f"Error creating visualizations: {e}", exc_info=True)
            return {}

    def filter_trade_data(self, parsed_query: Dict) -> pd.DataFrame:
        """Filter the DataFrame based on parsed query components"""
        df_filtered = self.df.copy()

        # Print initial shape for debugging
        print(f"\nInitial records: {len(df_filtered)}")

        # Filter by date range first
        time_info = parsed_query.get('time_info', {})
        time_range = time_info.get('time_range')
        if time_range:
            start_date, end_date = time_range
            # Convert dates to pandas datetime
            df_filtered['DATE'] = pd.to_datetime(df_filtered['DATE'])
            mask = (df_filtered['DATE'] >= start_date) & (df_filtered['DATE'] <= end_date)
            df_filtered = df_filtered[mask]
            print(f"After date filter: {len(df_filtered)} records")

        # Filter by location
        location_info = parsed_query.get('location_info', {})
        countries = list(location_info.get('countries', []))
        if countries:
            # Case-insensitive country matching
            country_mask = df_filtered['ORIGIN'].str.lower().isin([c.lower() for c in countries])
            df_filtered = df_filtered[country_mask]
            print(f"After location filter: {len(df_filtered)} records")

        # Filter by product
        product_info = parsed_query.get('product_info', {})
        if product_info.get('main_category') in PRODUCT_KEYWORDS:
            product_type = product_info.get('main_category')
            # Comprehensive product-related terms based on PRODUCT_KEYWORDS
            product_terms = PRODUCT_KEYWORDS.get(product_type, [])
            # Create pattern for matching any product term
            pattern = '|'.join([re.escape(term) for term in product_terms])
            product_mask = df_filtered['DESCRIPTION'].str.lower().str.contains(pattern, na=False)
            df_filtered = df_filtered[product_mask]
            print(f"After product filter: {len(df_filtered)} records")

        # Sort by date and value
        df_filtered = df_filtered.sort_values(['DATE', 'VALUE (PKR)'], ascending=[True, False])

        # Print example records for debugging
        if not df_filtered.empty:
            print("\nSample matches:")
            sample = df_filtered.head(3)
            for _, row in sample.iterrows():
                date_str = row['DATE'].strftime('%Y-%m-%d') if isinstance(row['DATE'], pd.Timestamp) else row['DATE']
                print(f"Date: {date_str}, Origin: {row['ORIGIN']}, Description: {row['DESCRIPTION']}")

        return df_filtered

    def load_data(self, csv_path: str):
        """Load and preprocess trade data"""
        self.df = pd.read_csv(csv_path)

        # Convert DATE column to datetime
        self.df['DATE'] = pd.to_datetime(self.df['DATE'])

        # Create a clean date column without time
        self.df['date_clean'] = self.df['DATE'].dt.date

        # Ensure DESCRIPTION is string type
        self.df['DESCRIPTION'] = self.df['DESCRIPTION'].astype(str)

    def filter_records(self, time_info: Dict, product_info: Dict, location_info: Dict) -> pd.DataFrame:
        """Filter records based on query components"""
        filtered_df = self.df.copy()

        # Time filtering
        if time_info.get('time_range'):
            start_date, end_date = time_info['time_range']
            filtered_df = filtered_df[
                (filtered_df['DATE'] >= start_date) &
                (filtered_df['DATE'] <= end_date)
            ]

        # Product filtering
        if product_info.get('main_category'):
            category_keywords = PRODUCT_KEYWORDS[product_info['main_category']]
            product_mask = filtered_df['DESCRIPTION'].str.lower().apply(
                lambda x: any(keyword in x.lower() for keyword in category_keywords)
            )
            filtered_df = filtered_df[product_mask]

        # Location filtering
        if location_info.get('countries'):
            filtered_df = filtered_df[
                filtered_df['ORIGIN'].str.lower().isin(
                    [c.lower() for c in location_info['countries']]
                )
            ]

        return filtered_df

def main():
    try:
        # Initialize system
        system = EnhancedTradeAnalysisSystem()

        print("\nEnhanced Trade Analysis System Ready!")
        print("Enter your queries about trade data. The system will provide detailed analysis.")
        print("Type 'quit' to exit.")

        while True:
            query = input("\nEnter your analysis question: ").strip()

            if query.lower() == 'quit':
                print("Exiting Enhanced Trade Analysis System. Goodbye!")
                break

            if not query:
                print("Please enter a valid query.")
                continue

            try:
                result = system.analyze(query)

                if result is None:
                    continue

                # Print retrieved records (limiting to first 5 for brevity)
                print("\nRetrieved Records:")
                print("="*80)

                for i, record in enumerate(result['raw_data'][:5], 1):
                    date_str = record['date'].strftime('%Y-%m-%d') if isinstance(record['date'], pd.Timestamp) else record['date']
                    print(f"\nRecord {i}:")
                    print(f"Date: {date_str}")
                    print(f"Product: {record['description']}")
                    print(f"Origin: {record['origin']}")
                    print(f"Value: {float(record['value_pkr']):,.2f} PKR")
                    print("-"*40)

                # Print summary stats
                metadata = result['summary']
                print("\nSummary Statistics:")
                print("="*80)
                print(f"Total Records: {metadata['total_records']}")
                print(f"Total Value: {metadata['total_value_pkr']:,.2f} PKR")
                print(f"Date Range: {metadata['date_range']}")

                print("\nTop Origins by Value:")
                for origin, value in metadata['top_origins']:
                    percentage = (value / metadata['total_value_pkr'] * 100) if metadata['total_value_pkr'] else 0
                    print(f"- {origin.capitalize()}: {value:,.2f} PKR ({percentage:.1f}%)")

                print("\nTop Products by Value:")
                for product, value in metadata['top_products']:
                    percentage = (value / metadata['total_value_pkr'] * 100) if metadata['total_value_pkr'] else 0
                    print(f"- {product.capitalize()}: {value:,.2f} PKR ({percentage:.1f}%)")

                # Print AI analysis
                print("\nAI Analysis:")
                print("="*80)
                print(result['ai_analysis'])

                # Print visualization info
                if result.get('charts'):
                    print("\nVisualizations have been saved to:", CHART_STORE_PATH)

            except Exception as e:
                print(f"Error during analysis: {str(e)}")

    except Exception as e:
        print(f"System initialization error: {str(e)}")
        logging.error(f"System initialization error: {str(e)}")

if __name__ == "__main__":
    main()



🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
Initializing Enhanced Trade Analysis System...
Are you certain you want to do remote code execution?
==((====))==  Unsloth 2024.10.7: Fast Qwen2 patching. Transformers = 4.44.2.
   \\   /|    GPU: Tesla T4. Max memory: 14.748 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.5.0+cu121. CUDA = 7.5. CUDA Toolkit = 12.1.
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.28.post2. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Unsloth: We fixed a gradient accumulation bug, but it seems like you don't have the latest transformers version!
Please update transformers, TRL and unsloth via:
`pip install --upgrade --no-cache-dir unsloth git+https://github.com/huggingface/transformers.git git+https://github.com/huggingface/trl.git`
Unsloth 2024.10.7 patched 28 layers with 0 QKV layers, 28 O layers and 28 MLP layers.


Loading trade data...


Processing records:   0%|          | 0/41398 [00:00<?, ?it/s]

Loaded 41398 records
Building new semantic index...


Creating text representations:   0%|          | 0/41398 [00:00<?, ?it/s]

Generating embeddings...


Generating embeddings:   0%|          | 0/1294 [00:00<?, ?it/s]

Building FAISS index...
Saving semantic index...
Successfully saved semantic index with 41398 records
Successfully loaded all indices!

Enhanced Trade Analysis System Ready!
Enter your queries about trade data. The system will provide detailed analysis.
Type 'quit' to exit.

Enter your analysis question: give me the main trends of rice products in q1 2024

Parsed Query Components:
Time Info: {'year': 2024, 'month': None, 'quarter': 1, 'time_range': (datetime.datetime(2024, 1, 1, 0, 0), datetime.datetime(2024, 3, 31, 0, 0)), 'is_relative': False}
Location Info: {'countries': set(), 'regions': set(), 'context': None}
Product Info: {'main_category': 'rice', 'subcategory': None, 'variants': set()}
Found 15 relevant records after filtering

Results Summary:
Total Records: 15
Total Value: 1,782,931,016.20 PKR
Date Range: 2024-01-05 00:00:00 to 2024-03-30 00:00:00

Top Origins by Value:
- Indonesia: 895,160,000.00 PKR (50.2%)
- United kingdom: 336,920,000.00 PKR (18.9%)
- Cameroon: 169,921,18

In [2]:
!pip install --upgrade unsloth transformers tiktoken

Collecting transformers
  Using cached transformers-4.46.0-py3-none-any.whl.metadata (44 kB)
Collecting tiktoken
  Downloading tiktoken-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.6 kB)
Downloading tiktoken-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m17.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tiktoken
Successfully installed tiktoken-0.8.0


In [4]:
pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121

Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu121


In [1]:
!pip install drive
from google.colab import drive
drive.mount('/content/drive2', force_remount=True)

Mounted at /content/drive2
