Setting up Path

In [None]:
from typing import Optional, Dict, Any, List
import requests
import time
import json
import os
import asyncio
import aiohttp
import pandas as pd
from datetime import datetime
from google.colab import drive
import nest_asyncio
import logging
from collections import Counter
import traceback

# Mount Google Drive and setup directories
try:
    drive.mount('/content/drive')
    base_dir = '/content/drive/MyDrive/Civitai_Data'
    DATA_DIRS = {
        'creators': os.path.join(base_dir, 'creators'),
        'models': os.path.join(base_dir, 'models'),
        'csv': os.path.join(base_dir, 'csv'),
        'logs': os.path.join(base_dir, 'logs'),
        'checkpoints': os.path.join(base_dir, 'checkpoints'),
        'images': os.path.join(base_dir, 'images')
    }

    for dir_path in DATA_DIRS.values():
        os.makedirs(dir_path, exist_ok=True)
except Exception as e:
    raise Exception(f"Failed to setup Google Drive and directories: {e}")

In [None]:
class RateLimitManager:
    def __init__(self, max_requests: int = 100, time_window: int = 60):
        self.max_requests = max_requests
        self.time_window = time_window
        self.requests = []
        self.lock = asyncio.Lock()

    async def wait_if_needed(self):
        """Check and wait if rate limit is approaching"""
        async with self.lock:
            current_time = time.time()
            self.requests = [t for t in self.requests if current_time - t < self.time_window]

            if len(self.requests) >= self.max_requests * 0.8:
                oldest_request = self.requests[0]
                wait_time = self.time_window - (current_time - oldest_request)
                if wait_time > 0:
                    print(f"Approaching rate limit. Waiting {wait_time:.2f} seconds...")
                    await asyncio.sleep(wait_time)
                    self.requests = []

            self.requests.append(current_time)

class ProgressTracker:
    def __init__(self, checkpoint_dir: str):
        self.checkpoint_file = os.path.join(checkpoint_dir, 'crawler_progress.json')
        self.progress = self.load_progress()

    def load_progress(self) -> Dict:
        """加载保存的进度"""
        if os.path.exists(self.checkpoint_file):
            try:
                with open(self.checkpoint_file, 'r') as f:
                    data = json.load(f)
                    # 确保数据正确转换为集合
                    return {
                        'processed_creators': set(data.get('processed_creators', [])),
                        'current_creator': data.get('current_creator'),
                        'processed_models': set(data.get('processed_models', [])),
                        'last_update': data.get('last_update'),
                        'total_models': data.get('total_models', 0)
                    }
            except (json.JSONDecodeError, FileNotFoundError):
                print("Error loading progress file, starting fresh")
                # 备份损坏的文件
                if os.path.exists(self.checkpoint_file):
                    backup_file = f"{self.checkpoint_file}.{datetime.now().strftime('%Y%m%d_%H%M%S')}.bak"
                    shutil.copy2(self.checkpoint_file, backup_file)
                return self._get_default_progress()
        return self._get_default_progress()

    def _get_default_progress(self):
        """获取默认进度状态"""
        return {
            'processed_creators': set(),
            'current_creator': None,
            'processed_models': set(),
            'last_update': None,
            'total_models': 0
        }

    def save_progress(self):
        """保存当前进度"""
        current_progress = {
            'processed_creators': list(self.progress['processed_creators']),
            'current_creator': self.progress['current_creator'],
            'processed_models': list(self.progress['processed_models']),
            'last_update': datetime.now().isoformat(),
            'total_models': self.progress['total_models']
        }
        with open(self.checkpoint_file, 'w') as f:
            json.dump(current_progress, f, indent=2)

    def mark_creator_complete(self, creator: str, models_count: int):
        """标记创作者为已完成"""
        self.progress['processed_creators'].add(creator)
        self.progress['current_creator'] = None
        self.progress['total_models'] += models_count
        self.save_progress()

    def is_creator_processed(self, creator: str) -> bool:
        """检查创作者是否已处理"""
        return creator in self.progress['processed_creators']

    def update_current(self, creator: str):
        """更新当前处理的创作者"""
        self.progress['current_creator'] = creator
        self.save_progress()

    def add_processed_model(self, model_id: int):
        """添加已处理的模型ID"""
        if model_id not in self.progress['processed_models']:
            self.progress['processed_models'].add(model_id)
            self._auto_save()

    def _auto_save(self):
        """自动保存（每100个模型保存一次）"""
        if len(self.progress['processed_models']) % 100 == 0:
            self.save_progress()

