In [None]:
import torch
import torch.nn as nn
from transformers import ViTImageProcessor, ViTForImageClassification
from transformers import TrainingArguments, Trainer
from PIL import Image
import numpy as np
from torch.utils.data import Dataset
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split
import os
import json
from datetime import datetime
import wandb
import random

# Wandb 초기화 함수
def init_wandb(project_name="robot-task-classifier", run_name=None, config=None):
    """
    Wandb 초기화

    Args:
        project_name: Wandb 프로젝트 이름
        run_name: 실행 이름 (None이면 자동 생성)
        config: 설정 딕셔너리
    """
    if run_name is None:
        run_name = f"vit-base-{datetime.now().strftime('%Y%m%d-%H%M%S')}"

    wandb.init(
        project=project_name,
        name=run_name,
        config=config,
        tags=["ViT", "image-classification", "robot-tasks"]
    )

    print(f"Wandb 초기화 완료: {project_name}/{run_name}")

# 1. ViT 모델과 프로세서 로드
def load_vit_model(num_classes, model_name="google/vit-base-patch16-224"):
    """
    ViT-Base 모델을 로드하고 분류 헤드를 커스터마이징
    """
    print(f"모델 로딩: {model_name}")
    print(f"분류 클래스 수: {num_classes}")

    processor = ViTImageProcessor.from_pretrained(model_name)
    model = ViTForImageClassification.from_pretrained(
        model_name,
        num_labels=num_classes,
        ignore_mismatched_sizes=True
    )

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    print(f"모델 파라미터 수: {total_params:,}")
    print(f"훈련 가능한 파라미터 수: {trainable_params:,}")

    # Wandb에 모델 정보 로깅
    if wandb.run is not None:
        wandb.log({
            "model/total_parameters": total_params,
            "model/trainable_parameters": trainable_params,
            "model/num_classes": num_classes
        })

    return model, processor

# 2. 커스텀 데이터셋 클래스
class RobotHeadDataset(Dataset):
    """로봇 헤드 이미지 데이터셋 클래스"""

    def __init__(self, image_paths, labels, processor, augment=False):
        self.image_paths = image_paths
        self.labels = labels
        self.processor = processor
        self.augment = augment

        print(f"데이터셋 생성 완료: {len(image_paths)}개 샘플, 증강: {augment}")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        try:
            image = Image.open(self.image_paths[idx]).convert('RGB')

            # 간단한 데이터 증강 (훈련 시에만)
            if self.augment and random.random() > 0.5:
                # 좌우 반전
                image = image.transpose(Image.FLIP_LEFT_RIGHT)

            inputs = self.processor(images=image, return_tensors="pt")

            return {
                'pixel_values': inputs['pixel_values'].squeeze(),
                'labels': torch.tensor(self.labels[idx], dtype=torch.long)
            }
        except Exception as e:
            print(f"이미지 로드 실패: {self.image_paths[idx]}, 에러: {e}")
            return self.__getitem__(0)

