# Import

In [None]:
import random
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Dict, List, Optional

import cv2
import torch
import torch.nn.functional as F
from PIL import Image
from tqdm import tqdm
from transformers import ViTForImageClassification, ViTImageProcessor

# Settings

In [2]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
MODEL_ID = "prithivMLmods/Deep-Fake-Detector-v2-Model"
TEST_DIR = Path("./test_data")  # test 데이터 경로

# Submission
OUTPUT_DIR = Path("./output")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)  # output 폴더 없으면 생성

OUT_CSV = OUTPUT_DIR / "baseline_submission.csv"

In [None]:
IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".jfif"}
VIDEO_EXTS = {".mp4", ".mov"}

TARGET_SIZE = (224, 224)
NUM_FRAMES = 10  # 비디오 샘플링 프레임 수

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {DEVICE}")

# Utils

In [5]:
def uniform_frame_indices(total_frames: int, num_frames: int) -> np.ndarray:
    """비디오 프레임을 균등하게 샘플링"""
    if total_frames <= 0:
        return np.array([], dtype=int)
    if total_frames <= num_frames:
        return np.arange(total_frames, dtype=int)
    return np.linspace(0, total_frames - 1, num_frames, dtype=int)

def get_full_frame_padded(pil_img: Image.Image, target_size=(224, 224)) -> Image.Image:
    """전체 이미지를 비율 유지하며 정사각형 패딩 처리"""
    img = pil_img.convert("RGB")
    img.thumbnail(target_size, Image.BICUBIC)
    new_img = Image.new("RGB", target_size, (0, 0, 0))
    new_img.paste(img, ((target_size[0] - img.size[0]) // 2,
                        (target_size[1] - img.size[1]) // 2))
    return new_img

def read_rgb_frames(file_path: Path, num_frames: int = NUM_FRAMES) -> List[np.ndarray]:
    """이미지 또는 비디오에서 RGB 프레임 추출"""
    ext = file_path.suffix.lower()
    
    # 이미지 파일
    if ext in IMAGE_EXTS:
        try:
            img = Image.open(file_path).convert("RGB")
            return [np.array(img)]
        except Exception:
            return []
    
    # 비디오 파일
    if ext in VIDEO_EXTS:
        cap = cv2.VideoCapture(str(file_path))
        total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        
        if total <= 0:
            cap.release()
            return []
        
        frame_indices = uniform_frame_indices(total, num_frames)
        frames = []
        
        for idx in frame_indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx))
            ret, frame = cap.read()
            if not ret:
                continue
            frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        
        cap.release()
        return frames
    
    return []

# Data Preprocessing

In [6]:
class PreprocessOutput:
    def __init__(
        self,
        filename: str,
        imgs: List[Image.Image],
        error: Optional[str] = None
    ):
        self.filename = filename
        self.imgs = imgs
        self.error = error

def preprocess_one(file_path: Path, num_frames: int = NUM_FRAMES) -> PreprocessOutput:
    """
    파일 하나에 대한 전처리 수행
    
    Args:
        file_path: 처리할 파일 경로
        num_frames: 비디오에서 추출할 프레임 수
    
    Returns:
        PreprocessOutput 객체
    """
    try:
        frames = read_rgb_frames(file_path, num_frames=num_frames)
              
        imgs: List[Image.Image] = []
        
        for rgb in frames:     
            imgs.append(get_full_frame_padded(Image.fromarray(rgb), TARGET_SIZE))
        
        return PreprocessOutput(file_path.name, imgs, None)
    
    except Exception as e:
        return PreprocessOutput(file_path.name, [], str(e))

# Model Load

In [None]:
print("Loading model...")
model = ViTForImageClassification.from_pretrained(MODEL_ID).to(DEVICE)
processor = ViTImageProcessor.from_pretrained(MODEL_ID)
model.eval()

print(f"Model loaded: {MODEL_ID}")
print(f"Model config: num_labels={model.config.num_labels}")
if hasattr(model.config, 'id2label'):
    print(f"id2label: {model.config.id2label}")