In [None]:
class CivitaiModelCrawler:
    def __init__(self, api_key: Optional[str] = None, retry_delay: int = 2):
        self.base_url = "https://civitai.com/api/v1"
        self.headers = {"Content-Type": "application/json"}
        if api_key:
            self.headers["Authorization"] = f"Bearer {api_key}"

        self.retry_delay = retry_delay
        self.rate_limiter = RateLimitManager()
        self.logger = self._setup_logger()

        # 文件路径
        self.models_file = os.path.join(DATA_DIRS['models'], 'all_models.json')
        self.models_csv = os.path.join(DATA_DIRS['csv'], 'all_models.csv')
        self.versions_csv = os.path.join(DATA_DIRS['csv'], 'all_versions.csv')

        # 加载现有数据
        self.all_models = self.load_existing_models()
        self.processed_models = {model['id'] for model in self.all_models}

        # 并发控制
        self.semaphore = asyncio.Semaphore(5)

    def _setup_logger(self):
        """设置日志记录器"""
        logger = logging.getLogger('CivitaiModelCrawler')
        logger.setLevel(logging.INFO)

        if not logger.handlers:
            # 文件处理器
            fh = logging.FileHandler(
                os.path.join(DATA_DIRS['logs'], f'models_crawler_{datetime.now().strftime("%Y%m%d")}.log')
            )
            fh.setLevel(logging.INFO)

            # 控制台处理器
            ch = logging.StreamHandler()
            ch.setLevel(logging.INFO)

            formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
            fh.setFormatter(formatter)
            ch.setFormatter(formatter)

            logger.addHandler(fh)
            logger.addHandler(ch)

        return logger

    def load_existing_models(self) -> List[Dict]:
        """加载现有的模型数据"""
        if os.path.exists(self.models_file):
            try:
                with open(self.models_file, 'r', encoding='utf-8') as f:
                    models = json.load(f)
                self.logger.info(f"Successfully loaded {len(models)} existing models")
                return models
            except json.JSONDecodeError:
                self.logger.error("Error reading models file")
                # 创建备份
                backup_file = f"{self.models_file}.{datetime.now().strftime('%Y%m%d_%H%M%S')}.bak"
                if os.path.exists(self.models_file):
                    shutil.copy2(self.models_file, backup_file)
                    self.logger.info(f"Backed up corrupted file to {backup_file}")
                return []
        return []

    async def _batch_request(self, urls: List[str]) -> List[Dict]:
        """并发请求多个URL"""
        async def fetch_url(url):
            async with self.semaphore:
                return await self._make_request(url)

        tasks = [fetch_url(url) for url in urls]
        return await asyncio.gather(*tasks, return_exceptions=True)

    async def _make_request(self, url: str, params: Dict = None) -> Dict:
        max_retries = 5
        base_delay = self.retry_delay
        exponential_base = 2

        for attempt in range(max_retries):
            try:
                async with aiohttp.ClientSession() as session:
                    async with session.get(url, headers=self.headers, params=params) as response:
                        if response.status == 429:
                            retry_after = response.headers.get('Retry-After')
                            wait_time = int(retry_after) if retry_after else base_delay * (exponential_base ** attempt)
                            self.logger.warning(f"Rate limited. Waiting {wait_time} seconds...")
                            await asyncio.sleep(wait_time)
                            continue

                        if response.status == 403:
                            self.logger.error("API access forbidden. Check API key.")
                            raise Exception("API access forbidden")

                        if response.status == 404:
                            self.logger.warning(f"Resource not found: {url}")
                            return None

                        if response.status >= 500:
                            wait_time = base_delay * (exponential_base ** attempt)
                            self.logger.warning(f"Server error ({response.status}). Waiting {wait_time} seconds...")
                            await asyncio.sleep(wait_time)
                            continue

                        response.raise_for_status()
                        return await response.json()

            except Exception as e:
                self.logger.error(f"Request error: {str(e)}")
                if attempt == max_retries - 1:
                    raise
                await asyncio.sleep(base_delay * (exponential_base ** attempt))

        return None

    async def collect_models(self, creator_username: str, start_page: int = 1,
                           progress_tracker: Optional[ProgressTracker] = None) -> None:
        self.logger.info(f"Starting to collect models for creator: {creator_username}")
        creator_models = []
        batch_size = 100

        try:
            while True:
                await self.rate_limiter.wait_if_needed()

                params = {
                    "username": creator_username,
                    "page": start_page,
                    "limit": 100
                }

                response_data = await self._make_request(f"{self.base_url}/models", params)
                if not response_data or 'items' not in response_data:
                    break

                models = response_data.get('items', [])
                if not models:
                    break

                metadata = response_data.get('metadata', {})
                total_items = metadata.get('totalItems', 0)
                self.logger.info(f"Found {total_items} total models for creator {creator_username}")

                # 过滤新模型
                new_models = [
                    model for model in models
                    if model.get('id') and model.get('id') not in self.processed_models
                    and model.get('creator', {}).get('username') == creator_username
                ]

                if not new_models:
                    break

                # 批量获取模型详情
                model_urls = [f"{self.base_url}/models/{model['id']}" for model in new_models]
                batch_results = await self._batch_request(model_urls)
                valid_models = [
                    model for model in batch_results
                    if isinstance(model, dict) and model.get('id')
                ]

                if valid_models:
                    creator_models.extend(valid_models)
                    self.processed_models.update(model['id'] for model in valid_models)
                    if progress_tracker:
                        for model in valid_models:
                            progress_tracker.add_processed_model(model['id'])
                    self._save_batch(valid_models)

                self.logger.info(f"Processed {len(creator_models)} models for {creator_username}")

                if not metadata.get('nextPage'):
                    break

                start_page += 1
                await asyncio.sleep(1)

            return creator_models

        except Exception as e:
            self.logger.error(f"Error collecting models for {creator_username}: {str(e)}")
            raise

    def _save_batch(self, new_models: List[Dict]):
        """保存一批新模型"""
        if not new_models:
            return

        # 将新模型添加到集合
        self.all_models.extend(new_models)

        # 保存到JSON和CSV
        self._save_json()
        self._export_batch_to_csv(new_models)

    def _save_json(self):
        """保存JSON数据，保持现有数据"""
        # 读取现有数据
        existing_models = []
        if os.path.exists(self.models_file):
            try:
                with open(self.models_file, 'r', encoding='utf-8') as f:
                    existing_models = json.load(f)
                self.logger.info(f"Loaded {len(existing_models)} existing models for merging")
            except json.JSONDecodeError:
                self.logger.error("Error reading existing models file")
                backup_file = f"{self.models_file}.{datetime.now().strftime('%Y%m%d_%H%M%S')}.bak"
                if os.path.exists(self.models_file):
                    shutil.copy2(self.models_file, backup_file)
                    self.logger.info(f"Backed up corrupted file to {backup_file}")

        # 合并去重
        existing_ids = {model.get('id') for model in existing_models}
        new_models = [model for model in self.all_models if model.get('id') not in existing_ids]
        merged_models = existing_models + new_models

        # 保存合并后的数据
        with open(self.models_file, 'w', encoding='utf-8') as f:
            json.dump(merged_models, f, ensure_ascii=False, indent=2)

        self.logger.info(f"Saved {len(merged_models)} models ({len(new_models)} new) to JSON")

    def _export_batch_to_csv(self, models: List[Dict]):
        """导出新模型到CSV"""
        if not models:
            return

        models_data = []
        versions_data = []

        for model in models:
            model_base = {
                'id': model.get('id'),
                'name': model.get('name'),
                'description': model.get('description'),
                'type': model.get('type'),
                'nsfw': model.get('nsfw'), # boolean
                'model_poi': model.get('poi'),
                'tags': ','.join(model.get('tags', [])),
                'creator': model.get('creator', {}).get('username'),
                'stats_download': model.get('stats', {}).get('downloadCount'),
                'stats_comment_count': model.get('stats', {}).get('commentCount'),
                'stats_favorite_count': model.get('stats', {}).get('favoriteCount'),
                'stats_rating_count': model.get('stats', {}).get('ratingCount'),
                'stats_rating': model.get('stats', {}).get('rating')
            }
            models_data.append(model_base)

            for version in model.get('modelVersions', []):
                version_data = {
                    'model_id': model.get('id'),
                    'version_id': version.get('id'),
                    'version_name': version.get('name'),
                    'version_description': version.get('description'),
                    'created_at': version.get('createdAt'),
                    'updated_at': version.get('updatedAt'),
                    'download_url': version.get('downloadUrl'),
                    'training_words': ','.join(version.get('trainedWords', [])),
                    'base_model': version.get('baseModel')
                }
                versions_data.append(version_data)

        # 追加到CSV文件
        write_header = not os.path.exists(self.models_csv)
        pd.DataFrame(models_data).to_csv(
            self.models_csv,
            mode='a',
            header=write_header,
            index=False
        )

        write_header = not os.path.exists(self.versions_csv)
        pd.DataFrame(versions_data).to_csv(
            self.versions_csv,
            mode='a',
            header=write_header,
            index=False
        )

        self.logger.info(f"Exported {len(models_data)} models and {len(versions_data)} versions")

    def get_stats(self) -> Dict:
        """获取统计信息"""
        if not self.all_models:
            return {
                'total_models': 0,
                'unique_creators': 0,
                'model_types': Counter(),
                'avg_downloads': 0
            }

        return {
            'total_models': len(self.all_models),
            'unique_creators': len({
                model.get('creator', {}).get('username')
                for model in self.all_models
                if model.get('creator')
            }),
            'model_types': Counter(
                model.get('type')
                for model in self.all_models
                if model.get('type')
            )
        }

    def validate_data(self):
        """验证数据完整性"""
        # 检查文件
        files_to_check = {
            'models_json': self.models_file,
            'models_csv': self.models_csv,
            'versions_csv': self.versions_csv
        }

        for name, filepath in files_to_check.items():
            if os.path.exists(filepath):
                size = os.path.getsize(filepath) / (1024 * 1024)  # MB
                mtime = datetime.fromtimestamp(os.path.getmtime(filepath))
                print(f"\n{name}:")
                print(f"  路径: {filepath}")
                print(f"  大小: {size:.2f} MB")
                print(f"  最后修改: {mtime}")
            else:
                print(f"\n{name} 不存在: {filepath}")

        # 检查模型数据
        try:
            with open(self.models_file, 'r', encoding='utf-8') as f:
                models = json.load(f)
                print(f"\n模型数据:")
                print(f"  总数: {len(models)}")
                print(f"  创作者数: {len({m.get('creator', {}).get('username') for m in models})}")
                print(f"  类型分布: {Counter(m.get('type') for m in models)}")
        except Exception as e:
            print(f"\n读取模型数据失败: {e}")

        # 检查CSV数据
        try:
            if os.path.exists(self.models_csv):
                df_models = pd.read_csv(self.models_csv)
                print(f"\nCSV数据:")
                print(f"  模型数: {len(df_models)}")
                print(f"  创作者数: {df_models['creator'].nunique()}")
                print(f"  类型分布: {df_models['type'].value_counts().to_dict()}")
        except Exception as e:
            print(f"\n读取CSV数据失败: {e}")