# 3. 데이터 분할 함수
def prepare_and_split_data(data_folder, class_names, val_split=0.2, test_split=0.1, random_state=42):
    """
    데이터 폴더에서 이미지를 로드하고 train/val/test로 분할

    Args:
        data_folder: 전체 데이터가 있는 폴더
        class_names: 클래스 이름 리스트
        val_split: 검증 데이터 비율
        test_split: 테스트 데이터 비율 (선택사항)
        random_state: 재현 가능한 분할을 위한 시드

    Returns:
        train_paths, train_labels, val_paths, val_labels, [test_paths, test_labels]
    """
    print(f"데이터 분할 시작: val={val_split}, test={test_split}")

    all_image_paths = []
    all_labels = []
    class_distribution = {}

    # 전체 데이터 수집
    supported_formats = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff')

    for class_idx, class_name in enumerate(class_names):
        class_folder = os.path.join(data_folder, class_name)

        if not os.path.exists(class_folder):
            print(f"경고: {class_folder} 폴더가 존재하지 않습니다.")
            continue

        class_images = []
        for img_file in os.listdir(class_folder):
            if img_file.lower().endswith(supported_formats):
                full_path = os.path.join(class_folder, img_file)
                all_image_paths.append(full_path)
                all_labels.append(class_idx)
                class_images.append(img_file)

        class_distribution[class_name] = len(class_images)
        print(f"  {class_name}: {len(class_images)}개 이미지")

    print(f"총 {len(all_image_paths)}개 이미지 로드됨")

    # 첫 번째 분할: train + val vs test
    if test_split > 0:
        train_val_paths, test_paths, train_val_labels, test_labels = train_test_split(
            all_image_paths, all_labels,
            test_size=test_split,
            random_state=random_state,
            stratify=all_labels  # 클래스 비율 유지
        )

        # 두 번째 분할: train vs val
        val_size_adjusted = val_split / (1 - test_split)  # 전체에서 test를 제외한 비율로 조정
        train_paths, val_paths, train_labels, val_labels = train_test_split(
            train_val_paths, train_val_labels,
            test_size=val_size_adjusted,
            random_state=random_state,
            stratify=train_val_labels
        )

        print(f"데이터 분할 완료:")
        print(f"  훈련: {len(train_paths)}개 ({len(train_paths)/len(all_image_paths)*100:.1f}%)")
        print(f"  검증: {len(val_paths)}개 ({len(val_paths)/len(all_image_paths)*100:.1f}%)")
        print(f"  테스트: {len(test_paths)}개 ({len(test_paths)/len(all_image_paths)*100:.1f}%)")

        # Wandb에 데이터 분포 로깅
        if wandb.run is not None:
            wandb.log({
                "data/total_samples": len(all_image_paths),
                "data/train_samples": len(train_paths),
                "data/val_samples": len(val_paths),
                "data/test_samples": len(test_paths),
                "data/class_distribution": class_distribution
            })

        return train_paths, train_labels, val_paths, val_labels, test_paths, test_labels

    else:
        # train vs val만 분할
        train_paths, val_paths, train_labels, val_labels = train_test_split(
            all_image_paths, all_labels,
            test_size=val_split,
            random_state=random_state,
            stratify=all_labels
        )

        print(f"데이터 분할 완료:")
        print(f"  훈련: {len(train_paths)}개 ({len(train_paths)/len(all_image_paths)*100:.1f}%)")
        print(f"  검증: {len(val_paths)}개 ({len(val_paths)/len(all_image_paths)*100:.1f}%)")

        # Wandb에 데이터 분포 로깅
        if wandb.run is not None:
            wandb.log({
                "data/total_samples": len(all_image_paths),
                "data/train_samples": len(train_paths),
                "data/val_samples": len(val_paths),
                "data/class_distribution": class_distribution
            })

        return train_paths, train_labels, val_paths, val_labels

# 4. Wandb 연동 Trainer 클래스
class WandbTrainer(Trainer):
    """Wandb 로깅이 통합된 커스텀 Trainer"""

    def log(self, logs):
        """훈련 로그를 Wandb에 전송"""
        super().log(logs)

        if wandb.run is not None:
            # 스텝별 로깅
            wandb_logs = {}
            for key, value in logs.items():
                if isinstance(value, (int, float)):
                    wandb_logs[f"train/{key}"] = value

            if wandb_logs:
                wandb.log(wandb_logs, step=self.state.global_step)

