## Train

In [None]:
import os
import random
import shutil
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor, get_linear_schedule_with_warmup
import torchvision.transforms as transforms
from torchvision.models import efficientnet_b3, EfficientNet_B3_Weights
from PIL import Image, ImageFile
import librosa
import numpy as np
from collections import defaultdict, Counter
from sklearn.utils.class_weight import compute_class_weight
import json
import gc

ImageFile.LOAD_TRUNCATED_IMAGES = True

# =========================
# 0. ÏÑ§Ï†ï
# =========================
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)
np.random.seed(SEED)

BEHAVIOR_ROOT = "files/1_Animal_Behavior"
EMOTION_ROOT  = "files/2_Animal_emotions"
SOUND_ROOT    = "files/3_Animal_Sound"
PATELLA_ROOT  = "files/6_Animal_Patella"
HEALTH_ROOT   = "files/7_Animal_Health"   # ÎπÑÎßåÎèÑ Îç∞Ïù¥ÌÑ∞ÏÖã
WORK_DIR      = "files/work/omni_dataset"

# ‚îÄ‚îÄ ÏÉòÌîåÎßÅ Ï†úÍ±∞: Î™®Îì† Îç∞Ïù¥ÌÑ∞ ÌôúÏö© ‚îÄ‚îÄ
# (behavior/emotion/sound Î™®Îëê Ï†ÑÏ≤¥ Îç∞Ïù¥ÌÑ∞ ÏÇ¨Ïö©, Î∂àÍ∑†ÌòïÏùÄ class_weightÎ°ú Î≥¥Ï†ï)

BATCH_SIZE  = 32
EPOCHS      = 100
LR_VIDEO    = 5e-5
LR_AUDIO    = 1e-5
DEVICE      = "cuda:1" if torch.cuda.is_available() else "cpu"
NUM_WORKERS = 24
SR          = 16000
MAX_AUDIO_LEN = SR * 5

# ‚îÄ‚îÄ LOSS_WEIGHTS ‚îÄ‚îÄ
# emotion: 0.8‚Üí1.0 (Î∂àÍ∑†ÌòïÏùÄ class_weightÎ°ú Ï≤òÎ¶¨ÌïòÎØÄÎ°ú task Ï§ëÏöîÎèÑÎäî ÎèôÏùºÌïòÍ≤å)
# health : 1.0 (ÎπÑÎßåÎèÑ, ÏûÑÏÉÅÏ†Å Ï§ëÏöîÎèÑ ÎÜíÏùå)
LOSS_WEIGHTS = {
    "behavior": 1.0,
    "emotion":  1.0,
    "sound":    0.6,
    "patella":  1.0,
    "health":   1.0,
}

# BCS ‚Üí ÎπÑÎßåÎèÑ 3-class Îß§Ìïë
BCS_TO_LABEL = {1: 0, 2: 0, 3: 0,   # Ï†ÄÏ≤¥Ï§ë
                4: 1, 5: 1, 6: 1,   # Î≥¥ÌÜµ
                7: 2, 8: 2, 9: 2}   # Í≥ºÏ≤¥Ï§ë
BCS_CLASSES  = ["underweight", "normal", "overweight"]

AUDIO_MODEL_NAME = "facebook/wav2vec2-base"
FEATURE_EXTRACTOR = Wav2Vec2FeatureExtractor.from_pretrained(AUDIO_MODEL_NAME)

print(f"üéØ Device: {DEVICE}")

# =========================
# üî• Audio Augmentation
# =========================
def augment_audio(waveform, p=0.5):
    if random.random() > p:
        return waveform
    
    n_steps = random.uniform(-2, 2)
    waveform = librosa.effects.pitch_shift(waveform, sr=SR, n_steps=n_steps)
    
    rate = random.uniform(0.9, 1.1)
    stretched = librosa.effects.time_stretch(waveform, rate=rate)
    if len(stretched) > MAX_AUDIO_LEN:
        stretched = stretched[:MAX_AUDIO_LEN]
    else:
        stretched = np.pad(stretched, (0, MAX_AUDIO_LEN - len(stretched)))
    waveform = stretched
    
    noise = np.random.normal(0, 0.003, len(waveform))
    waveform = waveform * 0.99 + noise
    
    return waveform

# =========================
# 1. Dataset Preparation
# =========================
def collect_samples(root, exts):
    samples = []
    for class_dir in sorted(os.listdir(root)):
        class_path = os.path.join(root, class_dir)
        if not os.path.isdir(class_path):
            continue
        
        for root_dir, _, files in os.walk(class_path):
            for filename in files:
                if any(filename.lower().endswith(ext) for ext in exts):
                    file_path = os.path.join(root_dir, filename)
                    samples.append((class_dir, file_path))
    
    print(f"  ‚Üí {len(samples)} samples, {len(set(s[0] for s in samples))} classes")
    return samples

def collect_patella_samples(root):
    samples = []
    
    for grade in sorted(os.listdir(root)):
        grade_path = os.path.join(root, grade)
        if not os.path.isdir(grade_path):
            continue
        
        for date_dir in os.listdir(grade_path):
            date_path = os.path.join(grade_path, date_dir)
            if not os.path.isdir(date_path):
                continue
            
            for direction in ['Back', 'Front', 'Left', 'Right']:
                direction_path = os.path.join(date_path, direction)
                if not os.path.exists(direction_path):
                    continue
                
                for filename in os.listdir(direction_path):
                    if filename.lower().endswith('.jpg'):
                        img_path = os.path.join(direction_path, filename)
                        json_path = img_path.replace('.jpg', '.json')
                        
                        if os.path.exists(json_path):
                            samples.append((grade, img_path, json_path))
    
    print(f"  ‚Üí {len(samples)} samples, {len(set(s[0] for s in samples))} classes")
    return samples

def sample_balanced(samples):
    """ÏÉòÌîåÎßÅ ÏóÜÏù¥ Ï†ÑÏ≤¥ Îç∞Ïù¥ÌÑ∞ Î∞òÌôò. Î∂àÍ∑†ÌòïÏùÄ ÌïôÏäµ Ïãú class_weightÎ°ú Î≥¥Ï†ï."""
    class_counts = defaultdict(int)
    for label, _ in samples:
        class_counts[label] += 1
    print(f"  üìä {len(class_counts)} classes, total {len(samples)} samples (all used)")
    for label, cnt in sorted(class_counts.items()):
        print(f"    {label}: {cnt}")
    return samples

def sample_balanced_audio(samples):
    """ÏÉòÌîåÎßÅ ÏóÜÏù¥ Ï†ÑÏ≤¥ Ïò§ÎîîÏò§ Îç∞Ïù¥ÌÑ∞ Î∞òÌôò. Î∂àÍ∑†ÌòïÏùÄ class_weightÎ°ú Î≥¥Ï†ï."""
    class_counts = defaultdict(int)
    for label, _ in samples:
        class_counts[label] += 1
    print(f"  üìä {len(class_counts)} classes, total {len(samples)} samples (all used)")
    for label, cnt in sorted(class_counts.items()):
        print(f"    {label}: {cnt}")
    return samples