In [None]:
async def main():
    """主执行函数"""
    try:
        print("Starting the Civitai crawler...")

        # 初始化进度追踪器和爬虫
        progress = ProgressTracker(DATA_DIRS['checkpoints'])
        # creator_crawler = CivitaiCreatorsCrawler()
        model_crawler = CivitaiModelCrawler()

        # 加载或收集创作者数据
        # creators_file = os.path.join(DATA_DIRS['creators'], 'creators_completed.json')
        # if not os.path.exists(creators_file):
        #     print("No creators data found. Running creator crawler first...")
        #     await creator_crawler.collect_creators()

        # 读取创作者数据
        creators_file = os.path('creators_completed.json')
        with open(creators_file, 'r') as f:
            creators_data = json.load(f)

        # 获取活跃创作者
        active_creators = [
            c['username'] for c in creators_data
            if c.get('modelCount', 0) > 0
        ]
        print(f"Found {len(active_creators)} active creators")

        # 如果有上次未完成的创作者，从那里开始
        if progress.progress['current_creator']:
            try:
                start_idx = active_creators.index(progress.progress['current_creator'])
                active_creators = active_creators[start_idx:]
                print(f"Resuming from creator: {progress.progress['current_creator']}")
            except ValueError:
                print("Previous creator not found, starting from beginning")

        # 处理每个创作者
        total_creators = len(active_creators)
        for idx, creator in enumerate(active_creators, 1):
            if progress.is_creator_processed(creator):
                print(f"Skipping already processed creator: {creator}")
                continue

            print(f"\nProcessing creator: {creator} ({idx}/{total_creators})")
            try:
                # 收集模型
                creator_models = await model_crawler.collect_models(
                    creator_username=creator,
                    progress_tracker=progress
                )

                if creator_models:
                    progress.mark_creator_complete(creator, len(creator_models))
                    print(f"Completed creator {creator} with {len(creator_models)} models")
                else:
                    progress.mark_creator_complete(creator, 0)
                    print(f"No new models for creator {creator}")

                # 打印当前统计
                stats = model_crawler.get_stats()
                print("\nCurrent Statistics:")
                print(f"Total Models: {stats['total_models']}")
                print(f"Unique Creators: {stats['unique_creators']}")
                print(f"Model Types: {dict(stats['model_types'])}")

                # 创作者之间添加延迟
                await asyncio.sleep(5)

            except Exception as e:
                print(f"Error processing creator {creator}: {e}")
                progress.update_current(creator)
                await asyncio.sleep(30)  # 错误后等待更长时间
                continue

        print("\nCrawling completed!")

        # 验证最终数据
        model_crawler.validate_data()

    except Exception as e:
        print(f"Critical error occurred: {e}")
        traceback.print_exc()

if __name__ == "__main__":
    # 设置 Colab 环境
    nest_asyncio.apply()

    # 创建事件循环并运行
    loop = asyncio.get_event_loop()
    loop.run_until_complete(main())