# 5. 모델 훈련 함수 (Wandb 연동)
def train_classifier_with_wandb(train_dataset, val_dataset, model, output_dir="./robot-task-classifier", config=None):
    """
    Wandb와 연동된 분류기 훈련
    """
    # 기본 config 설정
    if config is None:
        config = {}

    training_args = TrainingArguments(
        output_dir=output_dir,
        per_device_train_batch_size=config.get('batch_size', 16),
        per_device_eval_batch_size=config.get('batch_size', 16),
        num_train_epochs=config.get('num_epochs', 10),
        learning_rate=config.get('learning_rate', 2e-5),
        weight_decay=config.get('weight_decay', 0.01),
        logging_steps=10,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="accuracy",
        greater_is_better=True,
        warmup_steps=config.get('warmup_steps', 100),
        fp16=True,
        dataloader_num_workers=4,
        remove_unused_columns=False,
        report_to="wandb" if wandb.run is not None else "none",  # Wandb 연동
    )

    def compute_metrics(eval_pred):
        """검증 메트릭 계산 및 Wandb 로깅"""
        predictions, labels = eval_pred
        predictions = np.argmax(predictions, axis=1)

        accuracy = accuracy_score(labels, predictions)

        # 클래스별 정확도 계산
        class_accuracies = {}
        for i, class_name in enumerate(config.get('class_names', [])):
            class_mask = (labels == i)
            if class_mask.sum() > 0:
                class_acc = (predictions[class_mask] == labels[class_mask]).mean()
                class_accuracies[f"val_accuracy/{class_name}"] = class_acc

        metrics = {"accuracy": accuracy}

        # Wandb에 클래스별 정확도 로깅
        if wandb.run is not None and class_accuracies:
            wandb.log(class_accuracies)

        return metrics

    # WandbTrainer 사용
    trainer = WandbTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics,
    )

    print("=" * 50)
    print("훈련 시작 (Wandb 연동)")
    print(f"훈련 데이터: {len(train_dataset)}개")
    print(f"검증 데이터: {len(val_dataset)}개")
    print(f"배치 크기: {training_args.per_device_train_batch_size}")
    print(f"에폭 수: {training_args.num_train_epochs}")
    print(f"학습률: {training_args.learning_rate}")
    print("=" * 50)

    # 훈련 실행
    trainer.train()

    # 최종 모델 저장
    trainer.save_model()
    print(f"모델이 {output_dir}에 저장되었습니다.")

    # Wandb에 최종 메트릭 로깅
    if wandb.run is not None:
        final_metrics = trainer.evaluate()
        wandb.log({"final/val_accuracy": final_metrics.get("eval_accuracy", 0)})

    return trainer

# 6. 모델 저장 및 로드 함수들 (기존과 동일)
def save_model_and_config(model, processor, class_names, save_path):
    """모델, 프로세서, 설정을 함께 저장"""
    print(f"모델 저장 중: {save_path}")

    os.makedirs(save_path, exist_ok=True)

    # 모델과 프로세서 저장
    model.save_pretrained(save_path)
    processor.save_pretrained(save_path)

    # 설정 저장
    config = {
        "class_names": class_names,
        "num_classes": len(class_names),
        "model_name": "google/vit-base-patch16-224",
        "save_date": datetime.now().isoformat(),
        "model_type": "ViT-Base"
    }

    with open(os.path.join(save_path, "config.json"), "w", encoding="utf-8") as f:
        json.dump(config, f, indent=2, ensure_ascii=False)

    # Wandb에 모델 저장
    if wandb.run is not None:
        artifact = wandb.Artifact(
            name="robot-task-classifier",
            type="model",
            description="Trained ViT model for robot task classification"
        )
        artifact.add_dir(save_path)
        wandb.log_artifact(artifact)

    print("저장 완료!")

def load_saved_model(save_path):
    """저장된 모델과 설정 로드"""
    print(f"모델 로드 중: {save_path}")

    with open(os.path.join(save_path, "config.json"), "r", encoding="utf-8") as f:
        config = json.load(f)

    class_names = config["class_names"]
    model = ViTForImageClassification.from_pretrained(save_path)
    processor = ViTImageProcessor.from_pretrained(save_path)

    print("로드 완료!")
    return model, processor, class_names

