import os
import sys
import requests
from pathlib import Path
from urllib.parse import urljoin
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
import argparse

class HiPSDownloader:
    def __init__(self, base_url, output_dir, max_order=3, max_workers=4):
        self.base_url = base_url.rstrip('/') + '/'
        self.output_dir = Path(output_dir)
        self.max_order = max_order
        self.max_workers = max_workers
        self.session = requests.Session()
        self.session.headers.update({'User-Agent': 'HiPS-Downloader/1.0'})
        
        self.downloaded = 0
        self.failed = 0
        self.skipped = 0
        
    def download_file(self, url, local_path):
        try:
            if local_path.exists():
                self.skipped += 1
                return True
                
            local_path.parent.mkdir(parents=True, exist_ok=True)

            response = self.session.get(url, timeout=30, stream=True)
            
            if response.status_code == 200:
                with open(local_path, 'wb') as f:
                    for chunk in response.iter_content(chunk_size=8192):
                        f.write(chunk)
                self.downloaded += 1
                return True
            elif response.status_code == 404:
                return False
            else:
                print(f"Failed to download {url}: HTTP {response.status_code}")
                self.failed += 1
                return False
                
        except Exception as e:
            print(f"Error downloading {url}: {str(e)}")
            self.failed += 1
            return False
    
    def download_properties(self):
        print("Downloading properties file...")
        props_url = urljoin(self.base_url, 'properties')
        props_path = self.output_dir / 'properties'
        
        if self.download_file(props_url, props_path):
            print(f"✓ Properties file saved to {props_path}")
            return True
        else:
            print(f"✗ Failed to download properties file")
            return False
    
    def download_allsky(self):
        print("Downloading Allsky preview...")
        allsky_url = urljoin(self.base_url, 'Norder3/Allsky.jpg')
        allsky_path = self.output_dir / 'Norder3' / 'Allsky.jpg'
        
        if self.download_file(allsky_url, allsky_path):
            print(f"✓ Allsky preview saved")
            return True
        return False
    
    def get_tile_urls_for_order(self, order):
        npix_max = 12 * (4 ** order)  
        tile_urls = []
        
        for npix in range(npix_max):
            dir_num = (npix // 10000) * 10000
            
            for fmt in ['jpg', 'png']:
                tile_path = f"Norder{order}/Dir{dir_num}/Npix{npix}.{fmt}"
                url = urljoin(self.base_url, tile_path)
                local_path = self.output_dir / tile_path
                tile_urls.append((url, local_path))
        
        return tile_urls
    
    def download_order(self, order):
        print(f"\\nDownloading Order {order}...")
        tile_urls = self.get_tile_urls_for_order(order)
        total_tiles = len(tile_urls)
        
        print(f"  Attempting {total_tiles} tiles...")
        
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            futures = {executor.submit(self.download_file, url, path): (url, path) 
                      for url, path in tile_urls}
            
            completed = 0
            for future in as_completed(futures):
                completed += 1
                if completed % 100 == 0:
                    print(f"  Progress: {completed}/{total_tiles} tiles processed")
        
        print(f"  Order {order} complete: {self.downloaded} new, {self.skipped} existing, {self.failed} failed")
    
    def download_survey(self):
        print(f"Starting HiPS download from {self.base_url}")
        print(f"Output directory: {self.output_dir}")
        print(f"Max order: {self.max_order}")
        print(f"Parallel workers: {self.max_workers}")
        print("="*60)
        
        self.output_dir.mkdir(parents=True, exist_ok=True)

        self.download_properties()
        
        self.download_allsky()
        
        for order in range(self.max_order + 1):
            start_time = time.time()
            self.download_order(order)
            elapsed = time.time() - start_time
            print(f"  Order {order} took {elapsed:.1f} seconds")
        
        print("="*60)
        print(f"Download complete!")
        print(f"Total: {self.downloaded} downloaded, {self.skipped} skipped, {self.failed} failed")
        print(f"\\nTo use in Stellarium:")
        print(f"1. Copy {self.output_dir} to your Stellarium surveys directory")
        print(f"2. Restart Stellarium")
        print(f"3. Enable HiPS Surveys in Configuration (F2 > Extras)")
        print(f"4. Select your survey from the Surveys menu")

def main():
    parser = argparse.ArgumentParser(
        description='Download HiPS survey tiles for offline Stellarium use',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Download orders 0-3 (low resolution, ~1000 tiles)
  python hips_downloader.py --max-order 3
  
  # Download orders 0-5 (medium resolution, ~16k tiles)
  python hips_downloader.py --max-order 5 --workers 8
  
  # Download full resolution (orders 0-8, millions of tiles, 100+ GB)
  python hips_downloader.py --max-order 8 --workers 16
        """
    )
    
    parser.add_argument(
        '--url',
        default='https://www.simg.de/nebulae3/dr0_1/tc8/',
        help='Base URL of HiPS survey'
    )
    
    parser.add_argument(
        '--output',
        default='./NSNS_DR0.1_tc8',
        help='Output directory for downloaded tiles'
    )
    
    parser.add_argument(
        '--max-order',
        type=int,
        default=3,
        help='Maximum HiPS order to download (0-8)'
    )
    
    parser.add_argument(
        '--workers',
        type=int,
        default=4,
        help='Number of parallel download threads'
    )
    
    args = parser.parse_args()
    
    if args.max_order < 0 or args.max_order > 8:
        print("ERROR: max-order must be between 0 and 8")
        sys.exit(1)
    
    downloader = HiPSDownloader(
        base_url=args.url,
        output_dir=args.output,
        max_order=args.max_order,
        max_workers=args.workers
    )
    
    try:
        downloader.download_survey()
    except KeyboardInterrupt:
        print("\\n\\nDownload interrupted by user")
        print(f"Progress: {downloader.downloaded} downloaded, {downloader.skipped} skipped")
        print("You can resume by running the script again")
        sys.exit(0)


if __name__ == '__main__':
    main()