def collect_health_samples(root):
    """
    7_Animal_Health/Cat/, 7_Animal_Health/Dog/ ÌïòÏúÑÏùò jpg+json Ïåç ÏàòÏßë.
    label: BCS_TO_LABEL Í∏∞Ï§Ä 3-class (underweight / normal / overweight)
    Î∞òÌôò: [(label_str, img_path, json_path), ...]
    """
    samples = []
    for species in sorted(os.listdir(root)):           # Cat, Dog
        species_path = os.path.join(root, species)
        if not os.path.isdir(species_path):
            continue
        for fname in os.listdir(species_path):
            if not fname.lower().endswith('.jpg'):
                continue
            img_path  = os.path.join(species_path, fname)
            json_path = img_path.replace('.jpg', '.json').replace('.JPG', '.json')
            if not os.path.exists(json_path):
                continue
            try:
                with open(json_path, 'r', encoding='utf-8') as f:
                    meta = json.load(f)
                bcs = int(meta['metadata']['physical']['BCS'])
                label = BCS_CLASSES[BCS_TO_LABEL[bcs]]
                samples.append((label, img_path, json_path))
            except Exception:
                continue

    label_counts = defaultdict(int)
    for label, _, _ in samples:
        label_counts[label] += 1
    print(f"  ‚Üí {len(samples)} samples, {len(label_counts)} classes (all used)")
    for lbl, cnt in sorted(label_counts.items()):
        print(f"    {lbl}: {cnt}")
    return samples