# 7. 예측 및 평가 함수
def evaluate_model_with_wandb(model, test_dataset, processor, class_names):
    """테스트 데이터셋으로 모델 평가 및 Wandb 로깅"""
    from torch.utils.data import DataLoader
    from sklearn.metrics import confusion_matrix, classification_report
    import seaborn as sns
    import matplotlib.pyplot as plt

    model.eval()
    dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False)

    all_predictions = []
    all_labels = []

    print("테스트 데이터 평가 중...")
    with torch.no_grad():
        for batch in dataloader:
            outputs = model(batch['pixel_values'])
            predictions = torch.argmax(outputs.logits, dim=-1)

            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(batch['labels'].cpu().numpy())

    # 정확도 계산
    accuracy = accuracy_score(all_labels, all_predictions)
    print(f"테스트 정확도: {accuracy:.4f}")

    # 분류 보고서
    report = classification_report(all_labels, all_predictions, target_names=class_names, output_dict=True)

    # 혼동 행렬
    cm = confusion_matrix(all_labels, all_predictions)

    # Wandb 로깅
    if wandb.run is not None:
        # 정확도 로깅
        wandb.log({"test/accuracy": accuracy})

        # 클래스별 메트릭 로깅
        for class_name in class_names:
            if class_name in report:
                wandb.log({
                    f"test/precision/{class_name}": report[class_name]['precision'],
                    f"test/recall/{class_name}": report[class_name]['recall'],
                    f"test/f1/{class_name}": report[class_name]['f1-score']
                })

        # 혼동 행렬 시각화
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                   xticklabels=class_names, yticklabels=class_names)
        plt.title('Confusion Matrix')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')

        wandb.log({"test/confusion_matrix": wandb.Image(plt)})
        plt.close()

    return accuracy, report

# 8. 메인 실행 함수
def main():
    """메인 실행 함수"""

    # === 설정 ===
    config = {
        'class_names': ["pick_and_place", "navigation", "manipulation", "inspection", "assembly"],
        'data_folder': "data/all",  # 전체 데이터가 있는 폴더
        'val_split': 0.2,           # 검증 데이터 비율
        'test_split': 0.1,          # 테스트 데이터 비율
        'batch_size': 16,
        'num_epochs': 10,
        'learning_rate': 2e-5,
        'weight_decay': 0.01,
        'warmup_steps': 100,
        'random_state': 42
    }

    model_save_path = "./saved_models/robot_task_classifier"

    # === Wandb 초기화 ===
    init_wandb(
        project_name="robot-task-classifier",
        run_name=f"vit-base-{datetime.now().strftime('%Y%m%d-%H%M%S')}",
        config=config
    )

    print("=" * 60)
    print("로봇 태스크 분류기 훈련 시작 (Wandb 연동)")
    print("=" * 60)

    # === 모델 로드 ===
    model, processor = load_vit_model(len(config['class_names']))

    # === 데이터 준비 및 분할 ===
    result = prepare_and_split_data(
        config['data_folder'],
        config['class_names'],
        val_split=config['val_split'],
        test_split=config['test_split'],
        random_state=config['random_state']
    )

    if len(result) == 6:  # test 데이터도 있는 경우
        train_paths, train_labels, val_paths, val_labels, test_paths, test_labels = result
        has_test = True
    else:  # test 데이터 없는 경우
        train_paths, train_labels, val_paths, val_labels = result
        has_test = False

    if len(train_paths) == 0:
        print("오류: 훈련 데이터가 없습니다.")
        return

    # === 데이터셋 생성 ===
    train_dataset = RobotHeadDataset(train_paths, train_labels, processor, augment=True)
    val_dataset = RobotHeadDataset(val_paths, val_labels, processor, augment=False)

    if has_test:
        test_dataset = RobotHeadDataset(test_paths, test_labels, processor, augment=False)

    # === 모델 훈련 ===
    trainer = train_classifier_with_wandb(train_dataset, val_dataset, model, model_save_path, config)

    # === 모델 저장 ===
    save_model_and_config(model, processor, config['class_names'], model_save_path)

    # === 테스트 평가 (선택사항) ===
    if has_test:
        print("\n테스트 데이터셋 평가 중...")
        test_accuracy, test_report = evaluate_model_with_wandb(model, test_dataset, processor, config['class_names'])
        print(f"최종 테스트 정확도: {test_accuracy:.4f}")

    # Wandb 종료
    wandb.finish()

    print("=" * 60)
    print("훈련 완료!")
    print(f"Wandb 대시보드에서 결과를 확인하세요: {wandb.run.url if wandb.run else 'N/A'}")
    print("=" * 60)

if __name__ == "__main__":
    main()