In [None]:
# images_crawler.py
from typing import Optional, Dict, Any, List
import asyncio
import aiohttp
import pandas as pd
import json
import os
import time
import logging
from datetime import datetime
from collections import Counter
import traceback

from config import DATA_DIRS

In [None]:
# images_crawler.py

from typing import Optional, Dict, Any, List, Set
import asyncio
import aiohttp
import pandas as pd
import json
import os
import time
import logging
from datetime import datetime
from collections import Counter
import traceback

from config import DATA_DIRS

class RateLimitManager:
    def __init__(self, max_requests: int = 100, time_window: int = 60):
        self.max_requests = max_requests
        self.time_window = time_window
        self.requests = []
        self.lock = asyncio.Lock()

    async def wait_if_needed(self):
        """Check and wait if rate limit is approaching"""
        async with self.lock:
            current_time = time.time()
            self.requests = [t for t in self.requests if current_time - t < self.time_window]

            if len(self.requests) >= self.max_requests * 0.8:
                oldest_request = self.requests[0]
                wait_time = self.time_window - (current_time - oldest_request)
                if wait_time > 0:
                    print(f"Approaching rate limit. Waiting {wait_time:.2f} seconds...")
                    await asyncio.sleep(wait_time)
                    self.requests = []

            self.requests.append(current_time)

class CivitaiImageCrawler:
    def __init__(self, api_key: Optional[str] = None, retry_delay: int = 2):
        self.base_url = "https://civitai.com/api/v1"
        self.headers = {"Content-Type": "application/json"}
        if api_key:
            self.headers["Authorization"] = f"Bearer {api_key}"
        
        self.retry_delay = retry_delay
        self.rate_limiter = RateLimitManager()
        self.logger = self._setup_logger()
        
        # 文件路径
        self.images_csv = os.path.join(DATA_DIRS['csv'], 'version_images.csv')  # 改成version_images
        self.images_json = os.path.join(DATA_DIRS['models'], 'version_images.json')  # 改成version_images
        
        # 加载已处理的图片ID
        self.processed_images = set(self._load_processed_images())
        
        # 并发控制
        self.semaphore = asyncio.Semaphore(5)

    def _setup_logger(self):
        """设置日志记录器"""
        logger = logging.getLogger('CivitaiImageCrawler')
        logger.setLevel(logging.INFO)

        if not logger.handlers:
            # 文件处理器
            fh = logging.FileHandler(
                os.path.join(DATA_DIRS['logs'], f'images_crawler_{datetime.now().strftime("%Y%m%d")}.log')
            )
            fh.setLevel(logging.INFO)

            # 控制台处理器
            ch = logging.StreamHandler()
            ch.setLevel(logging.INFO)

            formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
            fh.setFormatter(formatter)
            ch.setFormatter(formatter)

            logger.addHandler(fh)
            logger.addHandler(ch)

        return logger

    def _load_processed_images(self) -> Set[int]:
        """加载已处理的图片ID"""
        if os.path.exists(self.images_csv):
            try:
                df = pd.read_csv(self.images_csv)
                return set(df['id'].unique())
            except Exception as e:
                self.logger.error(f"Error loading processed images: {e}")
        return set()

    async def _make_request(self, url: str, params: Dict = None) -> Dict:
        """发送API请求"""
        max_retries = 5
        base_delay = self.retry_delay

        for attempt in range(max_retries):
            try:
                async with aiohttp.ClientSession() as session:
                    async with session.get(url, headers=self.headers, params=params) as response:
                        if response.status == 429:  # Rate limit
                            retry_after = response.headers.get('Retry-After', base_delay)
                            wait_time = int(retry_after) * (2 ** attempt)
                            self.logger.warning(f"Rate limited. Waiting {wait_time} seconds...")
                            await asyncio.sleep(wait_time)
                            continue

                        response.raise_for_status()
                        return await response.json()

            except Exception as e:
                if attempt == max_retries - 1:
                    raise
                wait_time = base_delay * (2 ** attempt)
                self.logger.error(f"Request error: {e}. Retrying in {wait_time} seconds...")
                await asyncio.sleep(wait_time)

        return None

    async def collect_images_by_model_version(self, version_id: int, start_page: int = 1) -> None:
        """收集特定模型版本的所有图片"""
        self.logger.info(f"Starting to collect images for model version: {version_id}")
        all_images = []

        try:
            while True:
                await self.rate_limiter.wait_if_needed()
                
                params = {
                    "modelVersionId": version_id,
                    "page": start_page,
                    "limit": 100,
                    "sort": "Newest"
                }
                
                response_data = await self._make_request(f"{self.base_url}/images", params)
                if not response_data or 'items' not in response_data:
                    break
                    
                images = response_data.get('items', [])
                if not images:
                    break
                
                # 处理图片数据
                new_images = []
                for img in images:
                    if img.get('id') in self.processed_images:
                        continue

                    processed_image = {
                        'id': img.get('id'),
                        'url': img.get('url'),
                        'hash': img.get('hash'),
                        'width': img.get('width'),
                        'height': img.get('height'),
                        'model_version_id': version_id,
                        'model_id': img.get('modelId'),
                        'post_id': img.get('postId'),
                        'nsfw': img.get('nsfw'),
                        'nsfw_level': img.get('nsfwLevel'),
                        'created_at': img.get('createdAt'),
                        'cry_count': img.get('stats', {}).get('cryCount'),
                        'laugh_count': img.get('stats', {}).get('laughCount'),
                        'heart_count': img.get('stats', {}).get('heartCount'),
                        'like_count': img.get('stats', {}).get('likeCount'),
                        'comment_count': img.get('stats', {}).get('commentCount'),
                        'username': img.get('username'),
                        'meta': json.dumps(img.get('meta', {}))
                    }
                    new_images.append(processed_image)
                    self.processed_images.add(img.get('id'))

                if new_images:
                    self._save_batch_to_csv(new_images)
                    all_images.extend(new_images)
                    self.logger.info(f"Saved {len(new_images)} new images for version {version_id}")
                
                metadata = response_data.get('metadata', {})
                if not metadata.get('nextPage'):
                    break
                    
                start_page += 1
                await asyncio.sleep(1)

            return all_images

        except Exception as e:
            self.logger.error(f"Error collecting images for version {version_id}: {str(e)}")
            raise

    def _save_batch_to_csv(self, images: List[Dict]):
        """保存图片数据到CSV"""
        if not images:
            return
            
        df = pd.DataFrame(images)
        write_header = not os.path.exists(self.images_csv)
        df.to_csv(
            self.images_csv,
            mode='a',
            header=write_header,
            index=False
        )