class HealthDataset(Dataset):
    """
    ÎπÑÎßåÎèÑ(BCS) Î∂ÑÎ•ò Dataset.
    - JSONÏùò bounding box(points)Î°ú Ïù¥ÎØ∏ÏßÄÎ•º cropÌïú Îí§ ÌïôÏäµ
    - label: 0=Ï†ÄÏ≤¥Ï§ë, 1=Î≥¥ÌÜµ, 2=Í≥ºÏ≤¥Ï§ë
    """
    LABEL_TO_ID = {cls: i for i, cls in enumerate(BCS_CLASSES)}

    def __init__(self, samples, augment=False):
        """
        samples: collect_health_samples() Î∞òÌôòÍ∞í ÎòêÎäî
                 split_health()Î°ú ÎÇòÎâú [(label, img_path, json_path), ...]
        """
        self.samples = samples
        self.augment = augment

        if augment:
            self.transform = transforms.Compose([
                transforms.Resize((256, 256)),
                transforms.RandomCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(0.2, 0.2, 0.2),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        else:
            self.transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        label, img_path, json_path = self.samples[idx]

        img = Image.open(img_path).convert("RGB")
        w, h = img.size

        # Bounding box crop (points: [[x1,y1],[x2,y2]])
        try:
            with open(json_path, 'r', encoding='utf-8') as f:
                meta = json.load(f)
            pts = meta['annotations']['label']['points']
            x1, y1 = pts[0]
            x2, y2 = pts[1]
            x1, x2 = max(0, min(x1, x2)), min(w, max(x1, x2))
            y1, y2 = max(0, min(y1, y2)), min(h, max(y1, y2))
            if x2 > x1 and y2 > y1:
                img = img.crop((x1, y1, x2, y2))
        except Exception:
            pass   # crop Ïã§Ìå® Ïãú ÏõêÎ≥∏ Ïù¥ÎØ∏ÏßÄ Í∑∏ÎåÄÎ°ú ÏÇ¨Ïö©

        img = self.transform(img)
        return img, self.LABEL_TO_ID[label]


def split_health(samples, train_ratio=0.8, val_ratio=0.1):
    """Health samplesÎ•º train/val/testÎ°ú Î∂ÑÎ¶¨ (ÌååÏùº Î≥µÏÇ¨ ÏóÜÏù¥ Ïù∏Î©îÎ™®Î¶¨ Î∂ÑÎ¶¨)."""
    random.shuffle(samples)
    class_split = defaultdict(list)
    for item in samples:
        class_split[item[0]].append(item)

    train, val, test = [], [], []
    for items in class_split.values():
        n = len(items)
        n_train = int(n * train_ratio)
        n_val   = int(n * val_ratio)
        train.extend(items[:n_train])
        val.extend(items[n_train:n_train + n_val])
        test.extend(items[n_train + n_val:])

    print(f"  Health split ‚Üí train:{len(train)}, val:{len(val)}, test:{len(test)}")
    return train, val, test


def split_and_copy(samples, task_name, is_patella=False, original_samples=None):
    """
    original_samples: sound task Ï†ÑÏö©. test setÏùÑ Ïò§Î≤ÑÏÉòÌîå Ïù¥Ï†Ñ ÏõêÎ≥∏ÏóêÏÑú Î∂ÑÎ¶¨Ìï† Îïå ÏÇ¨Ïö©.
                      Î≤ÑÍ∑∏ 1 ÏàòÏ†ï - Ïò§Î≤ÑÏÉòÌîåÎêú poolÍ≥º testÍ∞Ä Í≤πÏπòÎäî data leakage Î∞©ÏßÄ.
    """
    random.shuffle(samples)
    class_samples = defaultdict(list)

    if is_patella:
        for label, img_path, json_path in samples:
            class_samples[label].append((img_path, json_path))
    else:
        for label, path in samples:
            class_samples[label].append(path)

    for split in ["train", "val", "test"]:
        os.makedirs(os.path.join(WORK_DIR, split, task_name), exist_ok=True)

    # ‚úÖ sound: test setÏùÄ Ïò§Î≤ÑÏÉòÌîå Ïù¥Ï†Ñ ÏõêÎ≥∏(original_samples)ÏóêÏÑú Î≥ÑÎèÑ Ï∂îÏ∂ú
    if original_samples is not None:
        orig_class = defaultdict(list)
        for label, path in original_samples:
            orig_class[label].append(path)
        test_items_by_label = {
            label: paths[:max(10, len(paths) // 5)]
            for label, paths in orig_class.items()
        }
    else:
        test_items_by_label = None

    for label, items in class_samples.items():
        n = len(items)
        n_train = int(n * 0.8)
        n_val   = int(n * 0.1)

        if test_items_by_label is not None:
            # sound: train/valÏùÄ Ïò§Î≤ÑÏÉòÌîå pool, testÎäî ÏõêÎ≥∏
            train_items = items[:n_train]
            val_items   = items[n_train:n_train + n_val]
            test_items  = test_items_by_label.get(label, [])
        else:
            train_items = items[:n_train]
            val_items   = items[n_train:n_train + n_val]
            test_items  = items[n_train + n_val:]

        split_map = {"train": train_items, "val": val_items, "test": test_items}

        for split_name, split_items in split_map.items():
            dst_label_dir = os.path.join(WORK_DIR, split_name, task_name, label)
            os.makedirs(dst_label_dir, exist_ok=True)

            for item in tqdm(split_items, desc=f"{task_name}/{split_name}/{label}", leave=False):
                if is_patella:
                    img_path, json_path = item
                    dst_img  = os.path.join(dst_label_dir, f"{label}_{os.path.basename(img_path)}")
                    shutil.copy(img_path, dst_img)
                    dst_json = dst_img.replace('.jpg', '.json')
                    shutil.copy(json_path, dst_json)
                else:
                    dst_path = os.path.join(dst_label_dir, f"{label}_{os.path.basename(item)}")
                    shutil.copy(item, dst_path)

def _task_ready(task_name):
    """Ìï¥Îãπ taskÏùò train Ìè¥ÎçîÍ∞Ä Ï°¥Ïû¨ÌïòÍ≥† ÎπÑÏñ¥ÏûàÏßÄ ÏïäÏúºÎ©¥ True"""
    task_train = os.path.join(WORK_DIR, "train", task_name)
    return os.path.isdir(task_train) and len(os.listdir(task_train)) > 0


def prepare_dataset():
    # taskÎ≥Ñ ÎèÖÎ¶Ω Ï≤¥ÌÅ¨: ÏóÜÎäî taskÎßå ÏÑ†ÌÉùÏ†ÅÏúºÎ°ú Ï§ÄÎπÑ
    need_behavior = not _task_ready("behavior")
    need_emotion  = not _task_ready("emotion")
    need_sound    = not _task_ready("sound")
    need_patella  = not _task_ready("patella")
    # HealthÎäî ÌååÏùº Î≥µÏÇ¨ ÏóÜÏù¥ Ïù∏Î©îÎ™®Î¶¨ Î∂ÑÎ¶¨ ‚Üí Ìï≠ÏÉÅ ÏàòÏßë

    if not any([need_behavior, need_emotion, need_sound, need_patella]):
        print("\u2705 All file-copy tasks already prepared, skipping.")
    else:
        for split in ["train", "val", "test"]:
            os.makedirs(os.path.join(WORK_DIR, split), exist_ok=True)

        if need_behavior:
            print("\n\U0001f4e6 Collecting behavior (all samples)...")
            behavior_all = collect_samples(BEHAVIOR_ROOT, ['.jpg', '.png', '.jpeg'])
            behavior = sample_balanced(behavior_all)          # \u2705 \uc804\uccb4 \uc0ac\uc6a9
            print("  \U0001f4cb Splitting & Copying behavior...")
            split_and_copy(behavior, "behavior")
        else:
            print("\u2705 behavior already prepared, skipping.")

        if need_emotion:
            print("\n\U0001f4e6 Collecting emotion (all samples)...")
            emotion_all = collect_samples(EMOTION_ROOT, ['.jpg', '.png', '.jpeg'])
            emotion = sample_balanced(emotion_all)            # \u2705 \uc804\uccb4 \uc0ac\uc6a9
            print("  \U0001f4cb Splitting & Copying emotion...")
            split_and_copy(emotion, "emotion")
        else:
            print("\u2705 emotion already prepared, skipping.")

        if need_sound:
            print("\n\U0001f4e6 Collecting sound (all samples)...")
            sound_all = collect_samples(SOUND_ROOT, ['.wav', '.mp3', '.m4a'])
            sound = sample_balanced_audio(sound_all)          # \u2705 \uc804\uccb4 \uc0ac\uc6a9
            print("  \U0001f4cb Splitting & Copying sound...")
            split_and_copy(sound, "sound", original_samples=sound_all)
        else:
            print("\u2705 sound already prepared, skipping.")

        if need_patella:
            print("\n\U0001f4e6 Collecting patella luxation (all samples)...")
            patella_all = collect_patella_samples(PATELLA_ROOT)
            print("  \u2139\ufe0f  Patella: Using all samples")
            print("  \U0001f4cb Splitting & Copying patella...")
            split_and_copy(patella_all, "patella", is_patella=True)
        else:
            print("\u2705 patella already prepared, skipping.")

    # Health: \ud30c\uc77c \ubcf5\uc0ac \uc5c6\uc774 \uc778\uba54\ubaa8\ub9ac split \u2192 \ud56d\uc0c1 \uc218\uc9d1
    print("\n\U0001f4e6 Collecting health/BCS (all samples)...")
    health_all = collect_health_samples(HEALTH_ROOT)
    health_train, health_val, health_test = split_health(health_all)

    print("\n\u2705 Dataset preparation complete.")
    return health_train, health_val, health_test

# =========================
# 2. Dataset Classes
# =========================
class ImageDataset(Dataset):
    def __init__(self, task_dir, augment=False):
        self.samples = []
        self.label_to_id = {}
        
        for label in sorted(os.listdir(task_dir)):
            label_dir = os.path.join(task_dir, label)
            if not os.path.isdir(label_dir):
                continue
            
            self.label_to_id[label] = len(self.label_to_id)
            
            for file in os.listdir(label_dir):
                if file.lower().endswith(('.jpg', '.png', '.jpeg')):
                    self.samples.append((os.path.join(label_dir, file), label))
        
        print(f"  üìä {os.path.basename(task_dir)}: {len(self.samples)} samples, {len(self.label_to_id)} classes")
        
        if augment:
            self.transform = transforms.Compose([
                transforms.Resize((256,256)),
                transforms.RandomCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(0.2, 0.2, 0.2),
                transforms.ToTensor(),
                transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
            ])
        else:
            self.transform = transforms.Compose([
                transforms.Resize((224,224)),
                transforms.ToTensor(),
                transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
            ])
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        path, label = self.samples[idx]
        img = Image.open(path).convert("RGB")
        img = self.transform(img)
        return img, self.label_to_id[label]

class PatellaDataset(Dataset):
    def __init__(self, task_dir, augment=False):
        self.samples = []
        self.label_to_id = {}
        
        for label in sorted(os.listdir(task_dir)):
            label_dir = os.path.join(task_dir, label)
            if not os.path.isdir(label_dir):
                continue
            
            self.label_to_id[label] = len(self.label_to_id)
            
            for file in os.listdir(label_dir):
                if file.lower().endswith('.jpg'):
                    img_path = os.path.join(label_dir, file)
                    json_path = img_path.replace('.jpg', '.json')
                    
                    if os.path.exists(json_path):
                        self.samples.append((img_path, json_path, label))
        
        print(f"  üìä {os.path.basename(task_dir)}: {len(self.samples)} samples, {len(self.label_to_id)} classes")
        
        if augment:
            self.transform = transforms.Compose([
                transforms.Resize((256,256)),
                transforms.RandomCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(0.2, 0.2, 0.2),
                transforms.ToTensor(),
                transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
            ])
        else:
            self.transform = transforms.Compose([
                transforms.Resize((224,224)),
                transforms.ToTensor(),
                transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
            ])
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, json_path, label = self.samples[idx]
        
        img = Image.open(img_path).convert("RGB")
        img = self.transform(img)
        
        with open(json_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        keypoints = []
        for annotation in data.get('annotation_info', []):
            x = float(annotation.get('x', 0))
            y = float(annotation.get('y', 0))
            keypoints.extend([x, y])
        
        while len(keypoints) < 18:
            keypoints.append(0.0)
        
        keypoints = torch.tensor(keypoints[:18], dtype=torch.float32)
        
        return img, keypoints, self.label_to_id[label]

class AudioDataset(Dataset):
    def __init__(self, task_dir, augment=False):
        self.samples = []
        self.label_to_id = {}
        self.id_to_label = {}   # ‚úÖ Ïó≠Î∞©Ìñ• Îß§Ìïë Ï∂îÍ∞Ä
        self.augment = augment
        next_id = 0

        for label in sorted(os.listdir(task_dir)):
            label_dir = os.path.join(task_dir, label)
            if not os.path.isdir(label_dir):
                continue

            self.label_to_id[label] = next_id
            self.id_to_label[next_id] = label
            next_id += 1

            for file in os.listdir(label_dir):
                if file.lower().endswith(('.wav', '.mp3', '.m4a')):
                    self.samples.append((os.path.join(label_dir, file), label))

        print(f"  üìä {os.path.basename(task_dir)}: {len(self.samples)} samples, {len(self.label_to_id)} classes, augment={augment}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]

        try:
            waveform, _ = librosa.load(path, sr=SR, mono=True)
        except Exception:
            waveform = np.zeros(MAX_AUDIO_LEN)

        if self.augment:
            waveform = augment_audio(waveform)

        if len(waveform) > MAX_AUDIO_LEN:
            waveform = waveform[:MAX_AUDIO_LEN]
        else:
            waveform = np.pad(waveform, (0, MAX_AUDIO_LEN - len(waveform)))

        inputs = FEATURE_EXTRACTOR(waveform, sampling_rate=SR, return_tensors="pt")
        # ‚úÖ dict Î∞òÌôò: collate_fnÏù¥ ÏïàÏ†ÑÌïòÍ≤å Ïä§ÌÉùÌï† Ïàò ÏûàÎèÑÎ°ù
        return {
            "input_values": inputs.input_values.squeeze(0),
            "labels": torch.tensor(self.label_to_id[label], dtype=torch.long)
        }


def collate_fn_audio(batch):
    """AudioDatasetÏùò dict Î∞∞ÏπòÎ•º ÏïàÏ†ÑÌïòÍ≤å ÌÖåÏÑ†ÏÑúÎ°ú Î≥¥Ìòà"""
    input_values = torch.stack([item["input_values"] for item in batch])
    labels       = torch.stack([item["labels"]       for item in batch])
    return {"input_values": input_values, "labels": labels}

# =========================
# 3. Individual Models (ÎèÖÎ¶Ω Î™®Îç∏)
# =========================
def _efficientnet_b3_backbone():
    """EfficientNet-B3 backbone Í≥µÌÜµ ÎπåÎçî. fc Ï†úÍ±∞ ÌõÑ feature vector Î∞òÌôò."""
    backbone = efficientnet_b3(weights=EfficientNet_B3_Weights.IMAGENET1K_V1)
    in_features = backbone.classifier[1].in_features   # 1536
    backbone.classifier = nn.Identity()
    return backbone, in_features


class BehaviorModel(nn.Module):
    """ÌñâÎèô Î∂ÑÎ•ò: EfficientNet-B3 backbone"""
    def __init__(self, num_classes):
        super().__init__()
        self.backbone, in_features = _efficientnet_b3_backbone()
        self.head = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(in_features, num_classes)
        )

    def forward(self, x):
        return self.head(self.backbone(x))


class EmotionModel(nn.Module):
    """Í∞êÏ†ï Î∂ÑÎ•ò: EfficientNet-B3 backbone (ÎèÖÎ¶Ω backbone, Î≥ÑÎèÑ class_weight Ï†ÅÏö©)"""
    def __init__(self, num_classes):
        super().__init__()
        self.backbone, in_features = _efficientnet_b3_backbone()
        self.head = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(in_features, num_classes)
        )

    def forward(self, x):
        return self.head(self.backbone(x))


class PatellaModel(nn.Module):
    """Ïä¨Í∞úÍ≥® ÌÉàÍµ¨: EfficientNet-B3 + keypoint concat"""
    def __init__(self, num_classes):
        super().__init__()
        self.backbone, in_features = _efficientnet_b3_backbone()
        self.head = nn.Sequential(
            nn.Linear(in_features + 18, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    def forward(self, x, keypoints):
        feat = self.backbone(x)
        return self.head(torch.cat([feat, keypoints], dim=1))


class HealthModel(nn.Module):
    """ÎπÑÎßåÎèÑ(BCS) Î∂ÑÎ•ò: EfficientNet-B3 backbone, 3-class (Ï†ÄÏ≤¥Ï§ë/Î≥¥ÌÜµ/Í≥ºÏ≤¥Ï§ë)"""
    def __init__(self, num_classes=3):
        super().__init__()
        self.backbone, in_features = _efficientnet_b3_backbone()
        self.head = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(in_features, num_classes)
        )

    def forward(self, x):
        return self.head(self.backbone(x))

class AudioModel(nn.Module):
    def __init__(self, num_classes, freeze_backbone=False):
        super().__init__()
        self.model = Wav2Vec2ForSequenceClassification.from_pretrained(
            AUDIO_MODEL_NAME,
            num_labels=num_classes,
            ignore_mismatched_sizes=True
        )

        if freeze_backbone:
            for param in self.model.wav2vec2.parameters():
                param.requires_grad = False

    def forward(self, input_values, labels=None):
        # ‚úÖ labelsÎ•º ÎÑòÍ∏∞Î©¥ Î™®Îç∏ ÎÇ¥Î∂ÄÏóêÏÑú lossÎ•º ÏßÅÏ†ë Í≥ÑÏÇ∞ (padding mask Í≥†Î†§)
        return self.model(input_values=input_values, labels=labels)

# =========================
# 4. Helper Functions
# =========================
def mixup_data(x, y, alpha=0.4):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)
    
    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]
    
    return mixed_x, y_a, y_b, lam

def clear_memory():
    """üî• Î©îÎ™®Î¶¨ Ï†ïÎ¶¨"""
    gc.collect()
    torch.cuda.empty_cache()

# =========================
# 5. Sequential Training (Î©îÎ™®Î¶¨ Ìö®Ïú®Ï†Å)
# =========================
def train():
    # prepare_datasetÏùÄ health splitÏùÑ Î∞òÌôò
    health_train_samples, health_val_samples, _ = prepare_dataset()

    # label_to_id ÎØ∏Î¶¨ Î°úÎìú
    print("\n\U0001f504 Pre-loading label mappings...")
    temp_b = ImageDataset(os.path.join(WORK_DIR, "train", "behavior"), augment=False)
    temp_e = ImageDataset(os.path.join(WORK_DIR, "train", "emotion"), augment=False)
    temp_s = AudioDataset(os.path.join(WORK_DIR, "train", "sound"), augment=False)
    temp_p = PatellaDataset(os.path.join(WORK_DIR, "train", "patella"), augment=False)

    behavior_label_to_id = temp_b.label_to_id
    emotion_label_to_id  = temp_e.label_to_id
    sound_label_to_id    = temp_s.label_to_id
    sound_id_to_label    = temp_s.id_to_label
    patella_label_to_id  = temp_p.label_to_id

    # Emotion class_weight (Î∂àÍ∑†Ìòï Î≥¥Ï†ï)
    emotion_labels_list = [temp_e.label_to_id[label] for _, label in temp_e.samples]
    emotion_class_weights = compute_class_weight(
        'balanced',
        classes=np.arange(len(emotion_label_to_id)),
        y=emotion_labels_list
    )
    emotion_class_weights_tensor = torch.tensor(emotion_class_weights, dtype=torch.float)

    del temp_b, temp_e, temp_s, temp_p
    clear_memory()

    # Î™®Îç∏ Ï¥àÍ∏∞Ìôî (CPUÏóê Î®ºÏ†Ä ÏÉùÏÑ±)
    print("\n\U0001f504 Initializing models...")
    behavior_model = BehaviorModel(len(behavior_label_to_id))
    emotion_model  = EmotionModel(len(emotion_label_to_id))
    patella_model  = PatellaModel(len(patella_label_to_id))
    audio_model    = AudioModel(len(sound_label_to_id), freeze_backbone=False)
    health_model   = HealthModel(num_classes=3)

    # Optimizers
    behavior_opt = torch.optim.AdamW(behavior_model.parameters(), lr=LR_VIDEO, weight_decay=0.01)
    emotion_opt  = torch.optim.AdamW(emotion_model.parameters(),  lr=LR_VIDEO, weight_decay=0.01)
    patella_opt  = torch.optim.AdamW(patella_model.parameters(),  lr=LR_VIDEO, weight_decay=0.01)
    audio_opt    = torch.optim.AdamW(audio_model.parameters(),    lr=LR_AUDIO, weight_decay=0.01)
    health_opt   = torch.optim.AdamW(health_model.parameters(),   lr=LR_VIDEO, weight_decay=0.01)

    # Audio LR Warmup Scheduler
    _temp_sound = AudioDataset(os.path.join(WORK_DIR, "train", "sound"), augment=False)
    _approx_sound_steps = (len(_temp_sound) // BATCH_SIZE) * EPOCHS
    del _temp_sound
    audio_scheduler = get_linear_schedule_with_warmup(
        audio_opt,
        num_warmup_steps=100,
        num_training_steps=_approx_sound_steps
    )
    clear_memory()

    # Scalers
    video_scaler = torch.amp.GradScaler("cuda")
    audio_scaler = torch.amp.GradScaler("cuda")

    # Loss (Í∏∞Î≥∏ criterion: label smoothing Ï†ÅÏö©)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    # Emotion Ï†ÑÏö© criterion: class_weight Ï∂îÍ∞ÄÎ°ú Î∂àÍ∑†Ìòï Î≥¥Ï†ï
    criterion_emotion = nn.CrossEntropyLoss(
        weight=emotion_class_weights_tensor.to(DEVICE),
        label_smoothing=0.1
    )
    # Health Ï†ÑÏö© criterion: BCS 3-class Î∂àÍ∑†Ìòï Î≥¥Ï†ï (ÌïôÏäµ Ïãú ÎèôÏ†Å Í≥ÑÏÇ∞)
    # (HealthDataset ÏÉòÌîåÏù¥ epochÎßàÎã§ Í≥†Ï†ïÏù¥ÎØÄÎ°ú 1Ìöå Í≥ÑÏÇ∞ ÌõÑ Ïû¨ÏÇ¨Ïö©)
    _health_labels = [HealthDataset.LABEL_TO_ID[s[0]] for s in health_train_samples]
    _health_cw = compute_class_weight('balanced', classes=np.arange(3), y=_health_labels)
    criterion_health = nn.CrossEntropyLoss(
        weight=torch.tensor(_health_cw, dtype=torch.float).to(DEVICE),
        label_smoothing=0.1
    )

    best_avg_acc = 0
    history = []
    
    for epoch in range(EPOCHS):
        print(f"\n{'='*60}")
        print(f"Epoch {epoch+1}/{EPOCHS}")
        print(f"{'='*60}")
        
        loss_b, loss_e, loss_s, loss_p, loss_h = 0, 0, 0, 0, 0
        
        # ========== 1. Behavior ==========
        print(f"\nüêæ Training Behavior...")
        behavior_model.to(DEVICE)
        behavior_model.train()
        
        behavior_train = ImageDataset(os.path.join(WORK_DIR, "train", "behavior"), augment=True)
        behavior_loader = DataLoader(behavior_train, BATCH_SIZE, True, num_workers=NUM_WORKERS, pin_memory=True)
        
        for imgs, labels in tqdm(behavior_loader, desc="Behavior", leave=False):
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            
            behavior_opt.zero_grad()  # ‚úÖ zero_grad Ïù¥Îèô: forward ÏïûÏúºÎ°ú
            with torch.amp.autocast("cuda"):
                imgs, labels_a, labels_b, lam = mixup_data(imgs, labels)
                logits = behavior_model(imgs)
                loss = lam * criterion(logits, labels_a) + (1 - lam) * criterion(logits, labels_b)

            video_scaler.scale(loss).backward()
            video_scaler.step(behavior_opt)
            video_scaler.update()
            
            loss_b += loss.item()
        
        loss_b /= len(behavior_loader)
        print(f"  ‚Üí Avg Loss: {loss_b:.4f}")
        
        # üî• Î©îÎ™®Î¶¨ Ìï¥Ï†ú
        behavior_model.cpu()
        del behavior_train, behavior_loader
        clear_memory()
        
        # ========== 2. Emotion ==========
        print(f"\nüòä Training Emotion...")
        emotion_model.to(DEVICE)
        emotion_model.train()
        
        emotion_train = ImageDataset(os.path.join(WORK_DIR, "train", "emotion"), augment=True)
        emotion_loader = DataLoader(emotion_train, BATCH_SIZE, True, num_workers=NUM_WORKERS, pin_memory=True)
        
        for imgs, labels in tqdm(emotion_loader, desc="Emotion", leave=False):
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            
            emotion_opt.zero_grad()  # ‚úÖ zero_grad Ïù¥Îèô: forward ÏïûÏúºÎ°ú
            with torch.amp.autocast("cuda"):
                imgs, labels_a, labels_b, lam = mixup_data(imgs, labels)
                logits = emotion_model(imgs)
                # ‚úÖ emotion: class_weight Ï†ÅÏö© criterion ÏÇ¨Ïö©
                loss = (lam * criterion_emotion(logits, labels_a)
                        + (1 - lam) * criterion_emotion(logits, labels_b))
                loss = loss * LOSS_WEIGHTS["emotion"]

            video_scaler.scale(loss).backward()
            video_scaler.step(emotion_opt)
            video_scaler.update()

            loss_e += loss.item()
        
        loss_e /= len(emotion_loader)
        print(f"  ‚Üí Avg Loss: {loss_e:.4f}")
        
        emotion_model.cpu()
        del emotion_train, emotion_loader
        clear_memory()
        
        # ========== 3. Sound ==========
        print(f"\nüîä Training Sound...")
        audio_model.to(DEVICE)
        audio_model.train()
        
        sound_train = AudioDataset(os.path.join(WORK_DIR, "train", "sound"), augment=True)

        # ‚úÖ ÌÅ¥ÎûòÏä§ Í∞ÄÏ§ëÏπò (epochÎßàÎã§ Í≥ÑÏÇ∞ Ïú†ÏßÄ ‚Äì ÌÅ¥ÎûòÏä§Î≥Ñ Î≥ÑÎèÑ criterion)
        sound_labels_list = [item[1] for item in sound_train.samples]
        sound_label_ids   = [sound_train.label_to_id[l] for l in sound_labels_list]
        class_weights = compute_class_weight(
            'balanced',
            classes=np.arange(len(sound_train.label_to_id)),
            y=sound_label_ids
        )
        class_weights_tensor = torch.tensor(class_weights, dtype=torch.float).to(DEVICE)

        # ‚úÖ collate_fn_audio Ï†ÅÏö©
        sound_loader = DataLoader(
            sound_train, BATCH_SIZE, True,
            num_workers=2, pin_memory=True,
            collate_fn=collate_fn_audio
        )

        for batch in tqdm(sound_loader, desc="Sound", leave=False):
            audios = batch["input_values"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)

            audio_opt.zero_grad()

            with torch.amp.autocast("cuda"):
                # ‚úÖ labels ÏßÅÏ†ë Ï†ÑÎã¨ ‚Üí outputs.loss ÏÇ¨Ïö© (padding mask Í≥†Î†§)
                outputs = audio_model(input_values=audios, labels=labels)
                # ‚úÖ LOSS_WEIGHTS Ïú†ÏßÄ, ÌÅ¥ÎûòÏä§ Í∞ÄÏ§ëÏπòÎäî ÏàòÎèôÏúºÎ°ú Ï†ÅÏö©
                loss = outputs.loss * LOSS_WEIGHTS["sound"]
                # class_weightsÎ•º ÌôúÏö©Ìïú Î≥¥Ï†ï Ìï≠ Ï∂îÍ∞Ä
                per_sample_w = class_weights_tensor[labels]
                loss = (loss * per_sample_w.mean())

            audio_scaler.scale(loss).backward()
            audio_scaler.unscale_(audio_opt)
            torch.nn.utils.clip_grad_norm_(audio_model.parameters(), 1.0)
            audio_scaler.step(audio_opt)
            audio_scaler.update()
            # ‚úÖ Ïä§ÏºÄÏ§ÑÎü¨ step
            audio_scheduler.step()

            loss_s += loss.item()

        loss_s /= len(sound_loader)
        print(f"  ‚Üí Avg Loss: {loss_s:.4f}")

        audio_model.cpu()
        del sound_train, sound_loader, class_weights_tensor
        clear_memory()
        
        # ========== 4. Patella ==========
        print(f"\nü¶¥ Training Patella...")
        patella_model.to(DEVICE)
        patella_model.train()
        
        patella_train = PatellaDataset(os.path.join(WORK_DIR, "train", "patella"), augment=True)
        patella_loader = DataLoader(patella_train, BATCH_SIZE, True, num_workers=NUM_WORKERS, pin_memory=True)
        
        for imgs, keypoints, labels in tqdm(patella_loader, desc="Patella", leave=False):
            imgs, keypoints, labels = imgs.to(DEVICE), keypoints.to(DEVICE), labels.to(DEVICE)
            
            patella_opt.zero_grad()  # ‚úÖ zero_grad Ïù¥Îèô: forward ÏïûÏúºÎ°ú
            with torch.amp.autocast("cuda"):
                imgs, labels_a, labels_b, lam = mixup_data(imgs, labels)
                logits = patella_model(imgs, keypoints)
                loss = lam * criterion(logits, labels_a) + (1 - lam) * criterion(logits, labels_b)

            video_scaler.scale(loss).backward()
            video_scaler.step(patella_opt)
            video_scaler.update()
            
            loss_p += loss.item()
        
        loss_p /= len(patella_loader)
        print(f"  ‚Üí Avg Loss: {loss_p:.4f}")
        
        patella_model.cpu()
        del patella_train, patella_loader
        clear_memory()

        # ========== 5. Health (BCS) ==========
        print(f"\n\U0001f4aa Training Health(BCS)...")
        health_model.to(DEVICE)
        health_model.train()

        health_train_ds = HealthDataset(health_train_samples, augment=True)
        health_loader   = DataLoader(health_train_ds, BATCH_SIZE, True,
                                     num_workers=NUM_WORKERS, pin_memory=True)

        for imgs, labels in tqdm(health_loader, desc="Health", leave=False):
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)

            health_opt.zero_grad()
            with torch.amp.autocast("cuda"):
                imgs, labels_a, labels_b, lam = mixup_data(imgs, labels)
                logits = health_model(imgs)
                loss = (lam * criterion_health(logits, labels_a)
                        + (1 - lam) * criterion_health(logits, labels_b))
                loss = loss * LOSS_WEIGHTS["health"]

            video_scaler.scale(loss).backward()
            video_scaler.step(health_opt)
            video_scaler.update()

            loss_h += loss.item()

        loss_h /= len(health_loader)
        print(f"  \u2192 Avg Loss: {loss_h:.4f}")

        health_model.cpu()
        del health_train_ds, health_loader
        clear_memory()

        # ========== Validation ==========
        print(f"\nüîç Validation...")
        
        # Behavior Val
        behavior_model.to(DEVICE)
        behavior_model.eval()
        behavior_val = ImageDataset(os.path.join(WORK_DIR, "val", "behavior"), augment=False)
        behavior_val_loader = DataLoader(behavior_val, BATCH_SIZE, False, num_workers=NUM_WORKERS//2, pin_memory=True)
        
        correct_b, total_b = 0, 0
        with torch.no_grad():
            for imgs, labels in tqdm(behavior_val_loader, desc="Val Behavior", leave=False):
                imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
                logits = behavior_model(imgs)
                pred = logits.argmax(-1)
                correct_b += (pred == labels).sum().item()
                total_b += labels.size(0)
        acc_b = correct_b / total_b
        
        behavior_model.cpu()
        del behavior_val, behavior_val_loader
        clear_memory()
        
        # Emotion Val
        emotion_model.to(DEVICE)
        emotion_model.eval()
        emotion_val = ImageDataset(os.path.join(WORK_DIR, "val", "emotion"), augment=False)
        emotion_val_loader = DataLoader(emotion_val, BATCH_SIZE, False, num_workers=NUM_WORKERS//2, pin_memory=True)
        
        correct_e, total_e = 0, 0
        with torch.no_grad():
            for imgs, labels in tqdm(emotion_val_loader, desc="Val Emotion", leave=False):
                imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
                logits = emotion_model(imgs)
                pred = logits.argmax(-1)
                correct_e += (pred == labels).sum().item()
                total_e += labels.size(0)
        acc_e = correct_e / total_e
        
        emotion_model.cpu()
        del emotion_val, emotion_val_loader
        clear_memory()
        
        # Sound Val
        audio_model.to(DEVICE)
        audio_model.eval()
        sound_val = AudioDataset(os.path.join(WORK_DIR, "val", "sound"), augment=False)
        # ‚úÖ collate_fn_audio Ï†ÅÏö©
        sound_val_loader = DataLoader(
            sound_val, BATCH_SIZE, False,
            num_workers=2, pin_memory=True,
            collate_fn=collate_fn_audio
        )

        correct_s, total_s = 0, 0
        with torch.no_grad():
            for batch in tqdm(sound_val_loader, desc="Val Sound", leave=False):
                audios = batch["input_values"].to(DEVICE)
                labels = batch["labels"].to(DEVICE)
                # ‚úÖ outputs.logits ÏÇ¨Ïö©
                outputs = audio_model(input_values=audios, labels=labels)
                pred = outputs.logits.argmax(-1)
                correct_s += (pred == labels).sum().item()
                total_s   += labels.size(0)
        acc_s = correct_s / total_s
        
        audio_model.cpu()
        del sound_val, sound_val_loader
        clear_memory()
        
        # Patella Val
        patella_model.to(DEVICE)
        patella_model.eval()
        patella_val = PatellaDataset(os.path.join(WORK_DIR, "val", "patella"), augment=False)
        patella_val_loader = DataLoader(patella_val, BATCH_SIZE, False, num_workers=NUM_WORKERS//2, pin_memory=True)
        
        correct_p, total_p = 0, 0
        with torch.no_grad():
            for imgs, keypoints, labels in tqdm(patella_val_loader, desc="Val Patella", leave=False):
                imgs, keypoints, labels = imgs.to(DEVICE), keypoints.to(DEVICE), labels.to(DEVICE)
                logits = patella_model(imgs, keypoints)
                pred = logits.argmax(-1)
                correct_p += (pred == labels).sum().item()
                total_p += labels.size(0)
        acc_p = correct_p / total_p

        patella_model.cpu()
        del patella_val, patella_val_loader
        clear_memory()

        # Health Val
        health_model.to(DEVICE)
        health_model.eval()
        health_val_ds     = HealthDataset(health_val_samples, augment=False)
        health_val_loader = DataLoader(health_val_ds, BATCH_SIZE, False,
                                       num_workers=NUM_WORKERS // 2, pin_memory=True)

        correct_h, total_h = 0, 0
        with torch.no_grad():
            for imgs, labels in tqdm(health_val_loader, desc="Val Health", leave=False):
                imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
                logits = health_model(imgs)
                pred   = logits.argmax(-1)
                correct_h += (pred == labels).sum().item()
                total_h   += labels.size(0)
        acc_h = correct_h / total_h

        health_model.cpu()
        del health_val_ds, health_val_loader
        clear_memory()

        avg_acc = (acc_b + acc_e + acc_s + acc_p + acc_h) / 5
        
        print(f"\n\U0001f4ca Results:")
        print(f"  Behavior: Loss {loss_b:.4f} | Acc {acc_b:.4f} ({acc_b*100:.1f}%)")
        print(f"  Emotion:  Loss {loss_e:.4f} | Acc {acc_e:.4f} ({acc_e*100:.1f}%)")
        print(f"  Sound:    Loss {loss_s:.4f} | Acc {acc_s:.4f} ({acc_s*100:.1f}%)")
        print(f"  Patella:  Loss {loss_p:.4f} | Acc {acc_p:.4f} ({acc_p*100:.1f}%)")
        print(f"  Health:   Loss {loss_h:.4f} | Acc {acc_h:.4f} ({acc_h*100:.1f}%)")
        print(f"  Average Acc: {avg_acc:.4f} ({avg_acc*100:.1f}%)")

        history.append({
            'epoch' : epoch + 1,
            'loss_b': loss_b, 'loss_e': loss_e, 'loss_s': loss_s,
            'loss_p': loss_p, 'loss_h': loss_h,
            'acc_b' : acc_b,  'acc_e' : acc_e,  'acc_s' : acc_s,
            'acc_p' : acc_p,  'acc_h' : acc_h,  'acc_avg': avg_acc,
        })

        if avg_acc > best_avg_acc:
            best_avg_acc = avg_acc
            torch.save({
                "behavior_model":       behavior_model.state_dict(),
                "emotion_model":        emotion_model.state_dict(),
                "audio_model":          audio_model.state_dict(),
                "patella_model":        patella_model.state_dict(),
                "health_model":         health_model.state_dict(),
                "behavior_label_to_id": behavior_label_to_id,
                "emotion_label_to_id":  emotion_label_to_id,
                "sound_label_to_id":    sound_label_to_id,
                "sound_id_to_label":    sound_id_to_label,
                "patella_label_to_id":  patella_label_to_id,
                "health_classes":       BCS_CLASSES,
                "best_epoch":           epoch + 1,
                "best_acc":             best_avg_acc,
                "history":              history,
            }, "pet_normal_omni_best.pth")
            print(f"  \U0001f4be Saved new best model! (Acc: {best_avg_acc:.4f})")
    
    # ÌïôÏäµ Í≥°ÏÑ† ÏãúÍ∞ÅÌôî
    print("\n\U0001f4c8 Generating training history plot...")
    fig, axes = plt.subplots(2, 5, figsize=(25, 8))

    tasks = [
        ('acc_b', 'b-',      'Behavior'),
        ('acc_e', 'r-',      'Emotion'),
        ('acc_s', 'g-',      'Sound'),
        ('acc_p', 'purple',  'Patella'),
        ('acc_h', 'orange',  'Health'),
    ]
    loss_keys = ['loss_b', 'loss_e', 'loss_s', 'loss_p', 'loss_h']

    for i, (acc_key, color, title) in enumerate(tasks):
        # Loss row
        axes[0, i].plot([h[loss_keys[i]] for h in history], color=color, linewidth=2)
        axes[0, i].set_title(f'{title} Loss')
        axes[0, i].set_xlabel('Epoch'); axes[0, i].set_ylabel('Loss')
        axes[0, i].grid(True, alpha=0.3)
        # Accuracy row
        axes[1, i].plot([h[acc_key] for h in history], color=color, linewidth=2)
        axes[1, i].set_title(f'{title} Accuracy')
        axes[1, i].set_xlabel('Epoch'); axes[1, i].set_ylabel('Accuracy')
        axes[1, i].set_ylim(0, 1); axes[1, i].grid(True, alpha=0.3)

    plt.suptitle('Pet Normal Omni Model Training History', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig('pet_normal_omni_history.png', dpi=150, bbox_inches='tight')
    plt.close()
    print("  \u2705 Saved: pet_normal_omni_history.png")

    print(f"\n\U0001f389 Training Finished!")
    print(f"  Best Average Acc: {best_avg_acc:.4f} ({best_avg_acc*100:.1f}%)")

if __name__ == "__main__":
    train()

  from .autonotebook import tqdm as notebook_tqdm


üéØ Device: cuda:1
‚úÖ All file-copy tasks already prepared, skipping.

üì¶ Collecting health/BCS (all samples)...
  ‚Üí 4171 samples, 3 classes (all used)
    normal: 4029
    overweight: 93
    underweight: 49
  Health split ‚Üí train:3336, val:415, test:420

‚úÖ Dataset preparation complete.

üîÑ Pre-loading label mappings...
  üìä behavior: 11843 samples, 25 classes
  üìä emotion: 44766 samples, 10 classes
  üìä sound: 995 samples, 14 classes, augment=False
  üìä patella: 80696 samples, 5 classes

üîÑ Initializing models...


Loading weights: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 211/211 [00:00<00:00, 606.34it/s, Materializing param=wav2vec2.masked_spec_embed]                                            
[1mWav2Vec2ForSequenceClassification LOAD REPORT[0m from: facebook/wav2vec2-base
Key                          | Status     | 
-----------------------------+------------+-
project_hid.weight           | UNEXPECTED | 
project_q.bias               | UNEXPECTED | 
quantizer.weight_proj.bias   | UNEXPECTED | 
project_q.weight             | UNEXPECTED | 
project_hid.bias             | UNEXPECTED | 
quantizer.weight_proj.weight | UNEXPECTED | 
quantizer.codevectors        | UNEXPECTED | 
projector.bias               | MISSING    | 
projector.weight             | MISSING    | 
classifier.bias              | MISSING    | 
classifier.weight            | MISSING    | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING[3m	:those pa

  üìä sound: 995 samples, 14 classes, augment=False

Epoch 1/100

üêæ Training Behavior...
  üìä behavior: 11843 samples, 25 classes


                                                           

  ‚Üí Avg Loss: 3.0736

üòä Training Emotion...
  üìä emotion: 44766 samples, 10 classes


                                                            

  ‚Üí Avg Loss: 2.6950

üîä Training Sound...
  üìä sound: 995 samples, 14 classes, augment=True


  audio_scheduler.step()
                                                      

  ‚Üí Avg Loss: 1.6198

ü¶¥ Training Patella...
  üìä patella: 80696 samples, 5 classes


Patella:   1%|‚ñè         | 36/2522 [00:09<03:59, 10.37it/s] 

## Test

In [None]:
# import os
# import random
# import shutil
# from tqdm import tqdm
# import matplotlib.pyplot as plt
# import torch
# import torch.nn as nn
# from torch.utils.data import Dataset, DataLoader
# from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor, get_linear_schedule_with_warmup
# import torchvision.transforms as transforms
# from torchvision.models import resnet34, ResNet34_Weights
# from PIL import Image
# import librosa
# import numpy as np
# from collections import defaultdict, Counter
# from sklearn.utils.class_weight import compute_class_weight

# AUDIO_MODEL_NAME = "facebook/wav2vec2-base"
# FEATURE_EXTRACTOR = Wav2Vec2FeatureExtractor.from_pretrained(AUDIO_MODEL_NAME)

# class VideoMultiBackbone(nn.Module):
#     def __init__(self, num_b, num_e):
#         super().__init__()
        
#         backbone_b = resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)
#         in_features_b = backbone_b.fc.in_features
#         backbone_b.fc = nn.Identity()
#         self.behavior_backbone = backbone_b
#         self.behavior_head = nn.Linear(in_features_b, num_b)
        
#         backbone_e = resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)
#         in_features_e = backbone_e.fc.in_features
#         backbone_e.fc = nn.Identity()
#         self.emotion_backbone = backbone_e
#         self.emotion_head = nn.Linear(in_features_e, num_e)
    
#     def forward(self, x, task):
#         if task == "behavior":
#             feat = self.behavior_backbone(x)
#             return self.behavior_head(feat)
#         elif task == "emotion":
#             feat = self.emotion_backbone(x)
#             return self.emotion_head(feat)
#         else:
#             raise ValueError("Task must be 'behavior' or 'emotion'")
        
# class AudioModel(nn.Module):
#     def __init__(self, num_classes, freeze_backbone=False):  # üî• Í∏∞Î≥∏Í∞í False
#         super().__init__()
#         self.model = Wav2Vec2ForSequenceClassification.from_pretrained(
#             AUDIO_MODEL_NAME,
#             num_labels=num_classes,
#             ignore_mismatched_sizes=True
#         )
        
#         # üî• Freeze ÏòµÏÖò (Í∏∞Î≥∏: Ï†ÑÏ≤¥ ÌïôÏäµ)
#         if freeze_backbone:
#             for param in self.model.wav2vec2.parameters():
#                 param.requires_grad = False
    
#     def forward(self, x):
#         return self.model(input_values=x).logits

# def test():
#     from transformers import Wav2Vec2FeatureExtractor
#     FEATURE_EXTRACTOR = Wav2Vec2FeatureExtractor.from_pretrained(
#         "facebook/wav2vec2-base"
#     )

#     DEVICE = "cuda:1" if torch.cuda.is_available() else "cpu"
#     BATCH_SIZE = 16
#     SR = 16000
#     MAX_AUDIO_LEN = SR * 5

#     print("üîé Loading best model...")
#     checkpoint = torch.load("pet_omni_best.pth", map_location=DEVICE)

#     behavior_label_to_id = checkpoint["behavior_label_to_id"]
#     emotion_label_to_id = checkpoint["emotion_label_to_id"]
#     sound_label_to_id = checkpoint["sound_label_to_id"]

#     # -----------------------------
#     # Î™®Îç∏ Î≥µÏõê
#     # -----------------------------
#     video_model = VideoMultiBackbone(
#         len(behavior_label_to_id),
#         len(emotion_label_to_id)
#     ).to(DEVICE)

#     audio_model = AudioModel(
#         len(sound_label_to_id)
#     ).to(DEVICE)

#     video_model.load_state_dict(checkpoint["video_model"])
#     audio_model.load_state_dict(checkpoint["audio_model"])

#     video_model.eval()
#     audio_model.eval()

#     print("üì¶ Loading TEST datasets...")

#     TEST_DIR = os.path.join("files", "work", "omni_dataset", "test")

#     # -----------------------------
#     # Train ÏùòÏ°¥ ÏóÜÎäî Dataset Ï†ïÏùò
#     # -----------------------------
#     class TestImageDataset(Dataset):
#         def __init__(self, task_dir, label_to_id):
#             self.samples = []
#             self.label_to_id = label_to_id

#             for label in os.listdir(task_dir):
#                 if label not in label_to_id:
#                     continue

#                 label_dir = os.path.join(task_dir, label)
#                 for file in os.listdir(label_dir):
#                     if file.lower().endswith(('.jpg','.png','.jpeg')):
#                         self.samples.append(
#                             (os.path.join(label_dir,file),
#                              label_to_id[label])
#                         )

#             self.transform = transforms.Compose([
#                 transforms.Resize((224,224)),
#                 transforms.ToTensor(),
#                 transforms.Normalize(
#                     [0.485,0.456,0.406],
#                     [0.229,0.224,0.225]
#                 )
#             ])

#         def __len__(self):
#             return len(self.samples)

#         def __getitem__(self, idx):
#             path, label_id = self.samples[idx]
#             img = Image.open(path).convert("RGB")
#             img = self.transform(img)
#             return img, label_id


#     class TestAudioDataset(Dataset):
#         def __init__(self, task_dir, label_to_id):
#             self.samples = []
#             self.label_to_id = label_to_id

#             for label in os.listdir(task_dir):
#                 if label not in label_to_id:
#                     continue

#                 label_dir = os.path.join(task_dir, label)
#                 for file in os.listdir(label_dir):
#                     if file.lower().endswith(('.wav','.mp3','.m4a')):
#                         self.samples.append(
#                             (os.path.join(label_dir,file),
#                              label_to_id[label])
#                         )

#         def __len__(self):
#             return len(self.samples)

#         def __getitem__(self, idx):
#             path, label_id = self.samples[idx]
#             waveform, _ = librosa.load(path, sr=SR, mono=True)

#             if len(waveform) > MAX_AUDIO_LEN:
#                 waveform = waveform[:MAX_AUDIO_LEN]
#             else:
#                 waveform = np.pad(
#                     waveform,
#                     (0, MAX_AUDIO_LEN - len(waveform))
#                 )

#             inputs = FEATURE_EXTRACTOR(
#                 waveform,
#                 sampling_rate=SR,
#                 return_tensors="pt"
#             )

#             return inputs.input_values.squeeze(0), label_id


#     # -----------------------------
#     # Loader
#     # -----------------------------
#     behavior_loader = DataLoader(
#         TestImageDataset(
#             os.path.join(TEST_DIR,"behavior"),
#             behavior_label_to_id
#         ),
#         BATCH_SIZE, False
#     )

#     emotion_loader = DataLoader(
#         TestImageDataset(
#             os.path.join(TEST_DIR,"emotion"),
#             emotion_label_to_id
#         ),
#         BATCH_SIZE, False
#     )

#     sound_loader = DataLoader(
#         TestAudioDataset(
#             os.path.join(TEST_DIR,"sound"),
#             sound_label_to_id
#         ),
#         BATCH_SIZE, False
#     )

#     # -----------------------------
#     # Evaluation
#     # -----------------------------
#     def evaluate(loader, task):
#         correct, total = 0, 0
#         with torch.no_grad():
#             for x, y in loader:
#                 x, y = x.to(DEVICE), y.to(DEVICE)

#                 if task in ["behavior","emotion"]:
#                     logits = video_model(x, task)
#                 else:
#                     logits = audio_model(x)

#                 pred = logits.argmax(-1)
#                 correct += (pred == y).sum().item()
#                 total += y.size(0)

#         return correct / total if total > 0 else 0


#     acc_b = evaluate(behavior_loader, "behavior")
#     acc_e = evaluate(emotion_loader, "emotion")
#     acc_s = evaluate(sound_loader, "sound")

#     avg_acc = (acc_b + acc_e + acc_s) / 3

#     print("\nüìä TEST Results:")
#     print(f"  Behavior Acc: {acc_b:.4f} ({acc_b*100:.1f}%)")
#     print(f"  Emotion Acc:  {acc_e:.4f} ({acc_e*100:.1f}%)")
#     print(f"  Sound Acc:    {acc_s:.4f} ({acc_s*100:.1f}%)")
#     print(f"  Average Acc:  {avg_acc:.4f} ({avg_acc*100:.1f}%)")


# if __name__ == "__main__":
#     test()



üîé Loading best model...


Loading weights: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 211/211 [00:00<00:00, 658.49it/s, Materializing param=wav2vec2.masked_spec_embed]                                            
[1mWav2Vec2ForSequenceClassification LOAD REPORT[0m from: facebook/wav2vec2-base
Key                          | Status     | 
-----------------------------+------------+-
project_hid.weight           | UNEXPECTED | 
project_q.weight             | UNEXPECTED | 
quantizer.codevectors        | UNEXPECTED | 
quantizer.weight_proj.bias   | UNEXPECTED | 
quantizer.weight_proj.weight | UNEXPECTED | 
project_hid.bias             | UNEXPECTED | 
project_q.bias               | UNEXPECTED | 
projector.weight             | MISSING    | 
classifier.bias              | MISSING    | 
projector.bias               | MISSING    | 
classifier.weight            | MISSING    | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING[3m	:those pa

üì¶ Loading TEST datasets...

üìä TEST Results:
  Behavior Acc: 0.7273 (72.7%)
  Emotion Acc:  0.7525 (75.2%)
  Sound Acc:    0.9138 (91.4%)
  Average Acc:  0.7979 (79.8%)