In [None]:
def detect_face_regions(image: Image.Image) -> List[Image.Image]:
    """얼굴 영역만 추출하여 deepfake 감지 강화"""
    import mediapipe as mp
    mp_face_detection = mp.solutions.face_detection
    
    face_detection = mp_face_detection.FaceDetection(min_detection_confidence=0.5)
    
    img_array = np.array(image)
    results = face_detection.process(cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
    
    face_images = []
    h, w, _ = img_array.shape
    
    if results.detections:
        for detection in results.detections:
            bbox = detection.location_data.relative_bounding_box
            x_min = max(0, int(bbox.xmin * w) - 10)
            y_min = max(0, int(bbox.ymin * h) - 10)
            x_max = min(w, int((bbox.xmin + bbox.width) * w) + 10)
            y_max = min(h, int((bbox.ymin + bbox.height) * h) + 10)
            
            face_crop = image.crop((x_min, y_min, x_max, y_max))
            face_crop = face_crop.resize(TARGET_SIZE, Image.BICUBIC)
            face_images.append(face_crop)
    
    return face_images if face_images else [image]

def apply_laplacian_filter(image: Image.Image) -> np.ndarray:
    """라플라시안 필터로 고주파 성분 강조 (deepfake 특징 강화)"""
    img_array = np.array(image, dtype=np.float32) / 255.0
    
    # 라플라시안 필터 적용
    laplacian_kernel = np.array([[0, -1, 0],
                                 [-1, 4, -1],
                                 [0, -1, 0]], dtype=np.float32)
    
    if len(img_array.shape) == 3:  # RGB
        laplacian = cv2.filter2D(img_array[:,:,0], -1, laplacian_kernel)
    else:
        laplacian = cv2.filter2D(img_array, -1, laplacian_kernel)
    
    laplacian = np.clip(laplacian, 0, 1)
    return laplacian

def infer_fake_probs(pil_images: List[Image.Image]) -> List[float]:
    if not pil_images:
        return []

    probs: List[float] = []

    with torch.inference_mode():
        inputs = processor(images=pil_images, return_tensors="pt")
        inputs = {k: v.to(DEVICE, non_blocking=True) for k, v in inputs.items()}
        logits = model(**inputs).logits
        batch_probs = F.softmax(logits, dim=1)[:, 1]
        probs.extend(batch_probs.cpu().tolist())

    return probs

def infer_with_multi_scale(pil_images: List[Image.Image]) -> List[float]:
    """다양한 크기로 추론하여 정확도 향상"""
    if not pil_images:
        return []
    
    all_probs = []
    
    # 원본 크기
    all_probs.extend(infer_fake_probs(pil_images))
    
    # 작은 크기 (저주파 특징)
    small_images = [img.resize((112, 112), Image.BICUBIC) for img in pil_images]
    small_probs = infer_fake_probs(small_images)
    all_probs.extend(small_probs)
    
    # 큰 크기 (세부 특징)
    large_images = [img.resize((336, 336), Image.BICUBIC) for img in pil_images]
    large_probs = infer_fake_probs(large_images)
    all_probs.extend(large_probs)
    
    return all_probs

# Inference

In [None]:
files = sorted([p for p in TEST_DIR.iterdir() if p.is_file()])
print(f"Test data length: {len(files)}")

results: Dict[str, float] = {}

# 전처리 및 추론 (개선된 버전)
for file_path in tqdm(files, desc="Processing"):
    out = preprocess_one(file_path)
    
    # 1. 에러 로깅
    if out.error:
        print(f"[WARN] {out.filename}: {out.error}")
        results[out.filename] = 0.0
    
    # 2. 정상 추론
    elif out.imgs:
        # Multi-scale 추론으로 더 정확한 판단
        all_probs = infer_with_multi_scale(out.imgs)
        
        if all_probs:
            # 다양한 통계를 활용한 앙상블
            mean_prob = float(np.mean(all_probs))
            max_prob = float(np.max(all_probs))
            
            # 가중 평균: 최대값에 더 높은 가중치
            weighted_prob = mean_prob * 0.4 + max_prob * 0.6
            results[out.filename] = weighted_prob
        else:
            results[out.filename] = 0.0
    
    # 3. 둘 다 없으면 0.0 (real)
    else:
        results[out.filename] = 0.0

print(f"Inference completed. Processed: {len(results)} files")

# Submission

In [None]:
submission = pd.read_csv('./sample_submission.csv')
submission['prob'] = submission['filename'].map(results).fillna(0.0)

# CSV 저장
submission.to_csv(OUT_CSV, encoding='utf-8-sig', index=False)
print(f"Saved submission to: {OUT_CSV}")