In [None]:
import os
import cv2
import xml.etree.ElementTree as ET

# ========== 설정 ==========
input_video_dir = r"D:\dataset\이상행동 CCTV 영상\02.싸움(fight)\fight"
output_clip_dir = r"D:\clips"
clip_fps = 30
clip_length = clip_fps * 2  # 2초
time_window = 4 * clip_fps  # 4초 * 2 = 이벤트 앞뒤 8초 확보
width, height = 224, 224
target_event_name = "fight" # 추출할 이벤트

# ========== 시간 → 프레임 변환 ==========
def time_to_frames(time_str, fps=30):
    parts = time_str.strip().split(":")
    if len(parts) == 3:
        h, m, s = parts
    elif len(parts) == 2:
        h = 0
        m, s = parts
    elif len(parts) == 1:
        h, m = 0, 0
        s = parts[0]
    else:
        raise ValueError(f"Invalid time format: {time_str}")
    total_seconds = int(h) * 3600 + int(m) * 60 + float(s)
    return int(total_seconds * fps)


# ========== 이벤트 구간 파싱 ==========
def parse_event_ranges(xml_path):
    tree = ET.parse(xml_path)
    root = tree.getroot()
    event_ranges = []

    for event in root.findall("event"):
        if event.find("eventname").text == target_event_name:
            start_time = event.find("starttime").text
            duration = event.find("duration").text
            start_frame = time_to_frames(start_time, clip_fps)
            duration_frames = time_to_frames(duration, clip_fps)
            end_frame = start_frame + duration_frames
            event_ranges.append((start_frame, end_frame))

    return event_ranges

# ========== 겹침 확인 ==========
def is_overlapping(start1, end1, start2, end2):
    return max(start1, start2) < min(end1, end2)

# ========== 클립 저장 ==========
def save_clip(frames, out_path):
    out = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*'mp4v'), clip_fps, (width, height))
    for frame in frames:
        out.write(cv2.resize(frame, (width, height)))
    out.release()

# ========== 클립 추출 ==========
def extract_clips(video_path, event_ranges, out_dir):
    cap = cv2.VideoCapture(video_path)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    event_clip_id = 0
    normal_clip_id = 0
    normal_count, event_count = 0, 0
    video_basename = os.path.splitext(os.path.basename(video_path))[0]

    for start, end in event_ranges:
        # --- 1. 이벤트 클립 추출 ---
        for i in range(start, end - clip_length + 1, clip_length):
            cap.set(cv2.CAP_PROP_POS_FRAMES, i)
            frames = [cap.read()[1] for _ in range(clip_length) if cap.read()[0]]
            if len(frames) == clip_length:
                label = target_event_name
                label_dir = os.path.join(out_dir, label)
                os.makedirs(label_dir, exist_ok=True)
                save_path = os.path.join(label_dir, f"{video_basename}_{label}_{event_clip_id}.mp4")
                save_clip(frames, save_path)
                event_count += 1
                event_clip_id += 1

        # --- 2. Normal 클립 추출 (전방) ---
        for offset in range(time_window, 0, -clip_length):
            i = start - offset
            if i < 0:
                continue
            if any(is_overlapping(i, i + clip_length, ev_start, ev_end) for ev_start, ev_end in event_ranges):
                continue
            cap.set(cv2.CAP_PROP_POS_FRAMES, i)
            frames = [cap.read()[1] for _ in range(clip_length) if cap.read()[0]]
            if len(frames) == clip_length:
                label = "normal"
                label_dir = os.path.join(out_dir, label)
                os.makedirs(label_dir, exist_ok=True)
                save_path = os.path.join(label_dir, f"{video_basename}_{label}_{normal_clip_id}.mp4")
                save_clip(frames, save_path)
                normal_count += 1
                normal_clip_id += 1

        # --- 3. Normal 클립 추출 (후방) ---
        for offset in range(0, time_window, clip_length):
            i = end + offset
            if i + clip_length > total_frames:
                continue
            if any(is_overlapping(i, i + clip_length, ev_start, ev_end) for ev_start, ev_end in event_ranges):
                continue
            cap.set(cv2.CAP_PROP_POS_FRAMES, i)
            frames = [cap.read()[1] for _ in range(clip_length) if cap.read()[0]]
            if len(frames) == clip_length:
                label = "normal"
                label_dir = os.path.join(out_dir, label)
                os.makedirs(label_dir, exist_ok=True)
                save_path = os.path.join(label_dir, f"{video_basename}_{label}_{normal_clip_id}.mp4")
                save_clip(frames, save_path)
                normal_count += 1
                normal_clip_id += 1

    cap.release()
    return normal_count, event_count

# ========== 전체 영상 처리 ==========
total_normal, total_event = 0, 0

for filename in os.listdir(input_video_dir):
    if filename.endswith(".mp4"):
        video_path = os.path.join(input_video_dir, filename)
        xml_path = os.path.join(input_video_dir, os.path.splitext(filename)[0] + ".xml")

        if not os.path.exists(xml_path):
            print(f"XML not found: {filename}")
            continue

        print(f"Processing {filename}")
        event_ranges = parse_event_ranges(xml_path)
        normal_cnt, event_cnt = extract_clips(video_path, event_ranges, output_clip_dir)
        total_normal += normal_cnt
        total_event += event_cnt

print("\n 완료!")
print(f"총 normal 클립 수: {total_normal}")
print(f"총 event 클립 수: {total_event}")
