In [None]:
import pandas as pd
import requests
import json
from datetime import datetime, timedelta
from typing import Dict, List, Any
import re
import calendar

class ContractDataImporter:
    def __init__(self, api_base_url: str = "http://localhost:3000/api/contracts"):
        self.api_base_url = api_base_url
        self.contracts_cache = {}
        
    def load_csv(self, filepath: str) -> pd.DataFrame:
        """Load and parse the CSV file"""
        try:
            df = pd.read_csv(filepath)
            print(f"Loaded CSV with {len(df)} rows and columns: {list(df.columns)}")
            return df
        except Exception as e:
            print(f"Error loading CSV: {e}")
            return None
    
    def clean_and_validate_data(self, df: pd.DataFrame) -> pd.DataFrame:
        """Clean and validate the CSV data"""
        # Clean column names by stripping whitespace
        df.columns = df.columns.str.strip()
        
        # Print actual columns for debugging
        print(f"Actual CSV columns (after cleaning): {list(df.columns)}")
        
        # Map columns to find the right ones
        column_mapping = {}
        
        # Find Volume column (could have spaces or be at the end)
        volume_cols = [col for col in df.columns if 'volume' in col.lower() and 'cy' not in col.lower()]
        if volume_cols:
            column_mapping['Volume'] = volume_cols[0]
            print(f"Found Volume column: '{volume_cols[0]}'")
        
        # Find other columns
        required_base_columns = ['New_deal_name', 'State', 'type', 'Year', 'Month', 'Price_CY']
        
        for req_col in required_base_columns:
            matching_cols = [col for col in df.columns if req_col.lower() == col.lower()]
            if matching_cols:
                column_mapping[req_col] = matching_cols[0]
            else:
                print(f"Warning: Could not find column matching '{req_col}'")
        
        # Check if we found all required columns
        required_columns = ['New_deal_name', 'State', 'type', 'Year', 'Month', 'Volume', 'Price_CY']
        missing_cols = [col for col in required_columns if col not in column_mapping]
        
        if missing_cols:
            print(f"Missing columns after mapping: {missing_cols}")
            # Try to find similar column names for missing columns
            for missing_col in missing_cols:
                similar_cols = [col for col in df.columns if missing_col.lower() in col.lower()]
                if similar_cols:
                    print(f"  Similar columns found for '{missing_col}': {similar_cols}")
                    # Auto-map if there's only one similar column
                    if len(similar_cols) == 1:
                        column_mapping[missing_col] = similar_cols[0]
                        print(f"    Auto-mapped '{missing_col}' to '{similar_cols[0]}'")
            
            # Check again after auto-mapping
            missing_cols = [col for col in required_columns if col not in column_mapping]
            if missing_cols:
                print(f"Still missing columns: {missing_cols}")
                return pd.DataFrame()  # Return empty dataframe if critical columns missing
        
        # Clean data
        df_clean = df.copy()
        
        # Rename columns to standard names for easier processing
        df_clean = df_clean.rename(columns={v: k for k, v in column_mapping.items()})
        
        # Convert Volume to numeric, handling commas and dashes
        if 'Volume' in df_clean.columns:
            print(f"Processing Volume column...")
            df_clean['Volume'] = df_clean['Volume'].astype(str).str.replace(',', '').str.replace('-', '0').str.strip()
            df_clean['Volume'] = pd.to_numeric(df_clean['Volume'], errors='coerce').fillna(0)
            print(f"Sample Volume data: {df_clean['Volume'].head().tolist()}")
        
        # Convert Price_CY to numeric
        if 'Price_CY' in df_clean.columns:
            df_clean['Price_CY'] = df_clean['Price_CY'].astype(str).str.replace(',', '').str.strip()
            df_clean['Price_CY'] = pd.to_numeric(df_clean['Price_CY'], errors='coerce').fillna(0)
        
        # Remove rows with missing deal names
        df_clean = df_clean.dropna(subset=['New_deal_name'])
        
        print(f"Cleaned data: {len(df_clean)} rows remaining")
        return df_clean
    
    def parse_date_from_row(self, row) -> str:
        """Parse date from Year and Month columns"""
        try:
            year = int(row['Year'])
            month = int(row['Month'])
            # Create first day of the month as timestamp
            date = datetime(year, month, 1)
            return date.isoformat()
        except:
            return datetime.now().isoformat()
    
    def get_hours_in_period(self, year: int, month: int) -> int:
        """Get the number of hours in a given month"""
        days_in_month = calendar.monthrange(year, month)[1]
        return days_in_month * 24
    
    def convert_volume_for_contract_type(self, volume_mwh: float, contract_type: str, year: int, month: int) -> float:
        """Convert MWh to appropriate unit based on contract type"""
        if contract_type == 'wholesale':
            # Convert MWh to MW by dividing by hours in the period
            hours_in_period = self.get_hours_in_period(year, month)
            volume_mw = volume_mwh / hours_in_period
            return round(volume_mw, 3)  # Round to 3 decimal places for MW
        else:
            # Retail and offtake contracts keep MWh
            return round(volume_mwh, 0)  # Round to whole MWh
    
    def get_volume_unit(self, contract_type: str) -> str:
        """Get the appropriate volume unit based on contract type"""
        if contract_type == 'wholesale':
            return 'MW'
        else:
            return 'MWh'
    
    def map_contract_type(self, csv_type: str) -> str:
        """Map CSV type to contract type"""
        type_mapping = {
            'Retail': 'retail',
            'Wholesale': 'wholesale', 
            'Offtake': 'offtake'
        }
        return type_mapping.get(csv_type, 'retail')
    
    def map_contract_category(self, contract_type: str, deal_name: str, sub_type: str = None) -> str:
        """Determine contract category based on type, deal name, and sub_type"""
        if contract_type == 'retail':
            if 'Government' in deal_name or 'Gov' in deal_name:
                return 'Government Customer'
            elif 'Industrial' in deal_name or 'Mining' in deal_name:
                return 'Industrial Customer'
            else:
                return 'Retail Customer'
        elif contract_type == 'wholesale':
            # Use Sub_Type for wholesale contracts
            if sub_type and sub_type.strip():
                return sub_type.strip()
            else:
                return 'Swap'  # Default fallback
        else:  # offtake
            if 'Solar' in deal_name:
                return 'Solar Farm'
            elif 'Wind' in deal_name:
                return 'Wind Farm'
            else:
                return 'Solar Farm'
    
    def group_data_by_contract(self, df: pd.DataFrame) -> Dict[str, List[Dict]]:
        """Group data by unique deal name (aggregating all years) and unit type"""
        contracts_data = {}
        
        for _, row in df.iterrows():
            deal_name = row['New_deal_name']
            year = row['Year']
            unit = row.get('Unit', 'Energy')  # Default to Energy if Unit column not found
            sub_type = row.get('Sub_Type', '')  # Get Sub_Type for wholesale contracts
            contract_type = self.map_contract_type(row['type'])
            
            # Create unique key for contract (deal name + unit type)
            contract_key = f"{deal_name}_{unit}"
            
            if contract_key not in contracts_data:
                contracts_data[contract_key] = {
                    'deal_name': deal_name,
                    'unit': unit,
                    'sub_type': sub_type,
                    'state': row['State'],
                    'type': contract_type,
                    'buysell': row.get('buysell', 'Sell'),
                    'time_series_data': [],
                    'years_covered': set(),
                    'contract_info': {
                        'start_date': row.get('Start', ''),
                        'end_date': row.get('End', ''),
                        'time': row.get('Time', '')
                    }
                }
            
            # Add year to covered years
            contracts_data[contract_key]['years_covered'].add(year)
            
            # Convert volume based on contract type
            original_volume_mwh = float(row['Volume'])
            converted_volume = self.convert_volume_for_contract_type(
                original_volume_mwh, contract_type, year, row['Month']
            )
            
            # Add time series point
            timestamp = self.parse_date_from_row(row)
            period = f"{year}-{str(row['Month']).zfill(2)}"
            
            time_point = {
                'timestamp': timestamp,
                'volume': converted_volume,
                'price': round(float(row['Price_CY']), 2),  # Round to 2 decimal places
                'period': period,
                'periodType': 'monthly'
            }
            
            contracts_data[contract_key]['time_series_data'].append(time_point)
        
        # Convert years_covered set to sorted list for better display
        for contract_key in contracts_data:
            years_list = sorted(list(contracts_data[contract_key]['years_covered']))
            contracts_data[contract_key]['years_covered'] = years_list
            volume_unit = self.get_volume_unit(contracts_data[contract_key]['type'])
            print(f"Contract: {contracts_data[contract_key]['deal_name']} ({contracts_data[contract_key]['unit']}) - Type: {contracts_data[contract_key]['type']} - Unit: {volume_unit} - Years: {years_list}")
        
        print(f"Grouped data into {len(contracts_data)} unique contracts")
        return contracts_data
    
    def calculate_annual_volume(self, time_series: List[Dict], years_covered: List[int], contract_type: str) -> float:
        """Calculate volume for one full calendar year if data exists, otherwise return total"""
        if not time_series:
            return 0
        
        # Group time series by year
        yearly_volumes = {}
        for point in time_series:
            timestamp = datetime.fromisoformat(point['timestamp'])
            year = timestamp.year
            
            if year not in yearly_volumes:
                yearly_volumes[year] = {'volume': 0, 'months': set()}
            
            yearly_volumes[year]['volume'] += point['volume']
            yearly_volumes[year]['months'].add(timestamp.month)
        
        # Check if we have any complete year (12 months of data)
        complete_years = []
        for year, data in yearly_volumes.items():
            if len(data['months']) == 12:  # Full calendar year
                complete_years.append((year, data['volume']))
        
        if complete_years:
            # Return volume from the most recent complete year
            complete_years.sort(key=lambda x: x[0], reverse=True)  # Sort by year, newest first
            volume = complete_years[0][1]
        else:
            # No complete year found, return total volume for all available data
            volume = sum(point['volume'] for point in time_series)
        
        # Round appropriately based on contract type
        if contract_type == 'wholesale':
            return round(volume, 3)  # MW to 3 decimal places
        else:
            return round(volume, 0)  # MWh to whole numbers
    
    def clear_database(self) -> bool:
        """Clear all contracts from the database"""
        try:
            # First fetch all contracts to get their IDs
            response = requests.get(self.api_base_url)
            if response.status_code != 200:
                print(f"Error fetching contracts for deletion: {response.status_code}")
                return False
            
            contracts = response.json()
            print(f"Found {len(contracts)} contracts to delete...")
            
            deleted_count = 0
            error_count = 0
            
            for contract in contracts:
                contract_id = contract.get('_id') or contract.get('id')
                if contract_id:
                    try:
                        delete_response = requests.delete(f"{self.api_base_url}?id={contract_id}")
                        if delete_response.status_code == 200:
                            deleted_count += 1
                            print(f"  ✓ Deleted: {contract.get('name', 'Unknown')}")
                        else:
                            error_count += 1
                            print(f"  ✗ Error deleting {contract.get('name', 'Unknown')}: {delete_response.status_code}")
                    except Exception as e:
                        error_count += 1
                        print(f"  ✗ Error deleting {contract.get('name', 'Unknown')}: {e}")
                else:
                    error_count += 1
                    print(f"  ✗ No ID found for contract: {contract.get('name', 'Unknown')}")
            
            print(f"\nDatabase clearing summary:")
            print(f"  Deleted: {deleted_count}")
            print(f"  Errors: {error_count}")
            print(f"  Total processed: {len(contracts)}")
            
            return error_count == 0
            
        except Exception as e:
            print(f"Error clearing database: {e}")
            return False
    
    def fetch_existing_contracts(self) -> List[Dict]:
        """Fetch existing contracts from API"""
        try:
            response = requests.get(self.api_base_url)
            if response.status_code == 200:
                contracts = response.json()
                print(f"Fetched {len(contracts)} existing contracts")
                return contracts
            else:
                print(f"Error fetching contracts: {response.status_code}")
                return []
        except Exception as e:
            print(f"Error fetching contracts: {e}")
            return []
    
    def find_matching_contract(self, deal_name: str, existing_contracts: List[Dict]) -> Dict:
        """Find existing contract that matches the deal name"""
        for contract in existing_contracts:
            if contract['name'].lower() == deal_name.lower():
                return contract
        return None
    
    def create_new_contract(self, contract_data: Dict) -> Dict:
        """Create a new contract from grouped data"""
        deal_name = contract_data['deal_name']
        unit = contract_data['unit']
        sub_type = contract_data.get('sub_type', '')
        time_series = contract_data['time_series_data']
        years_covered = contract_data['years_covered']
        contract_type = contract_data['type']
        
        # Get appropriate volume unit
        volume_unit = self.get_volume_unit(contract_type)
        
        # Calculate aggregated values
        total_volume = sum(point['volume'] for point in time_series)
        avg_price = round(sum(point['price'] for point in time_series) / len(time_series), 2) if time_series else 0
        
        # Calculate annual volume - try to get one full calendar year if data exists
        annual_volume = self.calculate_annual_volume(time_series, years_covered, contract_type)
        
        # Determine start and end dates from time series
        timestamps = [datetime.fromisoformat(point['timestamp']) for point in time_series]
        start_date = min(timestamps).strftime('%Y-%m-%d') if timestamps else f"{min(years_covered)}-01-01"
        end_date = max(timestamps).strftime('%Y-%m-%d') if timestamps else f"{max(years_covered)}-12-31"
        
        # Create contract name that includes unit type
        contract_name = f"{deal_name} ({unit})" if unit != 'Energy' else deal_name
        
        new_contract = {
            'name': contract_name,
            'type': contract_type,
            'category': self.map_contract_category(contract_type, deal_name, sub_type),
            'state': contract_data['state'],
            'counterparty': deal_name,  # Use deal name as counterparty
            'startDate': start_date,
            'endDate': end_date,
            'annualVolume': annual_volume,  # Volume for one full CY if available, otherwise total for available data
            'strikePrice': avg_price,
            'unit': volume_unit,  # Use MW for wholesale, MWh for others
            'volumeShape': 'custom',
            'status': 'active',
            'indexation': 'Fixed',
            'referenceDate': start_date,
            'pricingType': 'timeseries',
            'timeSeriesData': time_series,
            'dataSource': 'csv_import',
            'yearsCovered': years_covered,
            'totalVolume': total_volume
        }
        
        return new_contract
    
    def update_contract_with_timeseries(self, contract_id: str, time_series_data: List[Dict]) -> bool:
        """Update existing contract with time series data using PATCH endpoint"""
        try:
            payload = {
                'action': 'updateTimeSeries',
                'contractId': contract_id,
                'timeSeriesData': time_series_data,
                'dataSource': 'csv_import'
            }
            
            response = requests.patch(self.api_base_url, json=payload)
            
            if response.status_code == 200:
                print(f"Successfully updated contract {contract_id} with time series data")
                return True
            else:
                print(f"Error updating contract {contract_id}: {response.status_code} - {response.text}")
                return False
        except Exception as e:
            print(f"Error updating contract {contract_id}: {e}")
            return False
    
    def create_contract_via_api(self, contract_data: Dict) -> str:
        """Create new contract via API and return contract ID"""
        try:
            response = requests.post(self.api_base_url, json=contract_data)
            
            if response.status_code == 201:
                new_contract = response.json()
                contract_id = new_contract.get('_id')
                print(f"Successfully created contract: {contract_data['name']} (ID: {contract_id})")
                return contract_id
            else:
                print(f"Error creating contract {contract_data['name']}: {response.status_code} - {response.text}")
                return None
        except Exception as e:
            print(f"Error creating contract {contract_data['name']}: {e}")
            return None
    
    def process_csv_file(self, filepath: str, dry_run: bool = True, clear_db: bool = False):
        """Main method to process the CSV file and update contracts"""
        print(f"Starting CSV import process (dry_run={dry_run}, clear_db={clear_db})...")
        
        # Clear database if requested
        if clear_db:
            print("\n" + "="*50)
            print("CLEARING DATABASE")
            print("="*50)
            if not dry_run:
                success = self.clear_database()
                if not success:
                    print("Database clearing failed. Aborting import.")
                    return
                print("Database cleared successfully!")
            else:
                print("Would clear all existing contracts from database")
            print()
        
        # Load and clean data
        df = self.load_csv(filepath)
        if df is None:
            return
        
        df_clean = self.clean_and_validate_data(df)
        if len(df_clean) == 0:
            print("No valid data to process")
            return
        
        # Group data by contract
        contracts_data = self.group_data_by_contract(df_clean)
        
        # Fetch existing contracts (will be empty if database was cleared)
        existing_contracts = self.fetch_existing_contracts() if not clear_db or dry_run else []
        
        # Process each contract
        created_count = 0
        updated_count = 0
        error_count = 0
        
        for contract_key, contract_data in contracts_data.items():
            deal_name = contract_data['deal_name']
            unit = contract_data['unit']
            years_covered = contract_data['years_covered']
            time_series = contract_data['time_series_data']
            contract_type = contract_data['type']
            volume_unit = self.get_volume_unit(contract_type)
            
            print(f"\nProcessing: {deal_name} ({unit}) - Type: {contract_type} - Unit: {volume_unit} - Years: {years_covered} - {len(time_series)} data points")
            
            # Check if contract exists (look for deal name with or without unit suffix)
            existing_contract = self.find_matching_contract(deal_name, existing_contracts)
            if not existing_contract and unit != 'Energy':
                # Also try looking for contract name with unit suffix
                contract_name_with_unit = f"{deal_name} ({unit})"
                existing_contract = self.find_matching_contract(contract_name_with_unit, existing_contracts)
            
            if existing_contract:
                # Update existing contract with time series data
                print(f"  → Found existing contract: {existing_contract['name']}")
                
                if not dry_run:
                    success = self.update_contract_with_timeseries(
                        existing_contract['_id'], 
                        time_series
                    )
                    if success:
                        updated_count += 1
                    else:
                        error_count += 1
                else:
                    print(f"  → Would update existing contract with {len(time_series)} time series points")
                    updated_count += 1
            else:
                # Create new contract
                print(f"  → No existing contract found, will create new one")
                new_contract_data = self.create_new_contract(contract_data)
                
                if not dry_run:
                    contract_id = self.create_contract_via_api(new_contract_data)
                    if contract_id:
                        created_count += 1
                    else:
                        error_count += 1
                else:
                    print(f"  → Would create new contract: {new_contract_data['name']}")
                    print(f"     Type: {new_contract_data['type']}, State: {new_contract_data['state']}, Unit: {new_contract_data['unit']}")
                    print(f"     Annual Volume: {new_contract_data['annualVolume']:,.3f} {volume_unit}, Avg Price: ${new_contract_data['strikePrice']:.2f}")
                    print(f"     Total Volume: {new_contract_data['totalVolume']:,.3f} {volume_unit}, Time Series Points: {len(time_series)}, Years: {years_covered}")
                    created_count += 1
        
        # Summary
        print(f"\n{'='*50}")
        print(f"IMPORT SUMMARY ({'DRY RUN' if dry_run else 'LIVE RUN'})")
        print(f"{'='*50}")
        print(f"Contracts created: {created_count}")
        print(f"Contracts updated: {updated_count}")
        print(f"Errors: {error_count}")
        print(f"Total processed: {len(contracts_data)}")
        
        if dry_run:
            print(f"\nThis was a dry run. To execute the import, run with dry_run=False")