In [None]:
async def main():
    """主执行函数"""
    try:
        print("Starting the Civitai image crawler...")

        # 初始化爬虫
        image_crawler = CivitaiImageCrawler()
        
        # 从versions.csv读取所有version_id
        versions_csv = os.path.join(DATA_DIRS['csv'], 'all_versions.csv')
        if not os.path.exists(versions_csv):
            raise FileNotFoundError("Versions CSV file not found. Please run model crawler first.")
            
        versions_df = pd.read_csv(versions_csv)
        total_versions = len(versions_df)
        
        print(f"Found {total_versions} model versions to process")
        
        # 处理每个版本的图片
        for idx, row in versions_df.iterrows():
            version_id = row['version_id']
            model_id = row['model_id']  # 保存model_id以便参考
            print(f"\nProcessing version {version_id} of model {model_id} ({idx+1}/{total_versions})")
            
            try:
                await image_crawler.collect_images_by_model_version(version_id)
                
            except Exception as e:
                print(f"Error processing version {version_id}: {e}")
                await asyncio.sleep(30)  # 错误后等待较长时间
                continue
                
            # 版本之间添加延迟
            await asyncio.sleep(2)

        print("\nImage crawling completed!")

    except Exception as e:
        print(f"Critical error occurred: {e}")
        traceback.print_exc()

if __name__ == "__main__":
    # 设置事件循环并运行
    loop = asyncio.get_event_loop()
    loop.run_until_complete(main())