# Example usage
if __name__ == "__main__":
    importer = ContractDataImporter("http://localhost:3000/api/contracts")
    
    # First run a dry run to see what would happen (with database clearing)
    print("Running dry run with database clear...")
    importer.process_csv_file("contracts.csv", dry_run=False, clear_db=True)
    
    # Uncomment the line below to execute the actual import with database clearing
    # importer.process_csv_file("contracts.csv", dry_run=False, clear_db=True)

Running dry run with database clear...
Starting CSV import process (dry_run=False, clear_db=True)...

CLEARING DATABASE
Found 309 contracts to delete...
  ✓ Deleted: BVC TOU
  ✓ Deleted: BVC TOU (LGC)
  ✓ Deleted: CSIRO TOU
  ✓ Deleted: CSIRO TOU (LGC)
  ✓ Deleted: Fairfield Non-TOU
  ✓ Deleted: Fairfield Non-TOU (LGC)
  ✓ Deleted: Fairfield TOU
  ✓ Deleted: Fairfield TOU (LGC)
  ✓ Deleted: SBS TOU
  ✓ Deleted: SBS TOU (LGC)
  ✓ Deleted: SSROC Non-TOU
  ✓ Deleted: SSROC Non-TOU (LGC)
  ✓ Deleted: SSROC TOU
  ✓ Deleted: SSROC TOU (LGC)
  ✓ Deleted: ISPT TOU
  ✓ Deleted: ISPT TOU (LGC)
  ✓ Deleted: Adelaide Metro Non-TOU
  ✓ Deleted: Adelaide Metro Non-TOU (LGC)
  ✓ Deleted: Adelaide Metro TOU
  ✓ Deleted: Adelaide Metro TOU (LGC)
  ✓ Deleted: Bunnings TOU
  ✓ Deleted: Bunnings TOU (LGC)
  ✓ Deleted: CBUS TOU
  ✓ Deleted: CBUS TOU (LGC)
  ✓ Deleted: Hentley Farm TOU
  ✓ Deleted: Hentley Farm TOU (LGC)
  ✓ Deleted: KMART&TARGET TOU
  ✓ Deleted: KMART&TARGET TOU (LGC)
  ✓ Deleted: SA Gover