In [None]:
import os
import cv2
import numpy as np
import pandas as pd
import glob
import zipfile
import matplotlib.pyplot as plt
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from google.colab import drive

# 1. Mount Google Drive
drive.mount('/content/drive')

# 2. Đường dẫn đến dữ liệu
drive_folder = "/content/drive/MyDrive/kaggle_data/aptos2019"
extract_root = "/content/extracted_zip_files"
os.makedirs(extract_root, exist_ok=True)

# Giải nén các file ZIP nếu chưa giải (nếu đã giải thì bỏ qua)
zip_files = glob.glob(os.path.join(drive_folder, "*.zip"))
for zip_path in zip_files:
    zip_name = os.path.basename(zip_path).replace(".zip", "")
    extract_path = os.path.join(extract_root, zip_name)
    os.makedirs(extract_path, exist_ok=True)
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_path)
    print(f"✅ Đã giải nén: {zip_path} → {extract_path}")

# 3. Đọc file CSV
df_train = pd.read_csv(os.path.join(drive_folder, "train.csv"))
df_test = pd.read_csv(os.path.join(drive_folder, "test.csv"))

# 4. Định nghĩa hàm xử lý ảnh: cắt, resize và tăng cường ảnh
def crop_image_from_gray_to_color(img, tol=7):
    """
    Cắt bỏ các vùng không cần thiết (đặc biệt là các cạnh tối) của ảnh dựa trên thông tin từ ảnh xám,
    sau đó áp dụng vùng cắt này lên ảnh màu gốc.
    """
    gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    mask = gray > tol
    if mask.sum() == 0:
        return img
    rows = mask.any(axis=1)
    cols = mask.any(axis=0)
    cropped_img = img[np.ix_(rows, cols)]
    return cropped_img

def load_ben_color(path, sigmaX=10, IMG_SIZE=244):
    """
    Load ảnh từ đường dẫn, cắt bỏ biên tối dựa trên ảnh xám, resize và tăng cường ảnh bằng GaussianBlur.
    """
    image = cv2.imread(path)
    if image is None:
        raise ValueError(f"Không thể đọc được ảnh từ đường dẫn: {path}")
    # Chuyển BGR sang RGB
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    # Cắt ảnh theo vùng sáng
    image = crop_image_from_gray_to_color(image, tol=7)
    # Resize ảnh về kích thước mong muốn
    image = cv2.resize(image, (IMG_SIZE, IMG_SIZE))
    # Tăng cường ảnh bằng GaussianBlur và weighted addition
    image = cv2.addWeighted(image, 4, cv2.GaussianBlur(image, (0, 0), sigmaX), -4, 128)
    return image

# 5. Xử lý và lưu ảnh đã xử lý vào một thư mục tạm thời trên Colab
train_img_folder = os.path.join(extract_root, "train_images")  # Thư mục chứa ảnh gốc
processed_folder = "/content/processed_train_images"          # Thư mục lưu ảnh đã xử lý
os.makedirs(processed_folder, exist_ok=True)

processed_ids = []  # Lưu lại id của các ảnh đã được xử lý thành công

for idx, row in df_train.iterrows():
    img_filename = f"{row['id_code']}.png"
    img_path = os.path.join(train_img_folder, img_filename)

    try:
        proc_img = load_ben_color(img_path, sigmaX=10, IMG_SIZE=244)
        # cv2.imwrite lưu ảnh theo định dạng BGR nên chuyển từ RGB sang BGR
        proc_img_bgr = cv2.cvtColor(proc_img, cv2.COLOR_RGB2BGR)
        save_path = os.path.join(processed_folder, img_filename)
        cv2.imwrite(save_path, proc_img_bgr)
        processed_ids.append(row['id_code'])
    except Exception as e:
        print(f"Lỗi khi xử lý ảnh {img_filename}: {e}")

print(f"Đã xử lý thành công {len(processed_ids)} ảnh.")

# 6. Cập nhật DataFrame chỉ với các ảnh đã xử lý thành công
df_train_processed = df_train[df_train['id_code'].isin(processed_ids)].copy()

# 7. Chia dữ liệu thành tập train và validation dựa trên file CSV
x = df_train_processed['id_code']
y = df_train_processed['diagnosis']

# Xáo trộn dữ liệu để đảm bảo tính ngẫu nhiên
x, y = shuffle(x, y, random_state=42)

# Chia tập train+validation và test (80% - 20%)
x_temp, test_x, y_temp, test_y = train_test_split(x, y, test_size=0.20, stratify=y, random_state=42)

# Chia tập train và validation (85% train, 15% val trong 80% dữ liệu ban đầu)
train_x, valid_x, train_y, valid_y = train_test_split(x_temp, y_temp, test_size=0.15/0.80, stratify=y_temp, random_state=42)

# In thông tin kiểm tra
print("Train X size:", len(train_x))
print("Train y size:", len(train_y))
print("Valid X size:", len(valid_x))
print("Valid y size:", len(valid_y))
print("Test X size:", len(test_x))
print("Test y size:", len(test_y))

Mounted at /content/drive
✅ Đã giải nén: /content/drive/MyDrive/kaggle_data/aptos2019/train_images.zip → /content/extracted_zip_files/train_images
✅ Đã giải nén: /content/drive/MyDrive/kaggle_data/aptos2019/test_images.zip → /content/extracted_zip_files/test_images
Đã xử lý thành công 3662 ảnh.
Train X size: 2379
Train y size: 2379
Valid X size: 550
Valid y size: 550
Test X size: 733
Test y size: 733


In [None]:
import matplotlib.pyplot as plt
import skimage.io
from skimage.transform import resize
import albumentations as A
from tqdm import tqdm
import PIL
from PIL import Image, ImageOps
import cv2
from sklearn.utils import class_weight, shuffle
from keras.losses import binary_crossentropy, categorical_crossentropy

# Tiền xử lý cho các mô hình khác nhau
from keras.applications.resnet50 import preprocess_input as resnet50_preprocess
from keras.applications.inception_v3 import preprocess_input as inception_preprocess
from keras.applications.densenet import preprocess_input as densenet_preprocess
from keras.applications.efficientnet import preprocess_input as efficientnet_preprocess
from keras.applications.xception import preprocess_input as xception_preprocess  # Đã thay thế

import keras.backend as K
import tensorflow as tf
from sklearn.metrics import f1_score, fbeta_score, cohen_kappa_score, accuracy_score
from keras.utils import Sequence, to_categorical
from sklearn.model_selection import train_test_split

# Import các mô hình CNN từ Keras Applications
from keras.applications import ResNet50, EfficientNetB0, InceptionV3, DenseNet121, Xception  # Đã thay thế MobileNetV2


WORKERS = 2
CHANNEL = 3

import warnings
warnings.filterwarnings("ignore")
SIZE = 244
NUM_CLASSES = 5


  check_for_updates()


In [None]:
import numpy as np

def to_multi_label(target, num_classes=5):
    """ Chuyển đổi nhãn đơn thành multi-label (One-vs-All). """
    return (np.arange(num_classes) < (target[:, None] + 1)).astype(int)

In [None]:
import tensorflow as tf

class AdamAccumulate(tf.keras.optimizers.Adam):
    def __init__(self, accum_iters=2, **kwargs):
        """
        accum_iters: số bước tích lũy gradient trước khi cập nhật trọng số.
        kwargs: các tham số khác của Adam (như learning_rate, beta_1, beta_2, epsilon, decay, v.v.)
        """
        super(AdamAccumulate, self).__init__(**kwargs)
        if accum_iters < 1:
            raise ValueError('accum_iters phải >= 1')
        self.accum_iters = accum_iters
        # Đếm số batch đã tích lũy
        self._accum_steps = 0
        # Bộ nhớ chứa gradient đã tích lũy (dictionary: key là tham chiếu của biến)
        self._grad_accum = {}

    @tf.function
    def apply_gradients(self, grads_and_vars, name=None, experimental_aggregate_gradients=True):
        # Nếu đây là lần chạy đầu tiên, khởi tạo biến tích lũy cho mỗi biến
        if not self._grad_accum:
            for grad, var in grads_and_vars:
                if grad is None:
                    continue
                self._grad_accum[var.ref()] = tf.Variable(tf.zeros_like(var), trainable=False)

        # Tích lũy gradient cho mỗi biến
        for grad, var in grads_and_vars:
            if grad is None:
                continue
            self._grad_accum[var.ref()].assign_add(grad)

        self._accum_steps += 1

        # Chỉ cập nhật trọng số khi số bước tích lũy đạt đến accum_iters
        if self._accum_steps % self.accum_iters == 0:
            # Tính trung bình gradient và tạo danh sách cập nhật
            avg_grads_and_vars = []
            for grad, var in grads_and_vars:
                if grad is None:
                    continue
                accumulated_grad = self._grad_accum[var.ref()]
                avg_grad = accumulated_grad / tf.cast(self.accum_iters, accumulated_grad.dtype)
                avg_grads_and_vars.append((avg_grad, var))
                # Reset lại gradient tích lũy cho biến
                accumulated_grad.assign(tf.zeros_like(var))
            # Gọi phương thức apply_gradients của lớp Adam gốc để cập nhật trọng số
            super(AdamAccumulate, self).apply_gradients(avg_grads_and_vars, name, experimental_aggregate_gradients)
        # Nếu chưa đủ số bước, không cập nhật trọng số


In [None]:
import albumentations as A
import numpy as np
import cv2
import os
from sklearn.utils import shuffle
from keras.utils import Sequence

# Import preprocess_input từ các mô hình
from tensorflow.keras.applications.resnet50 import preprocess_input as resnet50_preprocess
from tensorflow.keras.applications.inception_v3 import preprocess_input as inception_preprocess
from tensorflow.keras.applications.densenet import preprocess_input as densenet_preprocess
from tensorflow.keras.applications.efficientnet import preprocess_input as efficientnet_preprocess
from tensorflow.keras.applications.xception import preprocess_input as xception_preprocess

class My_Generator(Sequence):
    def __init__(self, image_filenames, labels, batch_size, is_train=False,
                 mix=False, augment=False, size1=224, size2=299, model_type="default",
                 balance_classes=False):
        self.image_filenames = np.array(image_filenames)
        self.labels = np.array(labels)
        self.batch_size = batch_size
        self.is_train = is_train
        self.is_augment = augment
        self.is_mix = mix
        self.model_type = str(model_type).lower()
        self.n_classes = self.labels.shape[1] if self.labels.ndim > 1 else int(max(self.labels) + 1)

        if "inceptionv3" in self.model_type:
            self.target_size = (size2, size2)
        else:
            self.target_size = (size1, size1)

        self.base_path = "/content/processed_train_images/"

        if self.is_augment and self.is_train:
            self.augmenter = A.Compose([
                A.OneOf([
                    A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0, p=1),
                    A.MultiplicativeNoise(multiplier=(0.9, 1.1), per_channel=True, p=1),
                    A.RandomBrightnessContrast(brightness_limit=0, contrast_limit=0.1, p=1)
                ], p=0.5),
                A.HorizontalFlip(p=0.5),
                A.VerticalFlip(p=0.5),
                A.CropAndPad(percent=(-0.1, 0), p=0.5)
            ])

        self.rare_augmenter = A.Compose([
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.7),
            A.GaussNoise(p=0.5),
            A.Rotate(limit=30, p=0.5),
            A.RandomScale(scale_limit=0.2, p=0.5),
            A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, p=0.5),
            A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5)
        ])

        self.class_counts = self._compute_initial_class_counts()
        self.augmented_class_counts = self.class_counts.copy()
        self.class_weights = None  # Khởi tạo trọng số lớp là None

        if self.is_train and balance_classes:
            self.balance_classes()

        if self.is_train:
            self.on_epoch_end()

    def _compute_initial_class_counts(self):
        labels = np.argmax(self.labels, axis=1) if self.labels.ndim > 1 else self.labels
        return np.bincount(labels, minlength=self.n_classes)

    def _compute_class_weights(self):
        total_samples = np.sum(self.augmented_class_counts)
        if total_samples == 0:
            return np.ones(self.n_classes)
        class_weights = total_samples / (self.n_classes * self.augmented_class_counts)
        class_weights = np.where(np.isinf(class_weights) | (self.augmented_class_counts == 0), 1.0, class_weights)
        return class_weights / np.min(class_weights[np.isfinite(class_weights)])

    def get_class_weights(self):
        """Trả về trọng số lớp hiện tại để sử dụng trong huấn luyện."""
        return self.class_weights

    def balance_classes(self):
        class_counts = self._compute_initial_class_counts()
        max_count = class_counts[0]  # Sử dụng số lượng mẫu của lớp 0 làm mục tiêu

        print(f"Số lượng mẫu ban đầu: {class_counts}")
        print(f"Số lượng mẫu mục tiêu cho mỗi lớp (dựa trên lớp 0): {max_count}")

        new_filenames = []
        new_labels = []
        for cls in range(self.n_classes):
            current_count = class_counts[cls]
            if current_count == 0:
                print(f"Lớp {cls} không có mẫu, bỏ qua.")
                continue
            if cls == 3:
                target_count = int(max_count * 1.3)  # Lớp 3 được tăng thêm 30%
            else:
                target_count = max_count
            if current_count < target_count:
                samples_to_add = target_count - current_count
                label_indices = np.argmax(self.labels, axis=1) if self.labels.ndim > 1 else self.labels
                class_indices = np.where(label_indices == cls)[0]
                for i in range(samples_to_add):
                    idx = np.random.choice(class_indices)
                    img_id = self.image_filenames[idx]
                    label = self.labels[idx]
                    img = self._load_image(img_id)
                    if img is None:
                        continue
                    aug_img = self.rare_augmenter(image=img)['image']
                    new_img_id = f"{img_id}_balance_aug_{i}"
                    save_path = os.path.join(self.base_path, f"{new_img_id}.png")
                    if cv2.imwrite(save_path, cv2.cvtColor(aug_img, cv2.COLOR_RGB2BGR)):
                        new_filenames.append(new_img_id)
                        new_labels.append(np.array(label, dtype=self.labels.dtype))
                        self.augmented_class_counts[cls] += 1
                    else:
                        print(f"Lỗi khi lưu ảnh tăng cường {new_img_id}")

        if new_labels:
            new_labels_array = np.array(new_labels)
            if new_labels_array.ndim == 1:
                new_labels_array = new_labels_array[:, np.newaxis]
            self.image_filenames = np.concatenate([self.image_filenames, new_filenames])
            self.labels = np.concatenate([self.labels, new_labels_array])

        # Không gọi on_epoch_end() ở đây để tránh tính trọng số lớp ngay lập tức

        # Kiểm tra số lượng mẫu sau khi cân bằng
        updated_counts = self._compute_initial_class_counts()
        print(f"Số lượng mẫu sau khi cân bằng và tăng lớp 3: {updated_counts}")

    def augment_weak_classes(self, weak_classes, augment_factor=2):
        new_filenames = []
        new_labels = []
        for idx, label in enumerate(self.labels):
            label_class = np.argmax(label) if label.ndim > 1 else label
            if np.isscalar(label_class) and np.isin(label_class, weak_classes):
                img_id = self.image_filenames[idx]
                img = self._load_image(img_id)
                if img is None:
                    continue
                for i in range(augment_factor):
                    aug_img = self.rare_augmenter(image=img)['image']
                    new_img_id = f"{img_id}_weak_aug_{i}"
                    save_path = os.path.join(self.base_path, f"{new_img_id}.png")
                    if cv2.imwrite(save_path, cv2.cvtColor(aug_img, cv2.COLOR_RGB2BGR)):
                        new_filenames.append(new_img_id)
                        new_labels.append(np.array(label, dtype=self.labels.dtype))
                        self.augmented_class_counts[label_class] += 1
                    else:
                        print(f"Lỗi khi lưu ảnh tăng cường {new_img_id}")
        if new_labels:
            new_labels_array = np.array(new_labels)
            if new_labels_array.ndim == 1:
                new_labels_array = new_labels_array[:, np.newaxis]
            self.image_filenames = np.concatenate([self.image_filenames, new_filenames])
            self.labels = np.concatenate([self.labels, new_labels_array])

    def __len__(self):
        return int(np.ceil(len(self.image_filenames) / self.batch_size))

    def __getitem__(self, idx):
        batch_x = self.image_filenames[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.labels[idx * self.batch_size:(idx + 1) * self.batch_size]
        return self._generate_batch(batch_x, batch_y, augment=self.is_train)

    def on_epoch_end(self):
        if self.is_train:
            self.image_filenames, self.labels = shuffle(self.image_filenames, self.labels)
            # Tính trọng số lớp vào cuối mỗi epoch
            self.class_weights = self._compute_class_weights()
            print(f"Trọng số lớp sau epoch: {self.class_weights}")
            print(f"Số lượng mẫu tăng cường: {self.augmented_class_counts}")

    def _load_image(self, img_id):
        img_path = os.path.join(self.base_path, f"{img_id}.png")
        try:
            img = cv2.imread(img_path)
            if img is None:
                raise ValueError(f"Hình ảnh không tìm thấy hoặc bị hỏng: {img_path}")
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = cv2.resize(img, self.target_size)
            return img
        except Exception as e:
            print(f"Lỗi khi tải hình ảnh {img_id}: {str(e)}")
            return None

    def _generate_batch(self, batch_x, batch_y, augment=False):
        batch_images = []
        valid_labels = []

        for img_id, label in zip(batch_x, batch_y):
            img = self._load_image(img_id)
            if img is None:
                continue
            if augment and self.is_augment:
                img = self.augmenter(image=img.astype(np.uint8))['image']
            img = img.astype(np.float32) / 255.0

            if "resnet50" in self.model_type:
                img = resnet50_preprocess(img)
            elif "efficientnetb0" in self.model_type:
                img = efficientnet_preprocess(img)
            elif "inceptionv3" in self.model_type:
                img = inception_preprocess(img)
            elif "densenet121" in self.model_type:
                img = densenet_preprocess(img)
            elif "xception" in self.model_type:
                img = xception_preprocess(img)

            batch_images.append(img)
            valid_labels.append(label)

        if not batch_images:
            return np.zeros((1, *self.target_size, 3), dtype=np.float32), np.zeros((1, *batch_y.shape[1:]), dtype=np.float32)

        batch_images = np.array(batch_images)
        valid_labels = np.array(valid_labels)

        if self.is_mix and len(batch_images) > 1:
            batch_images, valid_labels = self._mixup(batch_images, valid_labels)

        return batch_images, valid_labels

    def _mixup(self, x, y):
        lam = np.random.beta(0.2, 0.4)
        index = np.random.permutation(len(x))
        mixed_x = np.zeros_like(x)
        mixed_y = np.zeros_like(y)
        for i in range(len(x)):
            if np.argmax(y[i]) == np.argmax(y[index[i]]):
                mixed_x[i] = lam * x[i] + (1 - lam) * x[index[i]]
                mixed_y[i] = y[i]
            else:
                mixed_x[i] = x[i]
                mixed_y[i] = y[i]
        return mixed_x, mixed_y

In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential, load_model
from keras.layers import (Activation, Dropout, Flatten, Dense, GlobalMaxPooling2D,
                          BatchNormalization, Input, Conv2D, GlobalAveragePooling2D)

from keras.applications.resnet50 import ResNet50
from keras.applications import EfficientNetB0
from keras.applications.inception_v3 import InceptionV3
from keras.applications.densenet import DenseNet121
from keras.applications.xception import Xception  # 🔄 Thay MobileNetV2 bằng Xception

from keras.callbacks import ModelCheckpoint
from keras import metrics
from keras.optimizers import Adam
from keras import backend as K
import keras
from keras.models import Model


In [None]:
def create_model(input_shape, n_out, model_type, weights_path=None, weights="imagenet"):
    input_tensor = Input(shape=input_shape)

    # Khởi tạo mô hình với weights hoặc weights_path
    if model_type == "resnet50":
        base_model = ResNet50(include_top=False, weights=weights if not weights_path else None, input_tensor=input_tensor)
    elif model_type == "efficientnetb0":
        base_model = EfficientNetB0(include_top=False, weights=weights if not weights_path else None, input_tensor=input_tensor)
    elif model_type == "inceptionv3":
        base_model = InceptionV3(include_top=False, weights=weights if not weights_path else None, input_tensor=input_tensor)
    elif model_type == "densenet121":
        base_model = DenseNet121(include_top=False, weights=weights if not weights_path else None, input_tensor=input_tensor)
    elif model_type == "xception":  # Đổi từ "mobilenetv2" thành "xception"
        base_model = Xception(include_top=False, weights=weights if not weights_path else None, input_tensor=input_tensor)

    else:
        raise ValueError(f"Unsupported model type: {model_type}")

    # Nếu có weights_path, tải trọng số từ file
    if weights_path:
        try:
            base_model.load_weights(weights_path)
            print(f"Loaded weights from {weights_path}")
        except Exception as e:
            print(f"Error loading weights from {weights_path}: {e}")
            raise

    x = GlobalAveragePooling2D(name='global_avg_pool')(base_model.output)
    x = Dropout(0.5)(x)
    x = Dense(1024, activation='relu')(x)
    x = Dropout(0.5)(x)
    final_output = Dense(n_out, activation="softmax", name='final_output')(x)

    model = Model(input_tensor, final_output)
    return model

In [None]:
# Cấu hình mô hình
model_configs = {
    # "xception": {
    #     "model_type": "xception",
    #     "weights": "imagenet",
    #     "save_path": "/content/drive/MyDrive/working/Xception_bestqwk.h5"
    # },
    "resnet50": {
        "model_type": "resnet50",
        "weights_path": "/content/drive/MyDrive/keras_weights/resnet50/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5",
        "save_path": "/content/drive/MyDrive/working/ResNet50_bestqwk.h5"
    },
    # "efficientnetb0": {
    #     "model_type": "efficientnetb0",
    #     "weights": "imagenet",
    #     "save_path": "/content/drive/MyDrive/working/EfficientNetB1_bestqwk.h5"
    # },
    # "inceptionv3": {
    #     "model_type": "inceptionv3",
    #     "weights_path": "/content/drive/MyDrive/keras_weights/inceptionv3/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5",
    #     "save_path": "/content/drive/MyDrive/working/InceptionV3_bestqwk.h5"
    # },
    "densenet121": {
        "model_type": "densenet121",
        "weights_path": "/content/drive/MyDrive/keras_weights/densenet121/densenet121_weights_tf_dim_ordering_tf_kernels_notop.h5",
        "save_path": "/content/drive/MyDrive/working/DenseNet121_bestqwk.h5"
    }
}

In [None]:
from sklearn.metrics import f1_score, recall_score, cohen_kappa_score, confusion_matrix
from datetime import datetime
import json
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
from tensorflow.keras.callbacks import EarlyStopping, Callback
from tensorflow.keras.losses import CategoricalCrossentropy
import shutil
from sklearn.utils import resample
import subprocess

# Class QWKEvaluation
class QWKEvaluation(tf.keras.callbacks.Callback):
    def __init__(self, validation_data=(), batch_size=64, interval=1, model_type=None, save_paths=None):
        super().__init__()
        self.interval = interval
        self.batch_size = batch_size
        self.valid_generator, self.y_val = validation_data
        self.history = []
        self.model_type = model_type
        self.save_paths = save_paths if save_paths is not None else {}
        self.save_path = self.save_paths.get(model_type, None)
        self.best_qwk = -float('inf')
        self.best_y_true = None
        self.best_y_pred = None

    def on_epoch_end(self, epoch, logs={}):
        if epoch % self.interval == 0:
            steps = int(np.ceil(len(self.y_val) / self.batch_size))
            y_pred = self.model.predict(self.valid_generator, steps=steps, verbose=1)

            if len(self.y_val.shape) > 1 and self.y_val.shape[1] > 1:
                y_true = np.argmax(self.y_val, axis=1)
                y_pred_classes = np.argmax(y_pred, axis=1)
            else:
                y_true = self.y_val.astype(int)
                y_pred_classes = np.argmax(y_pred, axis=1)

            score = cohen_kappa_score(y_true, y_pred_classes, labels=[0, 1, 2, 3, 4], weights='quadratic')
            print(f"\nEpoch {epoch+1} - QWK: {score:.4f}")

            f1 = f1_score(y_true, y_pred_classes, average=None, labels=[0, 1, 2, 3, 4])
            sensitivity = recall_score(y_true, y_pred_classes, average=None, labels=[0, 1, 2, 3, 4])
            print(f"F1-score per class: {f1}")
            print(f"Sensitivity per class: {sensitivity}")

            self.history.append(score)

            if score > self.best_qwk:
                self.best_qwk = score
                self.best_y_true = y_true
                self.best_y_pred = y_pred_classes
                print(f"New best QWK: {self.best_qwk:.4f} at Epoch {epoch+1}")

                if self.save_path:
                    keras_save_path = self.save_path.replace('.h5', '.keras')
                    save_dir = os.path.dirname(keras_save_path)
                    os.makedirs(save_dir, exist_ok=True)

                    self.model.save(keras_save_path, overwrite=True)
                    print(f"Saved (overwritten) full model to {keras_save_path}")

                    save_dir = self.save_path.replace('.h5', '')
                    os.makedirs(save_dir, exist_ok=True)

                    model_json = self.model.to_json()
                    config_path = os.path.join(save_dir, "config.json")
                    with open(config_path, "w") as json_file:
                        json_file.write(model_json)
                    print(f"Saved model architecture to {config_path}")

                    weights_path = os.path.join(save_dir, "model.weights.h5")
                    self.model.save_weights(weights_path)
                    print(f"Saved model weights to {weights_path}")

                    metadata = {
                        "keras_version": tf.keras.__version__,
                        "save_date": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                        "model_type": self.model_type
                    }
                    metadata_path = os.path.join(save_dir, "metadata.json")
                    with open(metadata_path, "w") as meta_file:
                        json.dump(metadata, meta_file)
                    print(f"Saved metadata to {metadata_path}")

# Class QWKReduceLROnPlateau
class QWKReduceLROnPlateau(tf.keras.callbacks.Callback):
    def __init__(self, qwk_callback, factor=0.5, patience=3, min_lr=1e-6, verbose=1):
        super().__init__()
        self.qwk_callback = qwk_callback
        self.factor = factor
        self.patience = patience
        self.min_lr = min_lr
        self.verbose = verbose
        self.best_qwk = -float('inf')
        self.wait = 0

    def on_epoch_end(self, epoch, logs=None):
        current_qwk = self.qwk_callback.history[-1] if self.qwk_callback.history else -float('inf')
        if current_qwk > self.best_qwk:
            self.best_qwk = current_qwk
            self.wait = 0
        else:
            self.wait += 1
            if self.wait >= self.patience:
                old_lr = float(self.model.optimizer.learning_rate)
                if old_lr > self.min_lr:
                    new_lr = max(old_lr * self.factor, self.min_lr)
                    self.model.optimizer.learning_rate.assign(new_lr)
                    if self.verbose > 0:
                        print(f"\nEpoch {epoch+1}: QWKReduceLROnPlateau reducing learning rate to {new_lr:.6f}.")
                    self.wait = 0

# Class DynamicRareClassAugmentationCallback
class DynamicRareClassAugmentationCallback(Callback):
    def __init__(self, train_generator, valid_generator, valid_labels, threshold=0.6, augment_factor=2):
        super().__init__()
        self.train_generator = train_generator
        self.valid_generator = valid_generator
        self.valid_labels = valid_labels
        self.threshold = threshold
        self.augment_factor = augment_factor
        self.f1_history = []
        self.batch_size = self.train_generator.batch_size
        self.num_classes = self.train_generator.n_classes

    def on_epoch_end(self, epoch, logs=None):
        steps = int(np.ceil(len(self.valid_labels) / self.batch_size))
        y_pred = self.model.predict(self.valid_generator, steps=steps, verbose=1)
        y_true = np.argmax(self.valid_labels, axis=1)
        y_pred_classes = np.argmax(y_pred, axis=1)

        f1_scores = f1_score(y_true, y_pred_classes, average=None, labels=list(range(self.num_classes)))
        print(f"F1-scores at epoch {epoch+1}: {f1_scores}")

        self.f1_history.append(f1_scores)

        weak_classes = [i for i, f1 in enumerate(f1_scores) if f1 < self.threshold]
        print(f"Weak classes at epoch {epoch+1} (F1 < {self.threshold}): {weak_classes}")

        if weak_classes:
            self.train_generator.augment_weak_classes(weak_classes, augment_factor=self.augment_factor)
            print(f"Augmented {self.augment_factor} samples for weak classes: {weak_classes}")

            self.train_generator.on_epoch_end()
            print(f"Updated class weights: {self.train_generator.get_class_weights()}")

        if len(self.f1_history) > 1:
            prev_f1 = self.f1_history[-2]
            curr_f1 = self.f1_history[-1]
            print(f"F1-score comparison (epoch {epoch} vs {epoch+1}):")
            for i in range(self.num_classes):
                print(f"Class {i}: {prev_f1[i]:.4f} -> {curr_f1[i]:.4f} (Change: {curr_f1[i] - prev_f1[i]:.4f})")

# Class để thu thập và vẽ loss
class LossHistoryCallback(Callback):
    def __init__(self):
        super().__init__()
        self.losses = []
        self.val_losses = []

    def on_epoch_end(self, epoch, logs=None):
        self.losses.append(logs.get('loss'))
        self.val_losses.append(logs.get('val_loss'))

    def plot_and_save_loss(self, model_type, save_dir="/content/drive/MyDrive/working/"):
        plt.figure(figsize=(10, 6))
        plt.plot(self.losses, label='Training Loss', marker='o')
        plt.plot(self.val_losses, label='Validation Loss', marker='s')
        plt.title(f'Training and Validation Loss - {model_type}')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True)

        os.makedirs(save_dir, exist_ok=True)
        save_path = os.path.join(save_dir, f'loss_plot_{model_type}.png')
        plt.savefig(save_path)
        plt.close()
        print(f"Saved loss plot to {save_path}")

# Chuyển đổi nhãn
NUM_CLASSES = 5
SIZE = 244
if len(train_y.shape) == 1 or train_y.shape[1] != NUM_CLASSES:
    train_y_multi = tf.keras.utils.to_categorical(train_y, num_classes = NUM_CLASSES)
    valid_y_multi = tf.keras.utils.to_categorical(valid_y, num_classes = NUM_CLASSES)
else:
    train_y_multi = train_y
    valid_y_multi = valid_y

batch_size = 64
resized_train_x = train_x.values
resized_valid_x = valid_x.values

# Định nghĩa callback
early_stopping = EarlyStopping(monitor='accuracy', patience=7, restore_best_weights=True, verbose=1, mode='max')

# Cấu hình mô hình
# Cấu hình mô hình
model_configs = {
    "xception": {
        "model_type": "xception",
        "weights": "imagenet",
        "save_path": "/content/drive/MyDrive/working/Xception_bestqwk_aptos.h5"
    },
    "resnet50": {
        "model_type": "resnet50",
        "weights_path": "/content/drive/MyDrive/keras_weights/resnet50/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5",
        "save_path": "/content/drive/MyDrive/working/ResNet50_bestqwk_aptos.h5"
    },
    "efficientnetb0": {
        "model_type": "efficientnetb0",
        "weights": "imagenet",
        "save_path": "/content/drive/MyDrive/working/EfficientNetB1_bestqwk_aptos.h5"
    },
    "inceptionv3": {
        "model_type": "inceptionv3",
        "weights_path": "/content/drive/MyDrive/keras_weights/inceptionv3/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5",
        "save_path": "/content/drive/MyDrive/working/InceptionV3_bestqwk_aptos.h5"
    },
    "densenet121": {
        "model_type": "densenet121",
        "weights_path": "/content/drive/MyDrive/keras_weights/densenet121/densenet121_weights_tf_dim_ordering_tf_kernels_notop.h5",
        "save_path": "/content/drive/MyDrive/working/DenseNet121_bestqwk_aptos.h5"
    }
}

# Kiểm tra trước khi huấn luyện
assert 'train_x' in globals() and 'valid_x' in globals(), "train_x hoặc valid_x không được định nghĩa"
assert 'train_y' in globals() and 'valid_y' in globals(), "train_y hoặc valid_y không được định nghĩa"
assert callable(create_model), "create_model không phải là hàm"
assert 'My_Generator' in globals(), "My_Generator không được định nghĩa"

# In thông tin dữ liệu để debug
print("train_y shape:", train_y.shape)
print("valid_y shape:", valid_y.shape)
print("train_y_multi shape:", train_y_multi.shape)
print("valid_y_multi shape:", valid_y_multi.shape)

# Vòng lặp huấn luyện
for model_name, config in model_configs.items():
    print(f"\n==> Đang huấn luyện mô hình {model_name} ...")

    if config["model_type"] == "inceptionv3":
        model_input_shape = (299, 299, 3)
        img_size = 299
    else:
        model_input_shape = (SIZE, SIZE, 3)
        img_size = SIZE

    # Tạo generator với balance_classes
    train_generator = My_Generator(
        resized_train_x, train_y_multi, batch_size,
        is_train=True, mix=False, augment=True,
        size1=SIZE, size2=299, model_type=config["model_type"],
        balance_classes=True
    )

    # Kiểm tra số lượng mẫu
    try:
        print(f"Số lượng mẫu: {train_generator.augmented_class_counts}")
    except AttributeError:
        print("Không truy cập được augmented_class_counts, tiếp tục huấn luyện...")

    valid_generator = My_Generator(
        resized_valid_x, valid_y_multi, batch_size,
        is_train=False, size1=SIZE, size2=299, model_type=config["model_type"]
    )

    # Lấy trọng số lớp ban đầu (sẽ được cập nhật sau mỗi epoch)
    class_weights = train_generator.get_class_weights()
    if class_weights is None:
        class_weights = np.ones(NUM_CLASSES)  # Mặc định nếu chưa tính
    class_weight = {i: float(w) for i, w in enumerate(class_weights)}
    print(f"Trọng số lớp ban đầu: {class_weight}")

    # Tạo mô hình
    weights_path = config.get("weights_path", None)
    pretrained_weights = config.get("weights", "imagenet")
    model = create_model(
        input_shape=model_input_shape,
        n_out=NUM_CLASSES,
        model_type=config["model_type"],
        weights_path=weights_path,
        weights=pretrained_weights
    )

    # Khởi tạo QWKEvaluation
    qwk_callback = QWKEvaluation(
        validation_data=(valid_generator, valid_y_multi),
        batch_size=batch_size,
        interval=1,
        model_type=config["model_type"],
        save_paths={config["model_type"]: config["save_path"]}
    )

    # Khởi tạo QWKReduceLROnPlateau
    qwk_reduce_lr = QWKReduceLROnPlateau(
        qwk_callback=qwk_callback,
        factor=0.5,
        patience=3,
        min_lr=1e-6,
        verbose=1
    )

    # Khởi tạo LossHistoryCallback
    loss_history = LossHistoryCallback()

    # Giai đoạn khởi động (warm-up) - KHÔNG dùng class_weight
    for layer in model.layers[:50]:
        layer.trainable = False
    for layer in model.layers[50:]:
        layer.trainable = True

    # Sử dụng CategoricalCrossentropy với label smoothing
    loss_fn = CategoricalCrossentropy(label_smoothing=0.1)
    model.compile(
        loss=loss_fn,
        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
        metrics=['accuracy']
    )

    augment_callback = DynamicRareClassAugmentationCallback(
        train_generator=train_generator,
        valid_generator=valid_generator,
        valid_labels=valid_y_multi,
        threshold=0.6,
        augment_factor=2
    )

    # Huấn luyện khởi động (Không áp dụng class_weight)
    model.fit(
        train_generator,
        steps_per_epoch=int(np.ceil(len(train_generator.image_filenames) / batch_size)),
        epochs=5,
        validation_data=valid_generator,
        validation_steps=int(np.ceil(len(valid_x) / batch_size)),
        verbose=1,
        callbacks=[qwk_callback, qwk_reduce_lr, early_stopping, augment_callback, loss_history]
    )

    # Huấn luyện đầy đủ
    for layer in model.layers:
        layer.trainable = True

    model.compile(
        loss=CategoricalCrossentropy(label_smoothing=0.1),
        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
        metrics=['accuracy']
    )

    train_mixup = My_Generator(
        resized_train_x, train_y_multi, batch_size,
        is_train=True, mix=True, augment=True,
        size1=SIZE, size2=299, model_type=config["model_type"],
        balance_classes=True
    )

    # Kiểm tra số lượng mẫu
    try:
        print(f"Số lượng mẫu (mixup): {train_mixup.augmented_class_counts}")
    except AttributeError:
        print("Không truy cập được augmented_class_counts (mixup), tiếp tục huấn luyện...")

    augment_callback = DynamicRareClassAugmentationCallback(
        train_generator=train_mixup,
        valid_generator=valid_generator,
        valid_labels=valid_y_multi,
        threshold=0.6,
        augment_factor=2
    )

    # Huấn luyện đầy đủ với vòng lặp epoch để cập nhật class_weight
    epochs = 30
    for epoch in range(epochs):
        print(f"\nBắt đầu epoch {epoch + 1} cho mô hình {model_name}")
        # Lấy trọng số lớp cho epoch hiện tại
        class_weights = train_mixup.get_class_weights()
        if class_weights is None:
            class_weights = np.ones(NUM_CLASSES)  # Mặc định nếu chưa tính
        class_weight = {i: float(w) for i, w in enumerate(class_weights)}
        print(f"Trọng số lớp cho epoch {epoch + 1}: {class_weight}")

        # Huấn luyện một epoch (Áp dụng class_weight)
        model.fit(
            train_mixup,
            steps_per_epoch=int(np.ceil(len(train_mixup.image_filenames) / batch_size)),
            epochs=1,  # Huấn luyện từng epoch
            validation_data=valid_generator,
            validation_steps=int(np.ceil(len(valid_x) / batch_size)),
            verbose=1,
            callbacks=[qwk_callback, qwk_reduce_lr, augment_callback, early_stopping, loss_history],
            class_weight=class_weight  # Áp dụng class_weight
        )

    # Lưu mô hình cuối cùng
    final_save_path = config["save_path"].replace('.h5', '_final.keras')
    model.save(final_save_path, overwrite=True)
    print(f"Đã lưu mô hình cuối cùng tại {final_save_path}")

    # Vẽ và lưu biểu đồ loss
    loss_history.plot_and_save_loss(config["model_type"], save_dir="/content/drive/MyDrive/working/")

    # Vẽ và lưu biểu đồ ma trận nhầm lẫn tốt nhất cho mô hình này
    if qwk_callback.best_y_true is not None and qwk_callback.best_y_pred is not None:
        cm = confusion_matrix(qwk_callback.best_y_true, qwk_callback.best_y_pred, labels=[0, 1, 2, 3, 4])
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                    xticklabels=[0, 1, 2, 3, 4],
                    yticklabels=[0, 1, 2, 3, 4])
        plt.title(f'Best Confusion Matrix - QWK: {qwk_callback.best_qwk:.4f} ({config["model_type"]})')
        plt.xlabel('Predicted')
        plt.ylabel('True')
        save_cm_path = f"/content/drive/MyDrive/working/best_confusion_matrix_{config['model_type']}.png"
        plt.savefig(save_cm_path)
        plt.close()  # Đóng figure để tránh chiếm bộ nhớ
        print(f"Saved best confusion matrix to {save_cm_path}")
    else:
        print(f"Không có best QWK được ghi nhận cho mô hình {config['model_type']}, không vẽ biểu đồ.")

    #TEST
    import numpy as np
    import tensorflow as tf
    import seaborn as sns
    import matplotlib.pyplot as plt
    from sklearn.metrics import cohen_kappa_score, f1_score, recall_score, precision_score, accuracy_score, confusion_matrix

    # ... (Giả sử các biến như model, test_x, test_y, batch_size, NUM_CLASSES, config, My_Generator đã được định nghĩa trước đó)

    print(f"\n==> Đang kiểm tra mô hình {model_name} trên tập test ...")

    # Chuyển nhãn test_y sang dạng one-hot nếu cần
    if len(test_y.shape) == 1 or test_y.shape[1] != NUM_CLASSES:
        test_y_multi = tf.keras.utils.to_categorical(test_y, num_classes=NUM_CLASSES)
    else:
        test_y_multi = test_y

    # Tạo generator cho tập test
    test_generator = My_Generator(
        test_x.values, test_y_multi, batch_size,
        is_train=False, size1=SIZE, size2=299, model_type=config["model_type"]
    )

    # Dự đoán trên tập test
    steps_test = int(np.ceil(len(test_x) / batch_size))
    y_pred_test = model.predict(test_generator, steps=steps_test, verbose=1)

    # Chuyển dự đoán và nhãn thật sang dạng lớp (class indices)
    y_true_test = np.argmax(test_y_multi, axis=1)
    y_pred_classes_test = np.argmax(y_pred_test, axis=1)

    # Tính Quadratic Weighted Kappa (QWK)
    qwk_test = cohen_kappa_score(y_true_test, y_pred_classes_test, labels=[0, 1, 2, 3, 4], weights='quadratic')
    print(f"QWK trên tập test: {qwk_test:.4f}")

    # Tính độ chính xác tổng thể (Accuracy)
    accuracy_test = accuracy_score(y_true_test, y_pred_classes_test)
    print(f"Độ chính xác trên tập test: {accuracy_test:.4f}")

    # Tính F1-score và độ nhạy (Sensitivity/Recall) cho từng lớp
    f1_test = f1_score(y_true_test, y_pred_classes_test, average=None, labels=[0, 1, 2, 3, 4])
    sensitivity_test = recall_score(y_true_test, y_pred_classes_test, average=None, labels=[0, 1, 2, 3, 4])
    print(f"F1-score cho từng lớp trên tập test: {[f'{f1:.4f}' for f1 in f1_test]}")
    print(f"Độ nhạy cho từng lớp trên tập test: {[f'{sens:.4f}' for sens in sensitivity_test]}")

    # Tính độ đặc hiệu (Specificity) cho từng lớp
    specificity_test = []
    cm = confusion_matrix(y_true_test, y_pred_classes_test, labels=[0, 1, 2, 3, 4])
    for cls in range(NUM_CLASSES):
        # Specificity = TN / (TN + FP)
        tn = np.sum(cm) - np.sum(cm[cls, :]) - np.sum(cm[:, cls]) + cm[cls, cls]
        fp = np.sum(cm[:, cls]) - cm[cls, cls]
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
        specificity_test.append(specificity)
    print(f"Độ đặc hiệu cho từng lớp trên tập test: {[f'{spec:.4f}' for spec in specificity_test]}")

    # Tính độ chính xác (Precision) cho từng lớp
    precision_test = precision_score(y_true_test, y_pred_classes_test, average=None, labels=[0, 1, 2, 3, 4])
    print(f"Độ chính xác cho từng lớp trên tập test: {[f'{prec:.4f}' for prec in precision_test]}")

    # Vẽ và lưu ma trận nhầm lẫn cho tập test
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=[0, 1, 2, 3, 4],
                yticklabels=[0, 1, 2, 3, 4])
    plt.title(f'Ma trận nhầm lẫn tập test - QWK: {qwk_test:.4f} ({config["model_type"]})')
    plt.xlabel('Dự đoán')
    plt.ylabel('Thật')
    save_cm_test_path = f"/content/drive/MyDrive/working/test_confusion_matrix_{config['model_type']}.png"
    plt.savefig(save_cm_test_path)
    plt.close()
    print(f"Đã lưu ma trận nhầm lẫn tập test tại {save_cm_test_path}")

NameError: name 'create_model' is not defined

**Huấn luyện meta-learner**

mẫu 2

In [None]:
import tensorflow as tf
import logging

# Cấu hình logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Cấu hình GPU trước khi TensorFlow khởi tạo
physical_devices = tf.config.list_physical_devices('GPU')
if physical_devices:
    for device in physical_devices:
        try:
            tf.config.experimental.set_memory_growth(device, True)
            logging.info(f"Đã bật memory growth cho {device}")
        except Exception as e:
            logging.error(f"Lỗi khi cài đặt memory growth cho {device}: {str(e)}")

import numpy as np
import tensorflow as tf
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import cohen_kappa_score, confusion_matrix, f1_score, recall_score, classification_report, roc_curve, auc, precision_recall_curve
from sklearn.model_selection import train_test_split, KFold
from sklearn.utils import shuffle
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import EfficientNetB1, Xception, InceptionV3, ResNet50, DenseNet121
from tensorflow.keras.applications.efficientnet import preprocess_input as efficientnet_preprocess
from tensorflow.keras.applications.xception import preprocess_input as xception_preprocess
from tensorflow.keras.applications.inception_v3 import preprocess_input as inceptionv3_preprocess
from tensorflow.keras.applications.resnet50 import preprocess_input as resnet50_preprocess
from tensorflow.keras.applications.densenet import preprocess_input as densenet121_preprocess
from tensorflow.keras.layers import Layer, Input, Conv2D, BatchNormalization, GlobalAveragePooling2D, Dense, Dropout
from tensorflow.keras.models import Model, model_from_json  # Thêm model_from_json
from tensorflow.keras.callbacks import Callback
from pathlib import Path
import os
import gc
import psutil
import logging
import cv2
import json
import time
import random
from itertools import product
from sklearn.calibration import calibration_curve
from datetime import datetime
import albumentations as A
import GPUtil

# Thiết lập logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logging.info(f"Pipeline khởi động lúc: {datetime.now().strftime('%H:%M %p %z, %d/%m/%Y')}")

# Tham số cố định
FEATURE_SAVE_DIR = "/content/drive/MyDrive/working"
PROCESSED_FOLDER = "/content/processed_train_images"
TEMP_AUGMENT_DIR = "/content/temp_augment"
DRIVE_FOLDER = "/content/drive/MyDrive/kaggle_data/aptos2019"
SIZE = 244
NUM_CLASSES = 5
BATCH_SIZE = 16
SUPPORT_SET_SIZE = 10
QUERY_SET_SIZE = 10

# Định nghĩa model_configs
MODEL_CONFIGS = {
    "efficientnetb1": {
        "model_type": "efficientnetb1",
        "config_path": "/content/drive/MyDrive/working/EfficientNetB1_bestqwk_aptos/config.json",
        "weights_path": "/content/drive/MyDrive/working/EfficientNetB1_bestqwk_aptos/model.weights.h5",
        "preprocess": efficientnet_preprocess,
        "img_size": 244,  # Sửa từ 244 thành 224
        "base_model": EfficientNetB1
    },
    "xception": {
        "model_type": "xception",
        "config_path": "/content/drive/MyDrive/working/Xception_bestqwk_aptos/config.json",
        "weights_path": "/content/drive/MyDrive/working/Xception_bestqwk_aptos/model.weights.h5",
        "preprocess": xception_preprocess,
        "img_size": 244,
        "base_model": Xception
    },
    "inceptionv3": {
        "model_type": "inceptionv3",
        "config_path": "/content/drive/MyDrive/working/InceptionV3_bestqwk_aptos/config.json",
        "weights_path": "/content/drive/MyDrive/working/InceptionV3_bestqwk_aptos/model.weights.h5",
        "preprocess": inceptionv3_preprocess,
        "img_size": 299,
        "base_model": InceptionV3
    },
    "resnet50": {
        "model_type": "resnet50",
        "config_path": "/content/drive/MyDrive/working/ResNet50_bestqwk_aptos/config.json",
        "weights_path": "/content/drive/MyDrive/working/ResNet50_bestqwk_aptos/model.weights.h5",
        "preprocess": resnet50_preprocess,
        "img_size": 244,
        "base_model": ResNet50
    },
    "densenet121": {
        "model_type": "densenet121",
        "config_path": "/content/drive/MyDrive/working/DenseNet121_bestqwk_aptos/config.json",
        "weights_path": "/content/drive/MyDrive/working/DenseNet121_bestqwk_aptos/model.weights.h5",
        "preprocess": densenet121_preprocess,
        "img_size": 244,
        "base_model": DenseNet121
    }
}

# Kiểm tra và mount Google Drive
from google.colab import drive
if not os.path.ismount('/content/drive'):
    drive.mount('/content/drive')
else:
    logging.info("Google Drive đã được mount.")

# Tạo thư mục lưu trữ
os.makedirs(FEATURE_SAVE_DIR, exist_ok=True)
os.makedirs(PROCESSED_FOLDER, exist_ok=True)
os.makedirs(TEMP_AUGMENT_DIR, exist_ok=True)

def validate_model_configs(model_configs):
    for model_name, config in model_configs.items():
        config_path = config['config_path']
        weights_path = config['weights_path']
        if not os.path.exists(config_path):
            logging.error(f"Tệp cấu hình không tồn tại cho {model_name}: {config_path}")
            raise FileNotFoundError(f"Thiếu tệp cấu hình: {config_path}")
        if not os.path.exists(weights_path):
            logging.error(f"Tệp trọng số không tồn tại cho {model_name}: {weights_path}")
            raise FileNotFoundError(f"Thiếu tệp trọng số: {weights_path}")
        logging.info(f"Xác nhận tệp hợp lệ cho {model_name}: {config_path}, {weights_path}")

def load_model_from_config(config_path, weights_path, base_model_fn, input_shape):
    if not os.path.exists(config_path) or not os.path.exists(weights_path):
        logging.error(f"Thiếu tệp: config={config_path}, weights={weights_path}")
        raise FileNotFoundError("Thiếu tệp cấu hình hoặc trọng số")
    try:
        with open(config_path, 'r') as f:
            model_config = json.load(f)
        model = model_from_json(json.dumps(model_config), custom_objects={
            'CustomGridDropout': CustomGridDropout,
            'MemoryAugmentedLayer': MemoryAugmentedLayer,
            'GradientReversalLayer': GradientReversalLayer
        })
        model.load_weights(weights_path, skip_mismatch=True)  # Bỏ by_name=True
        logging.info(f"Đã tải trọng số từ {weights_path}")
    except Exception as e:
        logging.warning(f"Không tải được mô hình: {str(e)}. Khởi tạo với trọng số ImageNet...")
        model = base_model_fn(include_top=False, weights='imagenet', input_shape=input_shape)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    return model

# Hàm tải mô hình từ config
def load_processed_image(image_id, processed_folder, model_type):
    try:
        if model_type not in MODEL_CONFIGS:
            logging.error(f"model_type không hợp lệ: {model_type}")
            raise ValueError(f"model_type không hợp lệ: {model_type}")
        target_size = MODEL_CONFIGS[model_type]['img_size']
        img_path = os.path.join(processed_folder, f"{image_id}.png")
        if not os.path.exists(img_path):
            logging.error(f"Ảnh không tồn tại: {img_path}")
            return None
        img = cv2.imread(img_path)
        if img is None:
            logging.error(f"Không đọc được ảnh: {img_path}")
            return None
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (target_size, target_size), interpolation=cv2.INTER_AREA)
        if img.shape != (target_size, target_size, 3):
            logging.error(f"Shape ảnh không đúng: {img.shape}, kỳ vọng: ({target_size}, {target_size}, 3)")
            return None
        img = img.astype(np.float32) / 255.0  # Chuẩn hóa
        logging.debug(f"Đã tải ảnh {image_id} cho {model_type}: shape={img.shape}")
        return img
    except Exception as e:
        logging.error(f"Lỗi khi tải ảnh {image_id} cho {model_type}: {str(e)}")
        return None
# Custom layers
class CustomGridDropout(Layer):
    def __init__(self, ratio=0.3, holes_number=4, p=0.5, **kwargs):
        super().__init__(**kwargs)
        self.ratio = ratio
        self.holes_number = holes_number
        self.p = p

    def call(self, inputs, training=None):
        if not training:
            return inputs
        inputs = tf.cast(inputs, tf.float16)  # Ép kiểu sang float16
        batch_size = tf.shape(inputs)[0]
        height = tf.shape(inputs)[1]
        width = tf.shape(inputs)[2]
        channels = tf.shape(inputs)[3]

        hole_height = tf.maximum(1, tf.cast(tf.cast(height, tf.float32) * self.ratio, tf.int32))
        hole_width = tf.maximum(1, tf.cast(tf.cast(width, tf.float32) * self.ratio, tf.int32))

        mask = tf.ones_like(inputs, dtype=tf.float16)  # Sử dụng float16
        random_probs = tf.random.uniform([self.holes_number], 0, 1)
        active_holes = tf.cast(random_probs < self.p, tf.int32)

        all_indices = []
        for i in range(self.holes_number):
            should_apply = active_holes[i]
            indices = tf.cond(
                should_apply > 0,
                lambda: self._generate_patch_indices(
                    batch_size, height, width, channels, hole_height, hole_width, i
                ),
                lambda: tf.zeros([0, 4], dtype=tf.int32)
            )
            all_indices.append(indices)

        all_indices = tf.concat(all_indices, axis=0)
        updates = tf.zeros([tf.shape(all_indices)[0]], dtype=tf.float16)

        if tf.shape(all_indices)[0] > 0:
            mask = tf.tensor_scatter_nd_update(mask, all_indices, updates)

        return tf.cast(inputs * mask, tf.float16)  # Đảm bảo đầu ra là float16

    def _generate_patch_indices(self, batch_size, height, width, channels, hole_height, hole_width, hole_idx):
        h_start = tf.random.uniform([], 0, height - hole_height + 1, dtype=tf.int32)
        w_start = tf.random.uniform([], 0, width - hole_width + 1, dtype=tf.int32)

        h_indices = tf.range(h_start, h_start + hole_height)
        w_indices = tf.range(w_start, w_start + hole_width)
        c_indices = tf.range(channels)

        batch_indices = tf.tile(tf.range(batch_size), [hole_height * hole_width * channels])
        h_grid, w_grid, c_grid = tf.meshgrid(h_indices, w_indices, c_indices, indexing='ij')
        h_grid = tf.reshape(h_grid, [-1])
        w_grid = tf.reshape(w_grid, [-1])
        c_grid = tf.reshape(c_grid, [-1])

        h_indices = tf.tile(h_grid, [batch_size])
        w_indices = tf.tile(w_grid, [batch_size])
        c_indices = tf.tile(c_grid, [batch_size])

        indices = tf.stack([batch_indices, h_indices, w_indices, c_indices], axis=1)
        return indices

    def get_config(self):
        config = super().get_config()
        config.update({"ratio": self.ratio, "holes_number": self.holes_number, "p": self.p})
        return config

class MemoryAugmentedLayer(Layer):
    def __init__(self, memory_size, memory_dim, **kwargs):
        super().__init__(**kwargs)
        self.memory_size = memory_size
        self.memory_dim = memory_dim
        self.memory = self.add_weight(
            shape=(memory_size, memory_dim),
            initializer='random_normal',
            trainable=True,
            name='memory'
        )

    def call(self, inputs):
        # Placeholder: Trả về inputs mà không thay đổi
        # Thay thế bằng triển khai thực tế của bạn
        return inputs

    def get_config(self):
        config = super().get_config()
        config.update({"memory_size": self.memory_size, "memory_dim": self.memory_dim})
        return config

def check_memory():
    try:
        memory_info = psutil.virtual_memory()
        logging.info(f"Memory usage: Total={memory_info.total / (1024 ** 3):.2f} GB, "
                     f"Used={memory_info.used / (1024 ** 3):.2f} GB, "
                     f"Free={memory_info.available / (1024 ** 3):.2f} GB, "
                     f"Percent={memory_info.percent}%")

        # Check GPU memory usage
        try:
            gpus = GPUtil.getGPUs()
            for gpu in gpus:
                logging.info(f"GPU {gpu.id}: {gpu.name}, "
                            f"Memory Used={gpu.memoryUsed / 1024:.2f} GB, "
                            f"Memory Total={gpu.memoryTotal / 1024:.2f} GB, "
                            f"Utilization={gpu.memoryUtil * 100:.2f}%")
        except:
            logging.warning(f"Failed to get GPU memory information")

        # Perform garbage collection
        gc.collect()

        # Warn if memory usage is high
        memory_threshold = 1024  # Threshold in MB
        if memory_info.available < memory_threshold:
            logging.warning(f"Low memory detected: {memory_info.available / (1024 ** 3):.2f} GB available")
            tf.keras.backend.clear_session()

            # Free GPU memory
            try:
                tf.config.experimental.set_memory_growth(tf.config.list_physical_devices('GPU')[0], True)
            except:
                logging.warning(f"Failed to set memory growth")

        return memory_info.available / (1024 ** 3)  # Return available memory in GB

    except Exception as e:
        logging.error(f"Error checking memory: {str(e)}")
        return None

class GradientReversalLayer(Layer):
    def __init__(self, lambda_=1.0, **kwargs):
        super().__init__(**kwargs)
        self.lambda_ = lambda_

    def call(self, inputs, training=None):
        inputs = tf.convert_to_tensor(inputs, dtype=tf.float32)
        return inputs if not training else tf.math.multiply(-self.lambda_, inputs)

    def get_config(self):
        config = super().get_config()
        config.update({"lambda_": self.lambda_})
        return config

# Hàm random erasing tùy chỉnh
def custom_random_erasing(image, scale=(0.01, 0.05), ratio=(0.5, 2.0), p=0.3, value=None):
    if np.random.random() > p:
        return image
    height, width, channels = image.shape
    area = height * width
    scale_factor = np.random.uniform(scale[0], scale[1])
    erase_area = area * scale_factor
    aspect_ratio = np.random.uniform(ratio[0], ratio[1])
    erase_height = int(np.sqrt(erase_area / aspect_ratio))
    erase_width = int(np.sqrt(erase_area * aspect_ratio))
    erase_height = min(erase_height, height)
    erase_width = min(erase_width, width)
    if erase_height < 1 or erase_width < 1:
        return image
    x = np.random.randint(0, width - erase_width + 1)
    y = np.random.randint(0, height - erase_height + 1)
    output = image.copy()
    if value is None:
        value = np.mean(image, axis=(0, 1))
    output[y:y+erase_height, x:x+erase_width, :] = value
    return output


# Hàm cân bằng và tăng cường dữ liệu
def balance_data(images, labels, target_classes=[0, 1, 2, 3, 4]):
    if not isinstance(target_classes, (list, np.ndarray)):
        raise TypeError(f"target_classes must be list or numpy array, got: {type(target_classes)}")

    num_classes = labels.shape[1]
    label_indices = np.argmax(labels, axis=1)

    # Filter samples belonging to target_classes
    keep_indices = np.isin(label_indices, target_classes)
    filtered_images = images[keep_indices]
    filtered_labels = labels[keep_indices]
    filtered_label_indices = label_indices[keep_indices]

    # Count samples per class
    class_counts = np.bincount(filtered_label_indices, minlength=num_classes)
    logging.info(f"Initial label distribution: {dict(zip(range(num_classes), class_counts))}")

    # Target ~7148 samples (~1430 per class)
    total_samples = 7148
    base_samples_per_class = total_samples // len(target_classes)  # 1429
    extra_samples = total_samples % len(target_classes)  # 3
    samples_per_class = [base_samples_per_class] * len(target_classes)
    for i in range(extra_samples):
        samples_per_class[i] += 1
    logging.info(f"Target samples per class: {dict(zip(target_classes, samples_per_class))}")

    new_images = []
    new_labels = []

    for cls_idx, cls in enumerate(target_classes):
        cls_indices = np.where(filtered_label_indices == cls)[0]
        cls_images = filtered_images[cls_indices]
        cls_labels = filtered_labels[cls_indices]
        current_count = len(cls_indices)
        target_count = samples_per_class[cls_idx]

        logging.info(f"Class {cls}: {current_count} initial samples, target {target_count} samples")

        # Add all available samples
        new_images.extend(cls_images)
        new_labels.extend(cls_labels)

        # Oversample if needed
        if current_count < target_count:
            oversample_count = target_count - current_count
            logging.info(f"Oversampling {oversample_count} samples for class {cls}")
            # Randomly select indices with replacement
            oversample_indices = np.random.choice(cls_indices, size=oversample_count, replace=True)
            new_images.extend(filtered_images[oversample_indices])
            new_labels.extend(filtered_labels[oversample_indices])

    new_images = np.array(new_images, dtype=np.float32)
    new_labels = np.array(new_labels, dtype=np.float32)

    # Verify total samples
    if len(new_images) != total_samples:
        logging.error(f"Sample count mismatch: got {len(new_images)}, expected {total_samples}")
        raise ValueError("Sample count mismatch after balancing")

    new_images, new_labels = shuffle(new_images, new_labels, random_state=42)

    final_class_counts = np.bincount(np.argmax(new_labels, axis=1), minlength=num_classes)
    logging.info(f"Final label distribution: {dict(zip(range(num_classes), final_class_counts))}")

    return new_images, new_labels


class My_Generator(tf.keras.utils.Sequence):
    def __init__(self, images, labels, batch_size, is_train=False, mix=False, augment=False, model_type="default", preprocess=None, image_paths=None, sample_ids=None, temp_augment_dir="/content/temp_augment"):
        self.batch_size = batch_size
        self.is_train = is_train
        self.is_mix = False  # Disable mixup
        self.augment = False  # Disable augmentation
        self.model_type = str(model_type).lower()
        self.preprocess = preprocess
        self.temp_augment_dir = temp_augment_dir
        os.makedirs(self.temp_augment_dir, exist_ok=True)
        if not os.access(self.temp_augment_dir, os.W_OK):
            raise PermissionError(f"No write permission for {self.temp_augment_dir}")

        # Set target_size dynamically based on model_type
        if self.model_type not in MODEL_CONFIGS:
            raise ValueError(f"Invalid model_type: {self.model_type}. Must be in {list(MODEL_CONFIGS.keys())}")
        self.target_size = (MODEL_CONFIGS[self.model_type]['img_size'], MODEL_CONFIGS[self.model_type]['img_size'])
        logging.info(f"Initialized My_Generator for {self.model_type} with target_size={self.target_size}")

        self.image_paths = []
        self.labels = []
        self.sample_ids = []

        if image_paths is not None:
            if len(image_paths) != len(labels) or len(image_paths) != len(sample_ids):
                raise ValueError(f"Mismatch: image_paths={len(image_paths)}, labels={len(labels)}, sample_ids={len(sample_ids)}")
            for path, label, sid in zip(image_paths, labels, sample_ids):
                if os.path.exists(path):
                    self.image_paths.append(path)
                    self.labels.append(label)
                    self.sample_ids.append(sid)
                else:
                    logging.warning(f"Skipping non-existent image: {path}")
        elif isinstance(images, np.ndarray):
            if len(images) != len(labels) or len(images) != len(sample_ids):
                raise ValueError(f"Mismatch: images={len(images)}, labels={len(labels)}, sample_ids={len(sample_ids)}")
            for i, (img, label, sid) in enumerate(zip(images, labels, sample_ids)):
                img_path = os.path.join(self.temp_augment_dir, f"img_{i}_{np.random.randint(1000000)}.png")
                try:
                    # Resize image to target_size
                    img = cv2.resize(img, self.target_size, interpolation=cv2.INTER_AREA)
                    if img.dtype != np.uint8:
                        if img.max() <= 1.0:
                            img = (img * 255).astype(np.uint8)
                        else:
                            img = img.astype(np.uint8)
                    cv2.imwrite(img_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
                    if os.path.exists(img_path):
                        self.image_paths.append(img_path)
                        self.labels.append(label)
                        self.sample_ids.append(sid)
                    else:
                        logging.warning(f"Failed to save image: {img_path}")
                except Exception as e:
                    logging.warning(f"Error saving image {img_path}: {str(e)}")
                    continue
        else:
            raise ValueError("Require valid image_paths or images")

        if not self.image_paths:
            raise ValueError("No valid image_paths created")

        self.labels = np.array(self.labels, dtype=np.float32)
        self.sample_ids = np.array(self.sample_ids, dtype=str)

        if len(self.image_paths) != len(self.labels) or len(self.image_paths) != len(self.sample_ids):
            raise ValueError(f"Post-init mismatch: image_paths={len(self.image_paths)}, labels={len(self.labels)}, sample_ids={len(self.sample_ids)}")

        logging.info(f"Initialized My_Generator: {len(self.image_paths)} samples, batch_size={batch_size}, is_train={is_train}, target_size={self.target_size}")
        self.indices = np.arange(len(self.image_paths))
        if self.is_train:
            np.random.shuffle(self.indices)

    def __len__(self):
        num_samples = len(self.image_paths)
        if num_samples == 0:
            logging.error("No samples in generator")
            raise ValueError("Generator has no samples")
        num_batches = (num_samples + self.batch_size - 1) // self.batch_size
        logging.info(f"Generator len: {num_samples} samples, {num_batches} batches")
        return num_batches

    def __getitem__(self, index):
        start_idx = index * self.batch_size
        end_idx = min(start_idx + self.batch_size, len(self.image_paths))
        batch_indices = self.indices[start_idx:end_idx]

        if len(batch_indices) == 0:
            logging.warning(f"Batch {index} has no indices, returning empty batch")
            return np.array([]), np.array([]), []

        batch_images = []
        batch_labels = []
        batch_ids = []

        for idx in batch_indices:
            try:
                img = self._load_image(self.image_paths[idx])
                if img is not None and img.size > 0:
                    batch_images.append(img)
                    batch_labels.append(self.labels[idx])
                    batch_ids.append(self.sample_ids[idx])
                else:
                    logging.warning(f"Skipping empty or failed image at index {idx}: {self.image_paths[idx]}")
            except Exception as e:
                logging.warning(f"Error loading image at index {idx}: {str(e)}")
                continue

        if not batch_images:
            logging.warning(f"Batch {index} empty after processing, returning empty batch")
            return np.array([]), np.array([]), []

        batch_images = np.array(batch_images)
        batch_labels = np.array(batch_labels)
        logging.debug(f"Batch {index}: images_shape={batch_images.shape}, labels_shape={batch_labels.shape}, ids_count={len(batch_ids)}")
        return batch_images, batch_labels, batch_ids

    def _load_image(self, img_path):
        img = cv2.imread(img_path)
        if img is None:
            logging.error(f"Failed to read image: {img_path}")
            return None
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, self.target_size)
        if self.preprocess:
            img = self.preprocess(img)
        return img

    def on_epoch_end(self):
        if self.is_train:
            np.random.shuffle(self.indices)

# Callback để tính trọng số lớp từ confusion matrix
class ConfusionMatrixWeightCallback(tf.keras.callbacks.Callback):
    def __init__(self, valid_features, valid_labels, classification_model, num_classes=5, class_counts=None):
        super().__init__()
        self.valid_features = valid_features
        self.valid_labels = valid_labels
        self.classification_model = classification_model
        self.num_classes = num_classes
        self.prev_cm = None
        self.class_weights = np.ones(num_classes, dtype=np.float32)
        self.class_counts = class_counts
        self.history_dir = os.path.join(FEATURE_SAVE_DIR, "history")
        os.makedirs(self.history_dir, exist_ok=True)
        self.weights_history = []

    def on_epoch_end(self, epoch, logs=None):
        y_pred = self.classification_model.predict(self.valid_features, verbose=0, batch_size=32)
        y_true = np.argmax(self.valid_labels, axis=1)
        y_pred_classes = np.argmax(y_pred, axis=1)

        cm = confusion_matrix(y_true, y_pred_classes, labels=list(range(self.num_classes)))
        logging.info(f"Epoch {epoch+1} - Ma trận nhầm lẫn:\n{cm}")

        errors = np.sum(cm * (1 - np.eye(self.num_classes)), axis=1)
        total_samples_per_class = np.sum(cm, axis=1)
        total_samples_per_class = np.where(total_samples_per_class == 0, 1, total_samples_per_class)
        error_rates = errors / total_samples_per_class

        weak_classes = []
        if self.class_counts is not None:
            min_count = np.min(self.class_counts[self.class_counts > 0])
            weak_classes = np.where(self.class_counts <= min_count * 1.5)[0]
        high_error_classes = np.where(error_rates >= np.percentile(error_rates, 75))[0]
        weak_classes = np.unique(np.concatenate([weak_classes, high_error_classes])).astype(int)

        self.class_weights = 1.0 + error_rates
        for cls in weak_classes:
            self.class_weights[cls] *= 2.0
        self.class_weights /= self.class_weights.max()

        logging.info(f"Epoch {epoch+1} - Lớp yếu: {weak_classes}")
        logging.info(f"Epoch {epoch+1} - Trọng số lớp: {self.class_weights}")

        self.weights_history.append({
            "epoch": epoch + 1,
            "class_weights": self.class_weights.tolist(),
            "weak_classes": weak_classes.tolist(),
            "confusion_matrix": cm.tolist()
        })
        weights_path = os.path.join(self.history_dir, f"class_weights_epoch_{epoch+1}.json")
        with open(weights_path, 'w') as f:
            json.dump(self.weights_history[-1], f, indent=4)
        logging.info(f"Đã lưu trọng số lớp tại: {weights_path}")

        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                    xticklabels=list(range(self.num_classes)),
                    yticklabels=list(range(self.num_classes)))
        plt.title(f'Ma trận nhầm lẫn - Epoch {epoch+1}')
        plt.xlabel('Dự đoán')
        plt.ylabel('Thực tế')
        cm_path = os.path.join(FEATURE_SAVE_DIR, f'confusion_matrix_epoch_{epoch+1}.png')
        plt.savefig(cm_path)
        plt.close()
        logging.info(f"Đã lưu ma trận nhầm lẫn tại: {cm_path}")

        self.prev_cm = cm.copy()

    def get_class_weights(self):
        return self.class_weights

# Callback cho checkpoint dựa trên nhiều metrics
class MultiMetricCheckpoint(tf.keras.callbacks.Callback):
    def __init__(self, filepath, monitor_metrics, mode='max', save_best_only=True):
        super().__init__()
        self.filepath = filepath
        self.monitor_metrics = monitor_metrics
        self.mode = mode
        self.save_best_only = save_best_only
        self.best_metrics = {metric: -float('inf') if mode == 'max' else float('inf') for metric in monitor_metrics}

    def on_epoch_end(self, epoch, logs=None):
        current_metrics = {}
        improved = False

        for metric in self.monitor_metrics:
            current = logs.get(metric, 0.0)
            current_metrics[metric] = current
            if self.mode == 'max' and current > self.best_metrics[metric]:
                self.best_metrics[metric] = current
                improved = True
            elif self.mode == 'min' and current < self.best_metrics[metric]:
                self.best_metrics[metric] = current
                improved = True

        if improved and self.save_best_only:
            self.model.save_weights(self.filepath.format(epoch=epoch + 1), overwrite=True)
            logging.info(f"Đã lưu mô hình tốt nhất tại epoch {epoch + 1} với metrics: {current_metrics}")

# Hàm kiểm tra bộ nhớ

# Hàm debug layers của mô hình
def debug_model_layers(model, model_name="model"):
    logging.info(f"Kiểm tra layers của {model_name}:")
    conv_layers = []
    for layer in model.layers:
        if hasattr(layer, 'output'):
            output_shape = layer.output.shape
            if len(output_shape) == 4:
                conv_layers.append((layer.name, output_shape))
                logging.info(f"Layer: {layer.name}, Type: {type(layer).__name__}, Output Shape: {output_shape}")
    if not conv_layers:
        logging.error(f"Không tìm thấy layer 4D nào trong {model_name}")
    else:
        logging.info(f"Tìm thấy {len(conv_layers)} layers với đầu ra 4D:")
        for name, shape in conv_layers:
            logging.info(f"  - {name}: {shape}")
    return conv_layers

# Hàm validate dataset
def validate_dataset(image_ids, processed_folder, save_dir):
    valid_ids = []
    corrupted_files = []

    for image_id in image_ids:
        img_path = os.path.join(processed_folder, f"{image_id}.png")
        if not os.path.exists(img_path):
            logging.error(f"Ảnh không tồn tại: {img_path}")
            corrupted_files.append(image_id)
            continue
        img = cv2.imread(img_path)
        if img is None:
            logging.error(f"Không đọc được ảnh: {img_path}")
            corrupted_files.append(image_id)
            continue
        valid_ids.append(image_id)

    report = {
        "total_images": len(image_ids),
        "valid_images": len(valid_ids),
        "corrupted_images": len(corrupted_files),
        "corrupted_ids": corrupted_files
    }
    report_path = os.path.join(save_dir, "dataset_validation_report.json")
    with open(report_path, 'w') as f:
        json.dump(report, f, indent=4)
    logging.info(f"Đã lưu báo cáo validation tại: {report_path}")

    if corrupted_files:
        logging.warning(f"Đã tìm thấy {len(corrupted_files)} ảnh bị lỗi hoặc thiếu: {corrupted_files[:5]}...")

    return valid_ids

# Hàm kiểm tra data leakage
def check_data_leakage(train_x, valid_x, test_x, metadata_df=None, id_column='id_code'):
    train_set = set(train_x)
    valid_set = set(valid_x)
    test_set = set(test_x)

    train_valid_overlap = train_set.intersection(valid_set)
    train_test_overlap = train_set.intersection(test_set)
    valid_test_overlap = valid_set.intersection(test_set)

    leakage_detected = False
    if train_valid_overlap:
        logging.warning(f"Phát hiện overlap giữa train và valid: {len(train_valid_overlap)} mẫu")
        leakage_detected = True
    if train_test_overlap:
        logging.warning(f"Phát hiện overlap giữa train và test: {len(train_test_overlap)} mẫu")
        leakage_detected = True
    if valid_test_overlap:
        logging.warning(f"Phát hiện overlap giữa valid và test: {len(valid_test_overlap)} mẫu")
        leakage_detected = True

    if metadata_df is not None and 'patient_id' in metadata_df.columns:
        train_patients = set(metadata_df[metadata_df[id_column].isin(train_x)]['patient_id'])
        valid_patients = set(metadata_df[metadata_df[id_column].isin(valid_x)]['patient_id'])
        test_patients = set(metadata_df[metadata_df[id_column].isin(test_x)]['patient_id'])

        patient_overlap = train_patients.intersection(valid_patients, test_patients)
        if patient_overlap:
            logging.warning(f"Phát hiện overlap bệnh nhân: {patient_overlap}")
            leakage_detected = True

    if not leakage_detected:
        logging.info("Không phát hiện data leakage.")

    return not leakage_detected

def get_image_ids(folder_path):
    try:
        if not os.path.exists(folder_path):
            print(f"ERROR: Thư mục {folder_path} không tồn tại.")
            return []

        image_files = [f for f in os.listdir(folder_path) if f.endswith('.png')]
        image_ids = [os.path.splitext(f)[0] for f in image_files]

        print(f"Tìm thấy {len(image_ids)} ID ảnh trong {folder_path}")
        print(f"Mẫu ID: {image_ids[:5]}...")

        return image_ids
    except Exception as e:
        print(f"ERROR: Lỗi khi liệt kê ID ảnh trong {folder_path}: {str(e)}")
        return []

def extract_features(generator, feature_extractor):
    features = []
    feature_ids = []
    total_samples = len(generator.image_paths)
    num_batches = len(generator)
    processed_samples = 0
    empty_batch_count = 0
    max_empty_batches = 10
    batch_size = 16

    logging.info(f"Trích xuất đặc trưng cho {total_samples} mẫu, {num_batches} batch")

    for batch_idx in range(num_batches):
        try:
            batch_images, _, batch_ids = generator[batch_idx]
        except Exception as e:
            logging.warning(f"Lỗi khi lấy batch {batch_idx}: {str(e)}")
            empty_batch_count += 1
            if empty_batch_count >= max_empty_batches:
                logging.error(f"Quá nhiều batch rỗng liên tiếp ({empty_batch_count})")
                break
            continue

        if len(batch_images) == 0 or len(batch_ids) == 0:
            logging.warning(f"Batch {batch_idx} rỗng, bỏ qua")
            empty_batch_count += 1
            if empty_batch_count >= max_empty_batches:
                logging.error(f"Quá nhiều batch rỗng liên tiếp ({empty_batch_count})")
                break
            continue

        empty_batch_count = 0
        try:
            batch_features = feature_extractor.predict(batch_images, batch_size=batch_size, verbose=0)
            logging.debug(f"Batch {batch_idx} feature shape trước pooling: {batch_features.shape}")
            if len(batch_features.shape) > 2:
                batch_features = np.mean(batch_features, axis=(1, 2))  # Global average pooling
            logging.debug(f"Batch {batch_idx} feature shape sau pooling: {batch_features.shape}")
            if batch_features.shape[0] != len(batch_ids):
                logging.warning(f"Số đặc trưng ({batch_features.shape[0]}) không khớp với số ID ({len(batch_ids)})")
                continue
            features.append(batch_features)
            feature_ids.extend(batch_ids)
            processed_samples += len(batch_ids)
            logging.info(f"Batch {batch_idx}: {len(batch_ids)} samples, features_shape={batch_features.shape}")
        except Exception as e:
            logging.warning(f"Lỗi khi xử lý batch {batch_idx}: {str(e)}")
            continue

    if not features:
        logging.error("Không trích xuất được đặc trưng nào.")
        raise ValueError("Không trích xuất được đặc trưng.")

    features = np.concatenate(features, axis=0)
    if len(features) != len(feature_ids):
        logging.error(f"Số đặc trưng ({len(features)}) không khớp với số feature_ids ({len(feature_ids)})")
        raise ValueError("Đặc trưng và ID không đồng bộ.")
    logging.info(f"Hoàn thành trích xuất: {len(features)} đặc trưng")
    return features, feature_ids

# Hàm trích xuất và lưu đặc trưng
def extract_and_save_features(model_name, extractor, generator, save_dir, sample_ids):
    if not sample_ids or len(sample_ids) == 0:
        logging.error(f"sample_ids rỗng hoặc không được cung cấp cho {model_name}.")
        raise ValueError("sample_ids rỗng hoặc không được cung cấp.")

    logging.info(f"Bắt đầu trích xuất đặc trưng cho {model_name}")
    features, feature_ids = extract_features(generator, extractor)
    logging.info(f"Trích xuất {len(features)} đặc trưng cho {model_name}")

    if len(features) != len(sample_ids):
        logging.warning(f"Số đặc trưng ({len(features)}) không khớp với số sample_ids ({len(sample_ids)})")
        missing_ids = set(sample_ids) - set(feature_ids)
        logging.warning(f"ID bị thiếu trong feature_ids: {list(missing_ids)[:5]}...")

    feature_path = os.path.join(save_dir, f"{model_name}_features.npy")
    np.save(feature_path, features)
    logging.info(f"Đã lưu đặc trưng tại: {feature_path}")

    ids_path = os.path.join(save_dir, f"{model_name}_feature_ids.npy")
    np.save(ids_path, np.array(feature_ids, dtype=str))
    logging.info(f"Đã lưu feature_ids tại: {ids_path}")

    return features, feature_ids

# Hàm kết hợp và giảm chiều đặc trưng
def combine_and_reduce_features(features_dict, features_dict_4d, labels, sample_ids, save_dir, n_components=50):
    try:
        features_2d_list = []
        feature_ids_list = []

        for model_name, features in features_dict.items():
            if features is None:
                logging.warning(f"Đặc trưng từ {model_name} là None, bỏ qua.")
                continue
            if len(features.shape) != 2:
                logging.warning(f"Đặc trưng từ {model_name} có shape {features.shape}, áp dụng global average pooling.")
                features = np.mean(features, axis=tuple(range(1, len(features.shape)-1)), keepdims=False)
            features_2d_list.append(features)
            feature_ids = np.load(os.path.join(save_dir, f"{model_name}_feature_ids.npy"), allow_pickle=True)
            feature_ids_list.append(feature_ids)
            logging.info(f"Đã thêm đặc trưng 2D từ {model_name}, shape: {features.shape}")

        if not features_2d_list:
            logging.error("Không có đặc trưng 2D hợp lệ nào để kết hợp.")
            raise ValueError("Yêu cầu ít nhất một tập đặc trưng 2D hợp lệ.")

        # Kiểm tra tính nhất quán của feature_ids
        feature_ids_array = np.array(feature_ids_list)
        if not np.all(feature_ids_array == feature_ids_array[0]):
            logging.warning("Feature_ids không nhất quán, sử dụng feature_ids từ mô hình đầu tiên.")
        common_ids = feature_ids_list[0]

        # Kiểm tra shape đồng nhất
        feature_shapes = [f.shape for f in features_2d_list]
        if len(set(tuple(s) for s in feature_shapes)) > 1:
            logging.error(f"Shape đặc trưng không đồng nhất: {feature_shapes}")
            raise ValueError("Đặc trưng có shape không đồng nhất.")

        combined_2d = np.concatenate(features_2d_list, axis=1)
        scaler = StandardScaler()
        combined_2d = scaler.fit_transform(combined_2d)

        indices = []
        missing_ids = []
        for sample_id in sample_ids:
            idx = np.where(common_ids == sample_id)[0]
            if len(idx) > 0:
                indices.append(idx[0])
            else:
                missing_ids.append(sample_id)
                logging.warning(f"Không tìm thấy sample_id {sample_id} trong feature_ids.")

        if not indices:
            logging.error("Không tìm thấy sample_id nào trong feature_ids.")
            raise ValueError("Không có sample_id hợp lệ.")

        indices = np.array(indices)
        combined_2d = combined_2d[indices]
        aligned_labels = labels[indices] if labels is not None else None

        if n_components is not None:
            pca = PCA(n_components=n_components, random_state=42)
            reduced_2d = pca.fit_transform(combined_2d)
            explained_variance = np.sum(pca.explained_variance_ratio_)
            logging.info(f"PCA: {n_components} thành phần, giải thích {explained_variance*100:.2f}% phương sai")
        else:
            reduced_2d = combined_2d
            explained_variance = 1.0
            logging.info("Bỏ qua PCA, giữ nguyên đặc trưng gốc.")

        save_path = os.path.join(save_dir, "combined_reduced_features.npy")
        np.save(save_path, reduced_2d)
        logging.info(f"Đã lưu đặc trưng giảm chiều tại: {save_path}")

        return reduced_2d, aligned_labels, indices, reduced_2d.shape[1]

    except Exception as e:
        logging.error(f"Lỗi khi kết hợp và giảm chiều đặc trưng: {str(e)}")
        raise

# Hàm apply temperature scaling và laplace smoothing
def apply_temperature_scaling(logits, temperature=2.0):
    logits = tf.convert_to_tensor(logits, dtype=tf.float32)
    return tf.nn.softmax(logits / temperature)

def laplace_smoothing(probs, epsilon=1e-5):
    probs = tf.convert_to_tensor(probs, dtype=tf.float32)
    return (probs + epsilon) / (tf.reduce_sum(probs, axis=-1, keepdims=True) + NUM_CLASSES * epsilon)

# Hàm lưu confusion matrix
def save_confusion_matrix(y_true, y_pred, episode, qwk, save_dir, prefix=''):
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1, 2, 3, 4])
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=[0, 1, 2, 3, 4], yticklabels=[0, 1, 2, 3, 4])
    plt.title(f'Ma trận nhầm lẫn - {prefix}QWK: {qwk:.4f} tại Episode {episode+1}')
    plt.xlabel('Dự đoán')
    plt.ylabel('Thực tế')
    cm_path = os.path.join(save_dir, f'confusion_matrix_{prefix}qwk_episode_{episode+1}.png')
    plt.savefig(cm_path)
    plt.close()
    logging.info(f"Đã lưu ma trận nhầm lẫn {prefix}QWK tại: {cm_path}")
    return cm

# Hàm đánh giá chi tiết per-class metrics
def evaluate_per_class_metrics(y_true, y_pred_probs, y_pred_classes, labels, save_dir, prefix="test"):
    report = classification_report(
        y_true, y_pred_classes, labels=labels, target_names=[f"Class {i}" for i in labels], output_dict=True
    )
    metrics_df = pd.DataFrame(report).transpose()

    metrics_path = Path(save_dir) / f"{prefix}_classification_report.csv"
    metrics_df.to_csv(metrics_path)
    logging.info(f"Đã lưu báo cáo phân loại tại: {metrics_path}")

    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    for i in labels:
        fpr, tpr, _ = roc_curve(y_true == i, y_pred_probs[:, i])
        roc_auc = auc(fpr, tpr)
        plt.plot(fpr, tpr, label=f"Class {i} (AUC = {roc_auc:.2f})")
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROC Curves")
    plt.legend()

    plt.subplot(1, 2, 2)
    for i in labels:
        precision, recall, _ = precision_recall_curve(y_true == i, y_pred_probs[:, i])
        plt.plot(recall, precision, label=f"Class {i}")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title("Precision-Recall Curves")
    plt.legend()

    plt.tight_layout()
    curves_path = Path(save_dir) / f"{prefix}_roc_pr_curves.png"
    plt.savefig(curves_path, bbox_inches="tight")
    plt.close()
    logging.info(f"Đã lưu ROC và PR curves tại: {curves_path}")

    return metrics_df

# Hàm đánh giá calibration
def evaluate_calibration(y_true, y_pred_probs, save_dir, prefix="test"):
    plt.figure(figsize=(10, 8))

    for cls in range(NUM_CLASSES):
        prob_true, prob_pred = calibration_curve(
            y_true == cls, y_pred_probs[:, cls], n_bins=10, strategy='uniform'
        )
        plt.plot(prob_pred, prob_true, marker='o', label=f"Class {cls}")

    plt.plot([0, 1], [0, 1], linestyle='--', color='gray')
    plt.xlabel("Mean Predicted Probability")
    plt.ylabel("Fraction of Positives")
    plt.title("Reliability Diagram")
    plt.legend()

    calib_path = Path(save_dir) / f"{prefix}_calibration_curve.png"
    plt.savefig(calib_path, bbox_inches="tight")
    plt.close()
    logging.info(f"Đã lưu calibration curve tại: {calib_path}")

# Hàm export predictions
def export_predictions(image_ids, y_true, y_pred_probs, y_pred_classes, save_dir, filename="predictions.csv"):
    predictions_df = pd.DataFrame({
        "id_code": image_ids,
        "true_label": y_true,
        "predicted_label": y_pred_classes
    })
    for i in range(y_pred_probs.shape[1]):
        predictions_df[f"prob_class_{i}"] = y_pred_probs[:, i]

    output_path = Path(save_dir) / filename
    predictions_df.to_csv(output_path, index=False)
    logging.info(f"Đã xuất dự đoán tại: {output_path}")

    return predictions_df



# Hàm lưu meta-learner model

def save_meta_learner_model(model, save_dir, episode, qwk_scores, training_config):
    if model is None:
        logging.error(f"Không thể lưu mô hình tại episode {episode + 1}: model là None")
        raise ValueError("Mô hình là None")
    try:
        os.makedirs(save_dir, exist_ok=True)
        model_path = os.path.join(save_dir, f"meta_model_episode_{episode + 1}.h5")
        model.save_weights(model_path, overwrite=True)  # Chỉ lưu trọng số
        logging.info(f"Đã lưu trọng số mô hình tại: {model_path}")
        meta_info = {
            "episode": episode + 1,
            "qwk_scores": [float(score) for score in qwk_scores],
            "training_config": training_config
        }
        info_path = os.path.join(save_dir, f"meta_info_episode_{episode + 1}.json")
        with open(info_path, 'w') as f:
            json.dump(meta_info, f, indent=4)
        logging.info(f"Đã lưu thông tin meta tại: {info_path}")
    except Exception as e:
        logging.error(f"Lỗi khi lưu meta mô hình tại episode {episode + 1}: {str(e)}")
        raise

# Hàm tạo episode cho meta-learning
def create_episode(data_df, support_size, query_size, num_classes, model_type="efficientnetb1"):
    config = MODEL_CONFIGS[model_type]
    size = config['img_size']
    support_images = []
    support_labels = []
    query_images = []
    query_labels = []
    included_classes = []

    for cls in range(num_classes):
        cls_samples = data_df[data_df['diagnosis'] == cls]
        required_samples = support_size + query_size
        if len(cls_samples) < required_samples:
            logging.warning(f"Không đủ mẫu cho lớp {cls}: cần {required_samples}, có {len(cls_samples)}")
            continue
        cls_samples = cls_samples.sample(n=required_samples, random_state=42)
        support_samples = cls_samples.iloc[:support_size]
        query_samples = cls_samples.iloc[support_size:support_size + query_size]

        for _, row in support_samples.iterrows():
            img = load_processed_image(row['id_code'], PROCESSED_FOLDER, model_type=model_type)
            if img is not None and img.shape == (size, size, 3):
                support_images.append(img)
                support_labels.append(cls)
            else:
                logging.warning(f"Bỏ qua ảnh support không hợp lệ: {row['id_code']}, shape: {getattr(img, 'shape', 'None')}")

        for _, row in query_samples.iterrows():
            img = load_processed_image(row['id_code'], PROCESSED_FOLDER, model_type=model_type)
            if img is not None and img.shape == (size, size, 3):
                query_images.append(img)
                query_labels.append(cls)
            else:
                logging.warning(f"Bỏ qua ảnh query không hợp lệ: {row['id_code']}, shape: {getattr(img, 'shape', 'None')}")

        included_classes.append(cls)

    if not included_classes:
        logging.error("Không có lớp nào có đủ mẫu cho episode")
        return np.zeros((0, size, size, 3)), np.zeros((0, num_classes)), np.zeros((0, size, size, 3)), np.zeros((0, num_classes))

    if len(support_images) < support_size * len(included_classes) or len(query_images) < query_size * len(included_classes):
        logging.error(f"Không đủ ảnh hợp lệ: support_images={len(support_images)}, query_images={len(query_images)}")
        return np.zeros((0, size, size, 3)), np.zeros((0, num_classes)), np.zeros((0, size, size, 3)), np.zeros((0, num_classes))

    try:
        # Kiểm tra và in hình dạng của từng ảnh
        for i, img in enumerate(support_images + query_images):
            if img.shape != (size, size, 3):
                logging.error(f"Ảnh {i} có shape không hợp lệ: {img.shape}")
                raise ValueError(f"Ảnh không hợp lệ trong episode: shape {img.shape}")
        support_images = np.array(support_images, dtype=np.float32)
        support_labels = tf.keras.utils.to_categorical(support_labels, num_classes)
        query_images = np.array(query_images, dtype=np.float32)
        query_labels = tf.keras.utils.to_categorical(query_labels, num_classes)
    except Exception as e:
        logging.error(f"Lỗi khi chuyển đổi mảng: {str(e)}")
        return np.zeros((0, size, size, 3)), np.zeros((0, num_classes)), np.zeros((0, size, size, 3)), np.zeros((0, num_classes))

    logging.info(f"Đã tạo episode: support_images={support_images.shape}, support_labels={support_labels.shape}, "
                 f"query_images={query_images.shape}, query_labels={query_labels.shape}, classes={included_classes}")
    return support_images, support_labels, query_images, query_labels

import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, BatchNormalization, ReLU, GlobalAveragePooling2D, Dense, Dropout, Concatenate
from tensorflow.keras.models import Model
import numpy as np
import logging
import os
from sklearn.metrics import cohen_kappa_score, f1_score

def build_meta_conv_model(img_shape, feature_2d_shape, num_classes):
    """
    Xây dựng mô hình meta-learning với kiến trúc đơn giản hơn.
    """
    img_input = Input(shape=img_shape, name='img_input')
    x = Conv2D(32, (3, 3), padding='same', activation=None)(img_input)  # Giảm từ 64
    x = BatchNormalization(trainable=True)(x)
    x = ReLU()(x)
    x = MaxPooling2D((2, 2))(x)
    x = Conv2D(64, (3, 3), padding='same', activation=None)(x)  # Giảm từ 128
    x = BatchNormalization(trainable=True)(x)
    x = ReLU()(x)
    x = MaxPooling2D((2, 2))(x)
    x = Conv2D(128, (3, 3), padding='same', activation=None)(x)  # Giảm từ 256
    x = BatchNormalization(trainable=True)(x)
    x = ReLU()(x)
    x = MaxPooling2D((4, 4))(x)
    conv_features = x

    x = GlobalAveragePooling2D()(conv_features)

    feature_2d_input = Input(shape=feature_2d_shape, name='feature_2d_input')
    y = Dense(128, activation='relu')(feature_2d_input)  # Giảm từ 256
    y = BatchNormalization(trainable=True)(y)
    y = Dropout(0.5)(y)

    combined = Concatenate()([x, y])
    z = Dense(64, activation='relu')(combined)  # Giảm từ 128
    z = BatchNormalization(trainable=True)(z)
    z = Dropout(0.5)(z)
    logits = Dense(num_classes, activation='softmax')(z)

    model = Model(inputs=[img_input, feature_2d_input], outputs=[conv_features, logits], name='meta_conv_model')
    return model

def maml_fomaml_train_manual(
    meta_model,
    data_df,
    valid_features,
    valid_labels,
    valid_images,
    input_dim,
    model_type="efficientnetb1",
    n_episodes=40,
    inner_lr=0.01,
    outer_lr=0.001,
    fine_tune_lr=0.0001,
    support_size=16,
    query_size=16,
    save_dir="./features",
    sample_ids=None,
    callbacks=None
):
    logging.info(f"Khởi tạo huấn luyện MAML/FOMAML cho {model_type}")

    optimizer = tf.keras.optimizers.Adam(learning_rate=outer_lr)
    qwk_scores = []
    best_qwk = -1.0
    history = {"qwk": [], "loss": []}
    NUM_CLASSES = valid_labels.shape[1] if valid_labels is not None else 5
    BATCH_SIZE = 16

    if valid_images.shape[0] != valid_features.shape[0]:
        logging.error(f"Số lượng mẫu không khớp: valid_images={valid_images.shape[0]}, valid_features={valid_features.shape[0]}")
        raise ValueError("Số lượng mẫu không khớp giữa ảnh và đặc trưng.")
    logging.info(f"Dữ liệu validation: images_shape={valid_images.shape}, features_shape={valid_features.shape}, labels_shape={valid_labels.shape}")

    config = MODEL_CONFIGS[model_type]
    input_shape = (config['img_size'], config['img_size'], 3)
    try:
        # Trong vòng lặp model_name trong main_pipeline
        base_model = load_model_from_config(
            config['config_path'], config['weights_path'], config['base_model'], input_shape
        )
        feature_2d_layers = [
            layer for layer in base_model.layers
            if not isinstance(layer, tf.keras.layers.InputLayer) and len(layer.output.shape) == 2
        ]
        if not feature_2d_layers:
            logging.warning(f"Không tìm thấy tầng 2D trong {model_name}, thêm GlobalAveragePooling2D")
            x = base_model.output
            x = GlobalAveragePooling2D()(x)
            feature_2d_output = x
        else:
            feature_2d_output = feature_2d_layers[-1].output
        feature_extractor = tf.keras.Model(inputs=base_model.input, outputs=feature_2d_output)
        logging.info(f"Feature extractor output shape: {feature_2d_output.shape}")
    except Exception as e:
        logging.error(f"Lỗi tải base model hoặc tạo feature extractor cho {model_type}: {str(e)}")
        raise

    # Tiếp tục với các bước còn lại
    try:
        if not isinstance(valid_images, np.ndarray):
            logging.warning(f"valid_images không phải np.ndarray, type: {type(valid_images)}. Chuyển đổi sang np.ndarray.")
            valid_images = np.array(valid_images)
        if len(valid_images.shape) != 4:
            logging.error(f"valid_images có shape không hợp lệ: {valid_images.shape}. Kỳ vọng mảng 4D.")
            raise ValueError("valid_images phải là mảng 4D.")

        valid_images_preprocessed = config['preprocess'](valid_images)
        if not isinstance(valid_images_preprocessed, np.ndarray):
            logging.warning(f"valid_images_preprocessed không phải np.ndarray, type: {type(valid_images_preprocessed)}. Chuyển đổi sang np.ndarray.")
            valid_images_preprocessed = np.array(valid_images_preprocessed)
        logging.info(f"Valid images preprocessed shape: {valid_images_preprocessed.shape}, dtype: {valid_images_preprocessed.dtype}")
        valid_features_2d = feature_extractor.predict(valid_images_preprocessed, batch_size=BATCH_SIZE, verbose=0)
        logging.info(f"Valid features 2D shape: {valid_features_2d.shape}, dtype: {valid_features_2d.dtype}")
        tf.keras.backend.clear_session()
        gc.collect()
        check_memory()
    except Exception as e:
        logging.error(f"Lỗi tiền xử lý hoặc trích xuất đặc trưng: {str(e)}")
        raise

    # ... (phần còn lại của hàm giữ nguyên)

    # Xây dựng meta_conv_model nếu chưa có
    if meta_model is None:
        try:
            meta_model = build_meta_conv_model(
                img_shape=input_shape,
                feature_2d_shape=(valid_features_2d.shape[1],),
                num_classes=NUM_CLASSES
            )
            logging.info("Đã xây dựng meta_conv_model thành công.")
        except Exception as e:
            logging.error(f"Lỗi xây dựng meta_conv_model: {str(e)}")
            raise

    # Tạo classification model cho tinh chỉnh và dự đoán
    img_input = tf.keras.Input(shape=input_shape, name='img_input')
    feature_2d_input = tf.keras.Input(shape=(valid_features_2d.shape[1],), name='feature_2d_input')
    conv_features, logits = meta_model([img_input, feature_2d_input])
    classification_model = tf.keras.Model(
        inputs=[img_input, feature_2d_input],
        outputs=logits,
        name='meta_classification_model'
    )

    # Callback cho confusion matrix
    class_counts = np.sum(valid_labels, axis=0) if valid_labels is not None else None
    cm_callback = ConfusionMatrixWeightCallback(
        valid_features=[valid_images_preprocessed, valid_features_2d],
        valid_labels=valid_labels,
        classification_model=classification_model,
        num_classes=NUM_CLASSES,
        class_counts=class_counts
    )
    checkpoint_callback = MultiMetricCheckpoint(
        filepath=os.path.join(save_dir, f"meta_model_best_{model_type}_{{epoch}}.weights.h5"),
        monitor_metrics={'qwk': 'max', 'f1': 'max'},
        mode='max',
        save_best_only=True
    )
    callbacks = [cm_callback, checkpoint_callback] + (callbacks or [])

    # Huấn luyện từng episode
    for episode in range(n_episodes):
        logging.info(f"Huấn luyện episode {episode + 1}/{n_episodes}")
        try:
            # Tạo support và query set
            support_images, support_labels, query_images, query_labels = create_episode(
                data_df, support_size, query_size, NUM_CLASSES, model_type=model_type
            )
            if support_images.shape[0] == 0 or query_images.shape[0] == 0:
                logging.warning(f"Episode rỗng tại {episode + 1}, bỏ qua")
                continue

            # Tiền xử lý và trích xuất đặc trưng
            support_preprocessed = config['preprocess'](support_images)
            query_preprocessed = config['preprocess'](query_images)
            support_features_2d = feature_extractor.predict(support_preprocessed, batch_size=BATCH_SIZE, verbose=0)
            query_features_2d = feature_extractor.predict(query_preprocessed, batch_size=BATCH_SIZE, verbose=0)
            tf.keras.backend.clear_session()
            gc.collect()
            check_memory()

            logging.info(f"Support set: images={support_preprocessed.shape}, features={support_features_2d.shape}")
            logging.info(f"Query set: images={query_preprocessed.shape}, features={query_features_2d.shape}")

            # Kiểm tra shape
            if support_preprocessed.shape[0] != support_features_2d.shape[0]:
                logging.error(f"Số lượng mẫu không khớp trong support set: images={support_preprocessed.shape[0]}, features={support_features_2d.shape[0]}")
                raise ValueError("Số lượng mẫu không khớp trong support set.")
            if query_preprocessed.shape[0] != query_features_2d.shape[0]:
                logging.error(f"Số lượng mẫu không khớp trong query set: images={query_preprocessed.shape[0]}, features={query_features_2d.shape[0]}")
                raise ValueError("Số lượng mẫu không khớp trong query set.")

            # MAML inner loop (FOMAML style)
            fast_weights = [w.numpy() for w in meta_model.weights]  # Lưu trọng số gốc
            with tf.GradientTape() as inner_tape:
                support_conv_features, support_logits = meta_model(
                    [support_preprocessed, support_features_2d], training=True
                )
                support_loss = tf.reduce_mean(tf.keras.losses.categorical_crossentropy(support_labels, support_logits))
            inner_grads = inner_tape.gradient(support_loss, meta_model.trainable_variables)

            # Cập nhật trọng số nhanh
            fast_weights_updated = fast_weights.copy()
            trainable_indices = [i for i, w in enumerate(meta_model.weights) if w.trainable]
            grad_idx = 0
            for i in trainable_indices:
                if grad_idx < len(inner_grads) and inner_grads[grad_idx] is not None:
                    fast_weights_updated[i] = fast_weights[i] - inner_lr * inner_grads[grad_idx]
                grad_idx += 1

            meta_model.set_weights(fast_weights_updated)  # Cập nhật trọng số nhanh

            # MAML outer loop
            with tf.GradientTape() as outer_tape:
                query_conv_features, query_logits = meta_model(
                    [query_preprocessed, query_features_2d], training=True
                )
                query_loss = tf.reduce_mean(tf.keras.losses.categorical_crossentropy(query_labels, query_logits))
                query_logits = apply_temperature_scaling(query_logits, temperature=2.0)
                query_logits = laplace_smoothing(query_logits)
            outer_grads = outer_tape.gradient(query_loss, meta_model.trainable_variables)
            valid_grads_and_vars = [(g, v) for g, v in zip(outer_grads, meta_model.trainable_variables) if g is not None]
            if not valid_grads_and_vars:
                logging.error("Không có gradient ngoài hợp lệ để áp dụng.")
                raise ValueError("Không có gradient ngoài hợp lệ.")
            optimizer.apply_gradients(valid_grads_and_vars)

            # Khôi phục trọng số gốc
            meta_model.set_weights(fast_weights)

            # Đánh giá
            y_true = np.argmax(query_labels, axis=1)
            y_pred = np.argmax(query_logits, axis=1)
            qwk = cohen_kappa_score(y_true, y_pred, weights='quadratic')
            qwk_scores.append(qwk)
            history["qwk"].append(float(qwk))
            history["loss"].append(float(query_loss))
            logging.info(f"Episode {episode + 1}: QWK={qwk:.4f}, Loss={query_loss:.4f}")

            # Gọi callbacks
            for callback in callbacks:
                callback.on_epoch_end(episode, logs={"qwk": qwk, "f1": f1_score(y_true, y_pred, average='weighted')})

            # Lưu confusion matrix
            save_confusion_matrix(y_true, y_pred, episode, qwk, save_dir, prefix=f'meta_{model_type}_')

            # Lưu mô hình tốt nhất
            if qwk > best_qwk:
                best_qwk = qwk
                training_config = {
                    "inner_lr": float(inner_lr),
                    "outer_lr": float(outer_lr),
                    "fine_tune_lr": float(fine_tune_lr),
                    "train_config": {
                        "support_size": support_size,
                        "query_size": query_size,
                        "n_episodes": n_episodes,
                        "model_type": model_type
                    }
                }
                save_meta_learner_model(meta_model, save_dir, episode, qwk_scores, training_config)
                logging.info(f"Đã lưu mô hình tốt nhất tại episode {episode + 1}")

            # Quản lý bộ nhớ
            tf.keras.backend.clear_session()
            gc.collect()
            check_memory()

        except Exception as e:
            logging.error(f"Lỗi trong episode {episode + 1}: {str(e)}")
            continue

    # Tinh chỉnh classification_model
    try:
        classification_model.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=fine_tune_lr),
            loss='categorical_crossentropy',
            metrics=['accuracy']
        )
        classification_model.fit(
            [valid_images_preprocessed, valid_features_2d],
            valid_labels,
            batch_size=BATCH_SIZE,
            epochs=5,
            verbose=1,
            callbacks=callbacks
        )
        logging.info("Hoàn thành tinh chỉnh classification_model.")
    except Exception as e:
        logging.error(f"Lỗi khi tinh chỉnh classification_model: {str(e)}")
        raise

    return meta_model, classification_model, history

# Hàm cross-validation
def cross_validate_pipeline(
    x, y, processed_folder, model_configs, feature_save_dir, n_splits=5, n_episodes=20, **kwargs
):
    """
    Performs k-fold cross-validation for the meta-learning pipeline.

    Args:
        x (pd.Series): Series of image IDs (id_code).
        y (np.ndarray): Array of labels (diagnosis).
        processed_folder (str): Path to folder containing processed images.
        model_configs (dict): Dictionary of model configurations (MODEL_CONFIGS).
        feature_save_dir (str): Directory to save features and results.
        n_splits (int): Number of folds for cross-validation.
        n_episodes (int): Number of episodes for meta-learning.
        **kwargs: Additional arguments for maml_fomaml_train_manual.

    Returns:
        dict: Cross-validation results with QWK, F1, and recall scores.
    """
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
    cv_results = {"qwk": [], "f1": [], "recall": []}

    for fold, (train_idx, valid_idx) in enumerate(kf.split(x)):
        logging.info(f"Processing fold {fold + 1}/{n_splits}")

        # Split data
        train_x, valid_x = x.iloc[train_idx].values, x.iloc[valid_idx].values
        train_y, valid_y = y[train_idx], y[valid_idx]

        # Prepare image paths and filter missing images
        train_image_paths = [os.path.join(processed_folder, f"{id_code}.png") for id_code in train_x]
        valid_image_paths = [os.path.join(processed_folder, f"{id_code}.png") for id_code in valid_x]
        train_sample_ids = []
        train_y_filtered = []
        valid_sample_ids = []
        valid_y_filtered = []

        # Filter train data
        for path, id_code, label in zip(train_image_paths, train_x, train_y):
            if os.path.exists(path):
                train_sample_ids.append(id_code)
                train_y_filtered.append(label)
            else:
                logging.warning(f"Skipping non-existent train image: {path}")
        train_y_filtered = np.array(train_y_filtered)

        # Filter valid data
        for path, id_code, label in zip(valid_image_paths, valid_x, valid_y):
            if os.path.exists(path):
                valid_sample_ids.append(id_code)
                valid_y_filtered.append(label)
            else:
                logging.warning(f"Skipping non-existent valid image: {path}")
        valid_y_filtered = np.array(valid_y_filtered)

        if len(train_sample_ids) == 0 or len(valid_sample_ids) == 0:
            logging.error(f"Fold {fold + 1}: Empty train or valid set after filtering.")
            continue

        # Convert labels to one-hot encoding
        train_y_multi = tf.keras.utils.to_categorical(train_y_filtered, num_classes=NUM_CLASSES)
        valid_y_multi = tf.keras.utils.to_categorical(valid_y_filtered, num_classes=NUM_CLASSES)

        # Load train and valid images
        train_images = np.array([
            img for img in [load_processed_image(id_code, processed_folder, model_type="efficientnetb1") for id_code in train_sample_ids]
            if img is not None
        ])
        valid_images = np.array([
            img for img in [load_processed_image(id_code, processed_folder, model_type="efficientnetb1") for id_code in valid_sample_ids]
            if img is not None
        ])
        train_y_multi = train_y_multi[:len(train_images)]  # Adjust labels to match images
        valid_y_multi = valid_y_multi[:len(valid_images)]

        # Balance training data
        balanced_train_x, balanced_train_y_multi = balance_data(
            train_images, train_y_multi, target_classes=[0, 1, 2, 3, 4]
        )

        # Initialize feature dictionaries
        train_features_dict, valid_features_dict = {}, {}

        # Extract features for each model
        for model_name, config in model_configs.items():
            logging.info(f"Extracting features for {model_name} in fold {fold + 1}")

            # Create generators
            train_generator = My_Generator(
                images=None,
                labels=balanced_train_y_multi,
                batch_size=BATCH_SIZE,
                is_train=True,
                mix=False,
                augment=False,
                model_type=config['model_type'],
                preprocess=config['preprocess'],
                image_paths=[os.path.join(processed_folder, f"{id_code}.png") for id_code in train_sample_ids],
                sample_ids=train_sample_ids,
                temp_augment_dir=TEMP_AUGMENT_DIR
            )
            valid_generator = My_Generator(
                images=None,
                labels=valid_y_multi,
                batch_size=BATCH_SIZE,
                is_train=False,
                mix=False,
                augment=False,
                model_type=config['model_type'],
                preprocess=config['preprocess'],
                image_paths=[os.path.join(processed_folder, f"{id_code}.png") for id_code in valid_sample_ids],
                sample_ids=valid_sample_ids,
                temp_augment_dir=TEMP_AUGMENT_DIR
            )

            # Load model and create feature extractor
            base_model = load_model_from_config(
                config['config_path'],
                config['weights_path'],
                config['base_model'],
                input_shape=(config['img_size'], config['img_size'], 3)
            )
            feature_layer = base_model.layers[-2].output
            feature_extractor = Model(inputs=base_model.input, outputs=feature_layer)

            # Extract and save features
            train_features, train_features_ids = extract_and_save_features(
                model_name, feature_extractor, train_generator, feature_save_dir, sample_ids=train_sample_ids
            )
            valid_features, valid_features_ids = extract_and_save_features(
                model_name, feature_extractor, valid_generator, feature_save_dir, sample_ids=valid_sample_ids
            )

            train_features_dict[model_name] = train_features
            valid_features_dict[model_name] = valid_features

            # Clean up
            del base_model, feature_extractor
            tf.keras.backend.clear_session()
            gc.collect()

        # Combine and reduce features (skip PCA to keep 1024 dims)
        train_combined, train_labels, train_indices, input_dim = combine_and_reduce_features(
            train_features_dict, {}, balanced_train_y_multi, train_sample_ids, feature_save_dir, n_components=None
        )
        valid_combined, valid_labels, valid_indices, _ = combine_and_reduce_features(
            valid_features_dict, {}, valid_y_multi, valid_sample_ids, feature_save_dir, n_components=None
        )

        # Create feature extractor for validation
        img_size = model_configs["efficientnetb1"]["img_size"]
        base_model = load_model_from_config(
            model_configs["efficientnetb1"]["config_path"],
            model_configs["efficientnetb1"]["weights_path"],
            model_configs["efficientnetb1"]["base_model"],
            input_shape=(img_size, img_size, 3)
        )
        feature_layer = base_model.layers[-2].output
        feature_extractor = Model(inputs=base_model.input, outputs=feature_layer)
        valid_images_preprocessed = model_configs["efficientnetb1"]["preprocess"](valid_images)
        valid_features_2d = feature_extractor.predict(valid_images_preprocessed, batch_size=BATCH_SIZE)
        del base_model, feature_extractor
        tf.keras.backend.clear_session()
        gc.collect()

        # Build meta-model
        meta_model = build_meta_conv_model(
            img_shape=(img_size, img_size, 3),
            feature_2d_shape=(1024,),
            num_classes=NUM_CLASSES
        )

        # Train meta-model
        meta_model, meta_classification_model, history = maml_fomaml_train_manual(
            meta_model=meta_model,
            data_df=pd.DataFrame({'id_code': np.array(train_sample_ids)[train_indices], 'diagnosis': np.argmax(train_labels[train_indices], axis=1)}),
            valid_features=valid_features_2d,
            valid_labels=valid_y_multi[valid_indices],
            valid_images=valid_images[valid_indices],
            input_dim=input_dim,
            model_type="efficientnetb1",
            n_episodes=n_episodes,
            sample_ids=np.array(valid_sample_ids)[valid_indices],
            **kwargs
        )

        # Predict and evaluate
        logging.info(f"Valid images shape: {valid_images_preprocessed.shape}")
        logging.info(f"Valid features 2D shape: {valid_features_2d.shape}")
        if valid_images_preprocessed.shape[0] != valid_features_2d.shape[0]:
            logging.error(f"Sample count mismatch: images={valid_images_preprocessed.shape[0]}, features={valid_features_2d.shape[0]}")
            raise ValueError("Sample count mismatch")

        _, y_pred = meta_classification_model.predict(
            [valid_images_preprocessed[valid_indices], valid_features_2d[valid_indices]], batch_size=BATCH_SIZE
        )
        y_pred_classes = np.argmax(y_pred, axis=1)
        y_true = np.argmax(valid_y_multi[valid_indices], axis=1)

        qwk = cohen_kappa_score(y_true, y_pred_classes, labels=[0, 1, 2, 3, 4], weights='quadratic')
        f1 = f1_score(y_true, y_pred_classes, average='weighted')
        recall = recall_score(y_true, y_pred_classes, average='weighted')

        cv_results["qwk"].append(qwk)
        cv_results["f1"].append(f1)
        cv_results["recall"].append(recall)

        logging.info(f"Fold {fold + 1}: QWK={qwk:.4f}, F1={f1:.4f}, Recall={recall:.4f}")

        # Save confusion matrix
        save_confusion_matrix(y_true, y_pred_classes, fold, qwk, feature_save_dir, prefix=f'fold_{fold + 1}_')

        # Clean up
        del meta_model, meta_classification_model
        tf.keras.backend.clear_session()
        gc.collect()

    # Summarize results
    cv_summary = {
        "mean_qwk": np.mean(cv_results["qwk"]) if cv_results["qwk"] else 0.0,
        "std_qwk": np.std(cv_results["qwk"]) if cv_results["qwk"] else 0.0,
        "mean_f1": np.mean(cv_results["f1"]) if cv_results["f1"] else 0.0,
        "std_f1": np.std(cv_results["f1"]) if cv_results["f1"] else 0.0,
        "mean_recall": np.mean(cv_results["recall"]) if cv_results["recall"] else 0.0,
        "std_recall": np.std(cv_results["recall"]) if cv_results["recall"] else 0.0
    }

    summary_path = Path(feature_save_dir) / "cross_validation_summary.json"
    with open(summary_path, 'w') as f:
        json.dump(cv_summary, f, indent=4)
    logging.info(f"Saved cross-validation summary at: {summary_path}")

    return cv_results
# Hàm hyperparameter tuning
def hyperparameter_tuning(
    x, y, processed_folder, model_configs, feature_save_dir, param_grid, n_iter=10, n_episodes=20
):
    if n_iter is None:
        param_combinations = list(product(*param_grid.values()))
    else:
        param_combinations = [
            {k: random.choice(v) for k, v in param_grid.items()} for _ in range(n_iter)
        ]

    best_qwk = -float('inf')
    best_params = None

    for params in param_combinations:
        logging.info(f"Thử tham số: {params}")

        cv_results = cross_validate_pipeline(
            x, y, processed_folder, model_configs, feature_save_dir, n_splits=3, n_episodes=n_episodes, **params
        )

        mean_qwk = np.mean(cv_results["qwk"])
        if mean_qwk > best_qwk:
            best_qwk = mean_qwk
            best_params = params

        logging.info(f"Mean QWK: {mean_qwk:.4f}")

    logging.info(f"Tham số tốt nhất: {best_params}, QWK tốt nhất: {best_qwk:.4f}")

    return best_params, best_qwk

# Hàm tạo Grad-CAM heatmap
def generate_gradcam_heatmap(
    meta_conv_model,
    img_array,
    feature_2d,
    class_idx,
    conv_layer_output='conv2d_3'  # Tên tầng cuối cùng trong build_meta_conv_model
):
    try:
        # Tạo mô hình trung gian để lấy conv_features
        intermediate_model = Model(
            inputs=meta_conv_model.input,
            outputs=meta_conv_model.get_layer(conv_layer_output).output
        )
        with tf.GradientTape() as tape:
            conv_output = intermediate_model([img_array, feature_2d], training=False)
            tape.watch(conv_output)
            _, predictions = meta_conv_model([img_array, feature_2d], training=False)
            loss = predictions[:, class_idx]
        grads = tape.gradient(loss, conv_output)
        if grads is None:
            logging.error(f"Gradient trả về None cho class {class_idx}")
            return None
        pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
        conv_output = conv_output[0]
        heatmap = tf.reduce_mean(tf.multiply(conv_output, pooled_grads), axis=-1)
        heatmap = tf.maximum(heatmap, 0) / (tf.reduce_max(heatmap) + 1e-10)
        heatmap = tf.image.resize(
            heatmap[..., tf.newaxis],
            [img_array.shape[1], img_array.shape[2]],
            method='bilinear'
        )
        heatmap = tf.squeeze(heatmap).numpy()
        return heatmap
    except Exception as e:
        logging.error(f"Lỗi khi tạo Grad-CAM: {str(e)}")
        return None

# Hàm ensemble predictions
def ensemble_predictions(model_configs, generators, test_ids, feature_save_dir, weights=None, meta_classification_model=None):
    predictions = []

    for model_name, config in model_configs.items():
        logging.info(f"Trích xuất đặc trưng cho {model_name}...")
        base_model = load_model_from_config(
            config['config_path'], config['weights_path'], config['base_model']
        )
        feature_layer = base_model.layers[-2].output
        feature_extractor = Model(inputs=base_model.input, outputs=feature_layer)

        features_2d, _ = extract_and_save_features(
            model_name,
            feature_extractor,
            generators[model_name],
            feature_save_dir,
            sample_ids=test_ids
        )

        probs = meta_classification_model.predict(features_2d, batch_size=32)
        predictions.append(probs)

        del base_model, feature_extractor
        tf.keras.backend.clear_session()
        gc.collect()
        logging.info(f"Đã trích xuất {len(features_2d)} đặc trưng cho {model_name}")

    weights = weights or [1.0 / len(predictions)] * len(predictions)
    ensemble_probs = np.average(predictions, axis=0, weights=weights)
    return ensemble_probs

# Hàm tạo pipeline report
def generate_pipeline_report(
    dataset_stats, cv_results, test_metrics, runtime, save_dir, filename="pipeline_report.json"
):
    report = {
        "timestamp": datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
        "dataset_statistics": dataset_stats,
        "cross_validation_results": cv_results,
        "test_metrics": test_metrics,
        "runtime_seconds": runtime,
        "model_configurations": list(MODEL_CONFIGS.keys()),
        "num_classes": NUM_CLASSES,
        "image_size": SIZE
    }

    report_path = Path(save_dir) / filename
    with open(report_path, 'w') as f:
        json.dump(report, f, indent=4)
    logging.info(f"Đã lưu báo cáo pipeline tại: {report_path}")


def prepare_data():
    try:
        df_train = pd.read_csv(os.path.join(DRIVE_FOLDER, "train.csv"))
        logging.info(f"Đã đọc train.csv: {len(df_train)} mẫu")
    except Exception as e:
        logging.error(f"Lỗi khi đọc train.csv: {str(e)}")
        return None, None, None, None, None, None, None, None, None, None, None, None

    if not {'id_code', 'diagnosis'}.issubset(df_train.columns):
        logging.error("File train.csv thiếu cột 'id_code' hoặc 'diagnosis'.")
        return None, None, None, None, None, None, None, None, None, None, None, None

    # Xác thực dataset một lần
    processed_image_ids = validate_dataset(get_image_ids(PROCESSED_FOLDER), PROCESSED_FOLDER, FEATURE_SAVE_DIR)
    if not processed_image_ids:
        logging.error("Không tìm thấy ảnh trong PROCESSED_FOLDER.")
        return None, None, None, None, None, None, None, None, None, None, None, None
    logging.info(f"Tìm thấy {len(processed_image_ids)} ID ảnh trong PROCESSED_FOLDER: {processed_image_ids[:5]}...")

    # Lọc df_train
    df_train_processed = df_train[df_train['id_code'].isin(processed_image_ids)].copy()
    logging.info(f"Số mẫu sau khi lọc với PROCESSED_FOLDER: {len(df_train_processed)}")

    if df_train_processed.empty:
        logging.error("Không có mẫu hợp lệ sau khi lọc.")
        return None, None, None, None, None, None, None, None, None, None, None, None

    x = df_train_processed['id_code']
    y = df_train_processed['diagnosis']
    x, y = shuffle(x, y, random_state=42)

    # Chia dữ liệu
    x_temp, test_x, y_temp, test_y = train_test_split(x, y, test_size=0.20, stratify=y, random_state=42)
    train_x, valid_x, train_y, valid_y = train_test_split(x_temp, y_temp, test_size=0.15/0.80, stratify=y_temp, random_state=42)

    train_ids = np.array(train_x)
    valid_ids = np.array(valid_x)
    test_ids = np.array(test_x)

    # Kiểm tra rò rỉ dữ liệu
    if not check_data_leakage(train_ids, valid_ids, test_ids, df_train_processed):
        logging.error("Phát hiện data leakage.")
        raise ValueError("Data leakage được phát hiện.")

    logging.info(f"Số mẫu: Train={len(train_x)}, Valid={len(valid_x)}, Test={len(test_x)}")

    # Tải ảnh train
    train_images = []
    train_ids_filtered = []
    train_y_filtered = []
    train_error_ids = []
    for id_code, label in zip(train_x, train_y):
        img = load_processed_image(id_code, PROCESSED_FOLDER, model_type="efficientnetb1")
        if img is not None:
            train_images.append(img)
            train_ids_filtered.append(id_code)
            train_y_filtered.append(label)
        else:
            train_error_ids.append(id_code)
            logging.warning(f"Không tải được ảnh train: {id_code}")
    if not train_images:
        logging.error("Không tải được ảnh train nào.")
        raise ValueError("Tập train rỗng sau khi lọc.")
    train_images = np.array(train_images)
    train_y_filtered = np.array(train_y_filtered)
    train_ids_filtered = np.array(train_ids_filtered)
    logging.info(f"Train: {len(train_images)} ảnh, {len(train_ids_filtered)} IDs, {len(train_y_filtered)} nhãn, {len(train_error_ids)} lỗi")

    # Tải ảnh valid
    valid_images = []
    valid_ids_filtered = []
    valid_y_filtered = []
    valid_error_ids = []
    for id_code, label in zip(valid_x, valid_y):
        img = load_processed_image(id_code, PROCESSED_FOLDER, model_type="efficientnetb1")
        if img is not None:
            valid_images.append(img)
            valid_ids_filtered.append(id_code)
            valid_y_filtered.append(label)
        else:
            valid_error_ids.append(id_code)
            logging.warning(f"Không tải được ảnh valid: {id_code}")
    valid_images = np.array(valid_images)
    valid_y_filtered = np.array(valid_y_filtered)
    valid_ids_filtered = np.array(valid_ids_filtered)
    logging.info(f"Valid: {len(valid_images)} ảnh, {len(valid_ids_filtered)} IDs, {len(valid_y_filtered)} nhãn, {len(valid_error_ids)} lỗi")

    # Tải ảnh test
    test_images = []
    test_ids_filtered = []
    test_y_filtered = []
    test_error_ids = []
    for id_code, label in zip(test_x, test_y):
        img = load_processed_image(id_code, PROCESSED_FOLDER, model_type="efficientnetb1")
        if img is not None:
            test_images.append(img)
            test_ids_filtered.append(id_code)
            test_y_filtered.append(label)
        else:
            test_error_ids.append(id_code)
            logging.warning(f"Không tải được ảnh test: {id_code}")
    test_images = np.array(test_images)
    test_y_filtered = np.array(test_y_filtered)
    test_ids_filtered = np.array(test_ids_filtered)
    logging.info(f"Test: {len(test_images)} ảnh, {len(test_ids_filtered)} IDs, {len(test_y_filtered)} nhãn, {len(test_error_ids)} lỗi")

    # Lưu báo cáo lỗi
    error_report = {
        "train_error_ids": train_error_ids,
        "valid_error_ids": valid_error_ids,
        "test_error_ids": test_error_ids,
        "timestamp": datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    }
    error_report_path = os.path.join(FEATURE_SAVE_DIR, "image_load_errors.json")
    with open(error_report_path, 'w') as f:
        json.dump(error_report, f, indent=4)
    logging.info(f"Đã lưu báo cáo ảnh lỗi tại: {error_report_path}")

    # Chuyển nhãn sang one-hot encoding
    train_y_multi = tf.keras.utils.to_categorical(train_y_filtered, num_classes=NUM_CLASSES)
    valid_y_multi = tf.keras.utils.to_categorical(valid_y_filtered, num_classes=NUM_CLASSES)
    test_y_multi = tf.keras.utils.to_categorical(test_y_filtered, num_classes=NUM_CLASSES)

    # Không cân bằng dữ liệu
    balanced_train_x = train_images
    balanced_train_y_multi = train_y_multi
    balanced_train_ids = train_ids_filtered

    # Báo cáo phân bố nhãn
    label_indices = np.argmax(balanced_train_y_multi, axis=1)
    class_counts = np.bincount(label_indices, minlength=NUM_CLASSES)
    logging.info(f"Phân bố nhãn tập train: {dict(zip(range(NUM_CLASSES), class_counts))}")

    return (balanced_train_x, balanced_train_y_multi, valid_images, valid_y_multi, test_images, test_y_multi,
            df_train_processed, train_ids_filtered, valid_ids_filtered, test_ids_filtered, balanced_train_ids, processed_image_ids)

def main_pipeline(
    balanced_train_x, balanced_train_y_multi, valid_images, valid_y_multi, test_images, test_y_multi,
    df_train_processed, train_ids, valid_ids, test_ids, balanced_train_ids, processed_image_ids
):
    start_time = time.time()
    logging.info("Starting pipeline...")

    # Initialize dictionaries for features
    train_features_dict = {}
    valid_features_dict = {}
    test_features_dict = {}
    meta_models = {}
    BATCH_SIZE = 16  # Reduced for memory efficiency

    # Filter test data to match processed IDs
    test_indices = [i for i, sid in enumerate(test_ids) if sid in processed_image_ids]
    if not test_indices:
        logging.error("No test samples match processed_image_ids.")
        raise ValueError("No valid test samples.")
    test_images = test_images[test_indices]
    test_y_multi = test_y_multi[test_indices]
    test_ids = [test_ids[i] for i in test_indices]
    logging.info(f"Filtered test data: {len(test_indices)} samples")

    # Process each model
    for model_name in MODEL_CONFIGS:
        logging.info(f"Processing {model_name}...")
        try:
            config = MODEL_CONFIGS[model_name]
            input_shape = (config['img_size'], config['img_size'], 3)

            # Load base model and create feature extractor
            base_model = load_model_from_config(
                config['config_path'], config['weights_path'], config['base_model'], input_shape
            )
            feature_2d_layers = [
                layer for layer in base_model.layers
                if not isinstance(layer, tf.keras.layers.InputLayer) and len(layer.output.shape) == 2
            ]
            if not feature_2d_layers:
                logging.error(f"No 2D layers found in {model_name}")
                continue
            selected_layer = feature_2d_layers[-1]
            feature_extractor = tf.keras.Model(inputs=base_model.input, outputs=selected_layer.output)
            logging.info(f"Feature extractor output shape: {selected_layer.output.shape}")

            # Create generators
            train_generator = My_Generator(
                images=balanced_train_x, labels=balanced_train_y_multi, batch_size=BATCH_SIZE,
                is_train=True, model_type=model_name, preprocess=config['preprocess'],
                sample_ids=balanced_train_ids
            )
            valid_generator = My_Generator(
                images=valid_images, labels=valid_y_multi, batch_size=BATCH_SIZE,
                model_type=model_name, preprocess=config['preprocess'], sample_ids=valid_ids
            )
            test_generator = My_Generator(
                images=test_images, labels=test_y_multi, batch_size=BATCH_SIZE,
                model_type=model_name, preprocess=config['preprocess'], sample_ids=test_ids
            )

            # Extract features
            def extract_features(generator, feature_extractor):
                features = []
                feature_ids = []
                for batch_idx in range(len(generator)):
                    batch_images, _, batch_ids = generator[batch_idx]
                    if len(batch_images) == 0:
                        logging.warning(f"Empty batch {batch_idx}, skipping")
                        continue
                    batch_features = feature_extractor.predict(
                        batch_images, batch_size=BATCH_SIZE, verbose=0
                    )
                    if len(batch_features.shape) > 2:
                        batch_features = np.mean(batch_features, axis=(1, 2))
                    features.append(batch_features)
                    feature_ids.extend(batch_ids)
                    tf.keras.backend.clear_session()
                    gc.collect()
                if not features:
                    raise ValueError("No features extracted.")
                return np.concatenate(features, axis=0), feature_ids

            train_features, train_feature_ids = extract_features(train_generator, feature_extractor)
            valid_features, valid_feature_ids = extract_features(valid_generator, feature_extractor)
            test_features, test_feature_ids = extract_features(test_generator, feature_extractor)

            # Verify feature consistency
            if len(train_features) != len(balanced_train_ids) or len(valid_features) != len(valid_ids) or len(test_features) != len(test_ids):
                logging.error(f"Feature count mismatch for {model_name}")
                continue

            train_features_dict[model_name] = train_features
            valid_features_dict[model_name] = valid_features
            test_features_dict[model_name] = test_features
            logging.info(f"Extracted features for {model_name}: train={len(train_features)}, valid={len(valid_features)}, test={len(test_features)}")

            # Save features
            os.makedirs(FEATURE_SAVE_DIR, exist_ok=True)
            np.save(os.path.join(FEATURE_SAVE_DIR, f"{model_name}_train_features.npy"), train_features)
            np.save(os.path.join(FEATURE_SAVE_DIR, f"{model_name}_valid_features.npy"), valid_features)
            np.save(os.path.join(FEATURE_SAVE_DIR, f"{model_name}_test_features.npy"), test_features)

            # Train meta-model
            meta_model = build_meta_conv_model(
                img_shape=input_shape,
                feature_2d_shape=(train_features.shape[1],),
                num_classes=NUM_CLASSES
            )
            cm_callback = ConfusionMatrixWeightCallback(
                valid_features=[config['preprocess'](valid_images), valid_features],
                valid_labels=valid_y_multi,
                classification_model=meta_model,
                num_classes=NUM_CLASSES,
                class_counts=np.sum(valid_y_multi, axis=0)
            )
            checkpoint_callback = MultiMetricCheckpoint(
                filepath=os.path.join(FEATURE_SAVE_DIR, f"meta_model_best_{model_name}_{{epoch}}.weights.h5"),
                monitor_metrics={'qwk': 'max', 'f1': 'max'},
                mode='max',
                save_best_only=True
            )
            meta_model, history = maml_fomaml_train_manual(
                meta_model=meta_model,
                data_df=df_train_processed,
                valid_features=valid_features,
                valid_labels=valid_y_multi,
                valid_images=valid_images,
                input_dim=train_features.shape[1],
                model_type=model_name,
                support_size=16,
                query_size=16,
                save_dir=FEATURE_SAVE_DIR,
                sample_ids=valid_ids,
                callbacks=[cm_callback, checkpoint_callback]
            )
            meta_models[model_name] = meta_model
            logging.info(f"Trained meta-model for {model_name}")

            # Clear memory
            del base_model, feature_extractor, train_generator, valid_generator, test_generator
            tf.keras.backend.clear_session()
            gc.collect()
            check_memory()

        except Exception as e:
            logging.error(f"Error processing {model_name}: {str(e)}")
            continue

    # Ensemble predictions
    logging.info("Ensembling predictions...")
    ensemble_probs = []
    successful_models = []
    for model_name in meta_models:
        try:
            test_features = test_features_dict[model_name]
            test_images_preprocessed = MODEL_CONFIGS[model_name]['preprocess'](test_images)
            probs = meta_models[model_name].predict(
                [test_images_preprocessed, test_features],
                batch_size=BATCH_SIZE,
                verbose=0
            )[1]  # Get softmax probabilities
            probs = apply_temperature_scaling(probs, temperature=0.1)
            probs = laplace_smoothing(probs)
            ensemble_probs.append(probs)
            successful_models.append(model_name)
            logging.info(f"Predicted for {model_name}: shape={probs.shape}")
        except Exception as e:
            logging.error(f"Prediction error for {model_name}: {str(e)}")
            continue

    # Compute final predictions
    if not ensemble_probs:
        logging.warning("No valid predictions. Using default zero probabilities.")
        final_probs = np.zeros((len(test_y_multi), NUM_CLASSES))
        y_pred_classes = np.zeros(len(test_y_multi), dtype=np.int32)
        y_true = np.argmax(test_y_multi, axis=1)
    else:
        weights = np.array([1.0 / len(ensemble_probs)] * len(ensemble_probs))
        final_probs = np.average(ensemble_probs, axis=0, weights=weights)
        y_pred_classes = np.argmax(final_probs, axis=1)
        y_true = np.argmax(test_y_multi, axis=1)

    # Evaluate
    qwk = cohen_kappa_score(y_true, y_pred_classes, labels=range(NUM_CLASSES), weights='quadratic')
    f1 = f1_score(y_true, y_pred_classes, average='weighted')
    recall = recall_score(y_true, y_pred_classes, average='weighted')
    logging.info(f"Ensemble results: QWK={qwk:.4f}, F1={f1:.4f}, Recall={recall:.2f}")

    # Generate report
    test_metrics = {
        "qwk": float(qwk),
        "f1_score": float(f1),
        "recall": float(recall),
        "successful_models": successful_models
    }
    dataset_stats = {
        "train_samples": len(balanced_train_x),
        "valid_samples": len(valid_images),
        "test_samples": len(test_images),
        "class_distribution": np.sum(test_y_multi, axis=0).tolist()
    }
    runtime = time.time() - start_time
    generate_pipeline_report(dataset_stats, {}, test_metrics, runtime, FEATURE_SAVE_DIR)

    logging.info(f"Pipeline completed in {runtime:.2f} seconds.")
    return {
        "train_features_dict": train_features_dict,
        "valid_features_dict": valid_features_dict,
        "test_features_dict": test_features_dict,
        "meta_models": meta_models
    }

def run_pipeline(tune_hyperparameters=False, use_gpu=True):
    start_time = time.time()
    logging.info(f"Bắt đầu pipeline tại: {datetime.now().strftime('%H:%M:%S %d/%m/%Y')}")

    try:
        # Kiểm tra môi trường
        if not os.path.ismount('/content/drive'):
            logging.error("Google Drive chưa được mount.")
            raise RuntimeError("Yêu cầu mount Google Drive.")
        for folder in [FEATURE_SAVE_DIR, PROCESSED_FOLDER, DRIVE_FOLDER]:
            if not os.path.exists(folder):
                logging.error(f"Thư mục không tồn tại: {folder}")
                raise FileNotFoundError(f"Thư mục không tồn tại: {folder}")
            logging.info(f"Thư mục tồn tại: {folder}")

        # Cấu hình GPU hoặc CPU
        if use_gpu:
            physical_devices = tf.config.list_physical_devices('GPU')
            if physical_devices:
                for device in physical_devices:
                    tf.config.experimental.set_memory_growth(device, True)
                    logging.info(f"Đã bật memory growth cho {device}")
            else:
                logging.warning("Không tìm thấy GPU, chuyển sang CPU.")
                use_gpu = False
        else:
            os.environ["CUDA_VISIBLE_DEVICES"] = ""
            logging.info("Chạy trên CPU theo yêu cầu.")

        os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices=false'
        logging.info("Đã tắt XLA compilation.")
        from tensorflow.keras.mixed_precision import set_global_policy
        set_global_policy('mixed_float16')
        logging.info("Đã bật mixed precision.")

        # Chuẩn bị dữ liệu
        logging.info("Chuẩn bị dữ liệu...")
        data = prepare_data()
        if data is None or any(x is None for x in data):
            logging.error("Chuẩn bị dữ liệu thất bại.")
            raise ValueError("Chuẩn bị dữ liệu thất bại.")

        (balanced_train_x, balanced_train_y_multi, valid_images, valid_y_multi, test_images, test_y_multi,
         df_train_processed, train_ids, valid_ids, test_ids, balanced_train_ids, processed_image_ids) = data

        logging.info(f"Kích thước dataset: train={len(balanced_train_x)}, valid={len(valid_images)}, test={len(test_images)}")

        # Kiểm tra rò rỉ dữ liệu
        if not check_data_leakage(train_ids, valid_ids, test_ids, df_train_processed):
            logging.error("Phát hiện rò rỉ dữ liệu.")
            raise ValueError("Rò rỉ dữ liệu được phát hiện.")
        logging.info("Không phát hiện rò rỉ dữ liệu.")

        # Kiểm tra cấu hình mô hình
        validate_model_configs(MODEL_CONFIGS)
        logging.info("Xác thực cấu hình mô hình hoàn tất.")

        # Tối ưu siêu tham số
        best_params = {'inner_lr': 0.01, 'outer_lr': 0.001, 'batch_size': 8}
        if tune_hyperparameters:
            logging.info("Tối ưu siêu tham số...")
            param_grid = {
                'inner_lr': [0.01, 0.001],
                'outer_lr': [0.001, 0.0001],
                'batch_size': [8, 16]
            }
            best_params, best_qwk = hyperparameter_tuning(
                df_train_processed['id_code'],
                df_train_processed['diagnosis'].values,
                PROCESSED_FOLDER,
                MODEL_CONFIGS,
                FEATURE_SAVE_DIR,
                param_grid,
                n_iter=5,
                n_episodes=10
            )
            logging.info(f"Tham số tốt nhất: {best_params}, QWK tốt nhất: {best_qwk:.4f}")

        # Kiểm tra bộ nhớ
        available_memory = check_memory()
        if available_memory is None or available_memory < 1.0:
            logging.warning(f"Bộ nhớ thấp: {available_memory} GB.")

        # Chạy pipeline chính
        logging.info("Chạy pipeline chính...")
        pipeline_results = main_pipeline(
            balanced_train_x,
            balanced_train_y_multi,
            valid_images,
            valid_y_multi,
            test_images,
            test_y_multi,
            df_train_processed,
            train_ids,
            valid_ids,
            test_ids,
            balanced_train_ids,
            processed_image_ids
        )

        train_features_dict = pipeline_results['train_features_dict']
        valid_features_dict = pipeline_results['valid_features_dict']
        test_features_dict = pipeline_results['test_features_dict']

        # Tạo báo cáo
        dataset_stats = {
            'train_samples': len(balanced_train_x),
            'valid_samples': len(valid_images),
            'test_samples': len(test_images),
            'class_distribution': np.bincount(np.argmax(balanced_train_y_multi, axis=1)).tolist()
        }
        cv_results = {}
        test_metrics = {}
        runtime = time.time() - start_time
        generate_pipeline_report(dataset_stats, cv_results, test_metrics, runtime, FEATURE_SAVE_DIR)

        logging.info(f"Pipeline hoàn tất sau {runtime:.2f} giây!")
        return train_features_dict, valid_features_dict, test_features_dict

    except Exception as e:
        logging.error(f"Pipeline thất bại: {str(e)}")
        raise
    finally:
        logging.info("Dọn dẹp tài nguyên...")
        tf.keras.backend.clear_session()
        gc.collect()
        logging.info(f"Pipeline kết thúc tại: {datetime.now().strftime('%H:%M:%S %d/%m/%Y')}")

if __name__ == "__main__":
    # Chạy pipeline với tùy chọn hyperparameter tuning
    train_features_dict, valid_features_dict, test_features_dict = run_pipeline(tune_hyperparameters=False)

Tìm thấy 3662 ID ảnh trong /content/processed_train_images
Mẫu ID: ['cb2f3c5d71a7', '5c6194562ed2', 'e893e86dde94', '58059e73d2d4', '8ee50c26fc13']...


  saveable.load_own_variables(weights_store.get(inner_path))
  saveable.load_own_variables(weights_store.get(inner_path))
ERROR:root:Lỗi trong episode 1: 'NoneType' object has no attribute 'save_weights'
ERROR:root:Lỗi trong episode 2: setting an array element with a sequence. The requested array has an inhomogeneous shape after 2 dimensions. The detected shape was (2, 550) + inhomogeneous part.
ERROR:root:Lỗi trong episode 3: setting an array element with a sequence. The requested array has an inhomogeneous shape after 2 dimensions. The detected shape was (2, 550) + inhomogeneous part.
ERROR:root:Lỗi trong episode 4: setting an array element with a sequence. The requested array has an inhomogeneous shape after 2 dimensions. The detected shape was (2, 550) + inhomogeneous part.
ERROR:root:Lỗi trong episode 5: setting an array element with a sequence. The requested array has an inhomogeneous shape after 2 dimensions. The detected shape was (2, 550) + inhomogeneous part.
ERROR:root:Lỗi t

KeyboardInterrupt: 

In [None]:
!pip install GPUtil

Collecting GPUtil
  Downloading GPUtil-1.4.0.tar.gz (5.5 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: GPUtil
  Building wheel for GPUtil (setup.py) ... [?25l[?25hdone
  Created wheel for GPUtil: filename=GPUtil-1.4.0-py3-none-any.whl size=7392 sha256=9eb9bb1c2a4699c385210cbaf09bca6ca453f02d1af6995b7b6ea6fe967e853b
  Stored in directory: /root/.cache/pip/wheels/2b/4d/8f/55fb4f7b9b591891e8d3f72977c4ec6c7763b39c19f0861595
Successfully built GPUtil
Installing collected packages: GPUtil
Successfully installed GPUtil-1.4.0


In [None]:
import pandas as pd
import os
import json
from pathlib import Path

def compare_id_code_with_processed_folder(csv_path, processed_folder, output_dir):
    """
    So sánh df_train['id_code'] với các tệp ảnh trong PROCESSED_FOLDER.

    Args:
        csv_path (str): Đường dẫn đến file train.csv
        processed_folder (str): Đường dẫn đến thư mục PROCESSED_FOLDER chứa ảnh .png
        output_dir (str): Thư mục để lưu báo cáo JSON
    """
    try:
        # Kiểm tra và đọc file CSV
        if not os.path.exists(csv_path):
            print(f"ERROR: File CSV không tồn tại: {csv_path}")
            return
        df_train = pd.read_csv(csv_path)
        if 'id_code' not in df_train.columns:
            print("ERROR: File CSV không có cột 'id_code'")
            return
        print(f"Đã đọc file CSV: {len(df_train)} mẫu")

        # Lấy danh sách id_code từ DataFrame
        df_ids = set(df_train['id_code'].astype(str).values)
        print(f"Số id_code trong df_train: {len(df_ids)}")
        print(f"Mẫu id_code: {list(df_ids)[:5]}...")

        # Lấy danh sách tệp ảnh từ PROCESSED_FOLDER
        if not os.path.exists(processed_folder):
            print(f"ERROR: Thư mục PROCESSED_FOLDER không tồn tại: {processed_folder}")
            return
        image_files = [f for f in os.listdir(processed_folder) if f.endswith('.png')]
        processed_ids = set(os.path.splitext(f)[0] for f in image_files)
        print(f"Số tệp ảnh trong PROCESSED_FOLDER: {len(processed_ids)}")
        print(f"Mẫu processed_ids: {list(processed_ids)[:5]}...")

        # So sánh id_code và processed_ids
        matching_ids = df_ids.intersection(processed_ids)
        missing_in_processed = df_ids - processed_ids
        missing_in_df = processed_ids - df_ids

        # Tạo báo cáo
        report = {
            "timestamp": pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S'),
            "csv_path": csv_path,
            "processed_folder": processed_folder,
            "total_df_ids": len(df_ids),
            "total_processed_ids": len(processed_ids),
            "matching_ids_count": len(matching_ids),
            "matching_ids_sample": list(matching_ids)[:10],  # Lấy mẫu 10 ID khớp
            "missing_in_processed_count": len(missing_in_processed),
            "missing_in_processed": list(missing_in_processed),
            "missing_in_df_count": len(missing_in_df),
            "missing_in_df": list(missing_in_df)
        }

        # In kết quả
        print(f"Số ID khớp: {len(matching_ids)}")
        print(f"Số ID trong df_train nhưng không có trong PROCESSED_FOLDER: {len(missing_in_processed)}")
        if missing_in_processed:
            print(f"WARNING: Mẫu ID thiếu trong PROCESSED_FOLDER: {list(missing_in_processed)[:5]}...")
        print(f"Số tệp ảnh trong PROCESSED_FOLDER nhưng không có trong df_train: {len(missing_in_df)}")
        if missing_in_df:
            print(f"WARNING: Mẫu ID thiếu trong df_train: {list(missing_in_df)[:5]}...")

        # Lưu báo cáo vào file JSON
        os.makedirs(output_dir, exist_ok=True)
        report_path = os.path.join(output_dir, "id_code_comparison_report.json")
        with open(report_path, 'w', encoding='utf-8') as f:
            json.dump(report, f, indent=4)
        print(f"Đã lưu báo cáo tại: {report_path}")

    except Exception as e:
        print(f"ERROR: Lỗi trong quá trình so sánh: {str(e)}")
        raise

if __name__ == "__main__":
    # Thay đổi các đường dẫn này theo cấu hình của bạn
    CSV_PATH = "/content/drive/MyDrive/kaggle_data/aptos2019/train.csv"
    PROCESSED_FOLDER = "/content/processed_train_images"
    OUTPUT_DIR = "/content/drive/MyDrive/working"

    print("Bắt đầu so sánh id_code với PROCESSED_FOLDER...")
    compare_id_code_with_processed_folder(CSV_PATH, PROCESSED_FOLDER, OUTPUT_DIR)
    print("Hoàn tất so sánh.")

Bắt đầu so sánh id_code với PROCESSED_FOLDER...
Đã đọc file CSV: 3662 mẫu
Số id_code trong df_train: 3662
Mẫu id_code: ['2fde69f20585', '8dc22e65c06f', 'ea05c22d92e9', 'a721efb1e049', '0eced86c9db8']...
Số tệp ảnh trong PROCESSED_FOLDER: 3662
Mẫu processed_ids: ['2fde69f20585', '8dc22e65c06f', 'ea05c22d92e9', 'a721efb1e049', '0eced86c9db8']...
Số ID khớp: 3662
Số ID trong df_train nhưng không có trong PROCESSED_FOLDER: 0
Số tệp ảnh trong PROCESSED_FOLDER nhưng không có trong df_train: 0
Đã lưu báo cáo tại: /content/drive/MyDrive/working/id_code_comparison_report.json
Hoàn tất so sánh.


In [None]:
import os

PROCESSED_DIR = "/content/processed_train_images"
sample_ids = df_train_processed['id_code'].values
missing_images = []
for sample_id in sample_ids:
    img_path = os.path.join(PROCESSED_FOLDER, f"{sample_id}.png")
    if not os.path.exists(img_path):
        missing_images.append(sample_id)
print(f"Missing images ({len(missing_images)}): {missing_images[:5]}...")

Missing images (0): []...


In [None]:
import json
import os

metadata_path = "/content/drive/MyDrive/working/efficientnetb1_features_metadata.json"
if os.path.exists(metadata_path):
    with open(metadata_path, 'r') as f:
        metadata = json.load(f)
    print("Nội dung metadata:")
    print(json.dumps(metadata, indent=4))
    print(f"features_4d_path: {metadata.get('features_4d_path', 'Không tìm thấy')}")
else:
    print(f"Metadata không tồn tại tại: {metadata_path}")

Nội dung metadata:
{
    "model_name": "efficientnetb1",
    "conv_layer_name": "top_conv",
    "features_2d_path": "/content/drive/MyDrive/working/efficientnetb1_features_2d.npy",
    "features_4d_path": "/content/drive/MyDrive/working/efficientnetb1_features_4d.npy",
    "sample_ids": [
        "189cbbc9e5e3",
        "d2901144070c",
        "b9519abce0c1",
        "b16dd4483ca5",
        "82bb8a01935f",
        "0c55d58bebaf",
        "80dbeb0fdc75",
        "2fdfb80ea53c",
        "239f2c348ea4",
        "041f09eec1e8",
        "5712e2aa73a2",
        "7ccf9d25dc48",
        "e4d3d437b0a8",
        "9be71d6d7e59",
        "9519a590985d",
        "165cd2070ebd",
        "437900a99871",
        "e599151ca14b",
        "89ed6a0dd53f",
        "d3e884109b45",
        "a11bf2edd470",
        "36a1e3c780a0",
        "1e8a1fdee5b9",
        "69df7ade0575",
        "5efa24b03d5e",
        "50915e2329a1",
        "e811f39a1243",
        "af87d48ffe01",
        "857230f64a2e",
        "cd54d

In [None]:
import tensorflow as tf
import numpy as np
import os
import pandas as pd
import cv2
import albumentations as A
from tensorflow.keras.models import Model
from tensorflow.keras.applications import EfficientNetB1, Xception, InceptionV3, ResNet50, DenseNet121
from tensorflow.keras.applications.efficientnet import preprocess_input as efficientnet_preprocess
from tensorflow.keras.applications.xception import preprocess_input as xception_preprocess
from tensorflow.keras.applications.inception_v3 import preprocess_input as inceptionv3_preprocess
from tensorflow.keras.applications.resnet50 import preprocess_input as resnet50_preprocess
from tensorflow.keras.applications.densenet import preprocess_input as densenet121_preprocess
from sklearn.metrics import cohen_kappa_score, f1_score, recall_score, confusion_matrix, precision_score, accuracy_score
from sklearn.utils import shuffle
from tensorflow.keras.layers import Dense, Input, BatchNormalization, Layer, Dropout
from sklearn.decomposition import PCA
from sklearn.isotonic import IsotonicRegression
import logging
from tensorflow.keras.callbacks import TensorBoard, ReduceLROnPlateau
import json
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns
import gc
from google.colab import drive
import time
import subprocess

# Thiết lập tăng trưởng bộ nhớ GPU
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print("Đã kích hoạt tăng trưởng bộ nhớ GPU")
    except RuntimeError as e:
        print(f"Lỗi khi thiết lập tăng trưởng bộ nhớ GPU: {e}")

# Thiết lập thư mục lưu trữ
feature_save_dir = "/content/drive/MyDrive/working"
log_dir = os.path.join(feature_save_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
os.makedirs(feature_save_dir, exist_ok=True)
tensorboard_callback = TensorBoard(log_dir=log_dir, histogram_freq=1)

SIZE = 244
NUM_CLASSES = 5
TEMP_AUGMENT_DIR = "/tmp/temp_augmented_images"
os.makedirs(TEMP_AUGMENT_DIR, exist_ok=True)



# Kiểm tra và mount Google Drive
if not os.path.ismount('/content/drive'):
    drive.mount('/content/drive')
else:
    print("Google Drive đã được mount.")

# Thiết lập thư mục lưu trữ trên Google Drive
feature_save_dir = "/content/drive/MyDrive/working"
log_dir = os.path.join(feature_save_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
os.makedirs(feature_save_dir, exist_ok=True)
tensorboard_callback = TensorBoard(log_dir=log_dir, histogram_freq=1)

# Đọc dữ liệu
drive_folder = "/content/drive/MyDrive/kaggle_data/aptos2019"
processed_folder = "/content/processed_train_images"
df_train = pd.read_csv(os.path.join(drive_folder, "train.csv"))
processed_ids = [f.replace('.png', '') for f in os.listdir(processed_folder) if f.endswith('.png')]
df_train_processed = df_train[df_train['id_code'].isin(processed_ids)].copy()



# Chia dữ liệu
x = df_train_processed['id_code']  # id_code có đuôi .png
y = df_train_processed['diagnosis']
if len(x) < 2:
    logging.error("Số mẫu quá ít để chia dữ liệu.")
    raise ValueError("Cần ít nhất 2 mẫu để chia train/valid.")
x, y = shuffle(x, y, random_state=42)
train_x, valid_x, train_y, valid_y = train_test_split(x, y, test_size=0.15, stratify=y, random_state=42)

logging.info(f"Train X shape: {train_x.shape}")
logging.info(f"Valid X shape: {valid_x.shape}")

# Tải ảnh đã xử lý
def load_processed_image(image_id, processed_folder, size=244):
    try:
        img_path = os.path.join(processed_folder, f"{image_id}.png")
        if not os.path.exists(img_path):
            logging.error(f"Ảnh không tồn tại: {img_path}")
            return None
        img = cv2.imread(img_path)
        if img is None:
            logging.error(f"Không đọc được ảnh: {img_path}")
            return None
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
        return img
    except Exception as e:
        logging.error(f"Lỗi khi tải ảnh {image_id}: {str(e)}")
        return None

resized_train_x = np.array([load_processed_image(id_code, processed_folder, size=SIZE) for id_code in train_x])
resized_valid_x = np.array([load_processed_image(id_code, processed_folder, size=SIZE) for id_code in valid_x])

train_y_multi = tf.keras.utils.to_categorical(train_y, num_classes=NUM_CLASSES)
valid_y_multi = tf.keras.utils.to_categorical(valid_y, num_classes=NUM_CLASSES)

# Tải ảnh đã xử lý
def load_processed_image(image_id, processed_folder, size=244):
    try:
        img_path = os.path.join(processed_folder, f"{image_id}.png")
        if not os.path.exists(img_path):
            logging.error(f"Ảnh không tồn tại: {img_path}")
            return None
        img = cv2.imread(img_path)
        if img is None:
            logging.error(f"Không đọc được ảnh: {img_path}")
            return None
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
        return img
    except Exception as e:
        logging.error(f"Lỗi khi tải ảnh {image_id}: {str(e)}")
        return None

resized_train_x = np.array([load_processed_image(id_code, processed_folder, size=SIZE) for id_code in train_x])
resized_valid_x = np.array([load_processed_image(id_code, processed_folder, size=SIZE) for id_code in valid_x])
resized_test_x = np.array([load_processed_image(id_code, processed_folder, size=SIZE) for id_code in test_x])

train_y_multi = tf.keras.utils.to_categorical(train_y, num_classes=NUM_CLASSES)
valid_y_multi = tf.keras.utils.to_categorical(valid_y, num_classes=NUM_CLASSES)
test_y_multi = tf.keras.utils.to_categorical(test_y, num_classes=NUM_CLASSES)

# Hàm custom_random_erasing
def custom_random_erasing(image, scale=(0.01, 0.05), ratio=(0.5, 2.0), p=0.3, value=None):
    if np.random.random() > p:
        return image
    height, width, channels = image.shape
    area = height * width
    scale_factor = np.random.uniform(scale[0], scale[1])
    erase_area = area * scale_factor
    aspect_ratio = np.random.uniform(ratio[0], ratio[1])
    erase_height = int(np.sqrt(erase_area / aspect_ratio))
    erase_width = int(np.sqrt(erase_area * aspect_ratio))
    erase_height = min(erase_height, height)
    erase_width = min(erase_width, width)
    if erase_height < 1 or erase_width < 1:
        return image
    x = np.random.randint(0, width - erase_width + 1)
    y = np.random.randint(0, height - erase_height + 1)
    output = image.copy()
    if value is None:
        value = np.mean(image, axis=(0, 1))
    output[y:y+erase_height, x:x+erase_width, :] = value
    return output

# Hàm balance_and_augment_data
def balance_and_augment_data(images, labels, target_classes=[0, 1, 2, 3, 4], samples_per_class=None):
    num_classes = labels.shape[1]
    label_indices = np.argmax(labels, axis=1)
    keep_indices = np.isin(label_indices, target_classes)
    filtered_images = images[keep_indices]
    filtered_labels = labels[keep_indices]
    filtered_label_indices = label_indices[keep_indices]
    class_counts = np.bincount(filtered_label_indices, minlength=num_classes)
    print(f"Phân bố nhãn ban đầu: {dict(zip(range(num_classes), class_counts))}")
    cls_counts = [class_counts[cls] for cls in target_classes]
    max_count = samples_per_class or max(cls_counts)
    print(f"Số mẫu mục tiêu mỗi lớp: {max_count}")
    augmenter = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.Rotate(limit=15, p=0.3),  # Reduced rotation
        A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.3),  # Reduced intensity
        A.GaussNoise(var_limit=(10.0, 20.0), p=0.2),
        A.CLAHE(clip_limit=1.0, tile_grid_size=(8, 8), p=0.2),  # Reduced CLAHE
    ])
    new_images = []
    new_labels = []
    for cls in target_classes:
        cls_indices = np.where(filtered_label_indices == cls)[0]
        cls_images = filtered_images[cls_indices]
        cls_labels = filtered_labels[cls_indices]
        current_count = len(cls_indices)
        new_images.extend(cls_images)
        new_labels.extend(cls_labels)
        print(f"Lớp {cls}: {current_count} mẫu ban đầu")
        augment_count = max_count - current_count
        if augment_count > 0:
            print(f"Tăng cường {augment_count} mẫu cho lớp {cls}")
            for _ in range(augment_count):
                idx = np.random.choice(cls_indices)
                img = filtered_images[idx].astype(np.uint8)
                aug_img = augmenter(image=img)['image']
                aug_img = custom_random_erasing(
                    aug_img, scale=(0.01, 0.05), ratio=(0.5, 2.0), p=0.0, value=np.mean(aug_img, axis=(0, 1))  # Disabled
                )
                new_images.append(aug_img)
                new_labels.append(filtered_labels[idx])
    new_images = np.array(new_images, dtype=np.float32)
    new_labels = np.array(new_labels, dtype=np.float32)
    new_images, new_labels = shuffle(new_images, new_labels, random_state=42)
    final_class_counts = np.bincount(np.argmax(new_labels, axis=1), minlength=num_classes)
    print(f"Phân bố nhãn sau cân bằng: {dict(zip(range(num_classes), final_class_counts))}")
    return new_images, new_labels

# Cân bằng dữ liệu train
class_counts = np.bincount(train_y)
class_0_count = class_counts[0]
print(f"Số mẫu lớp 0: {class_0_count}")
balanced_train_x, balanced_train_y_multi = balance_and_augment_data(
    resized_train_x, train_y_multi, target_classes=[1, 2, 3, 4], samples_per_class=class_0_count
)
class_0_indices = np.where(np.argmax(train_y_multi, axis=1) == 0)[0]
class_0_images = resized_train_x[class_0_indices]
class_0_labels = train_y_multi[class_0_indices]
balanced_train_x = np.concatenate([balanced_train_x, class_0_images], axis=0)
balanced_train_y_multi = np.concatenate([balanced_train_y_multi, class_0_labels], axis=0)
balanced_train_x, balanced_train_y_multi = shuffle(balanced_train_x, balanced_train_y_multi, random_state=42)
final_class_counts = np.bincount(np.argmax(balanced_train_y_multi, axis=1), minlength=5)
print(f"Phân bố nhãn sau khi thêm lớp 0: {dict(zip(range(5), final_class_counts))}")
print("balanced_train_x shape:", balanced_train_x.shape)
print("balanced_train_y_multi shape:", balanced_train_y_multi.shape)

# Định nghĩa model_configs
model_configs = {
    "efficientnetb1": {
        "model_type": "efficientnetb1",
        "config_path": "/content/drive/MyDrive/working/EfficientNetB1_bestqwk_aptos/config.json",
        "weights_path": "/content/drive/MyDrive/working/EfficientNetB1_bestqwk_aptos/model.weights.h5",
        "preprocess": efficientnet_preprocess,
        "img_size": 244,
        "base_model": EfficientNetB1
    },
    "xception": {
        "model_type": "xception",
        "config_path": "/content/drive/MyDrive/working/Xception_bestqwk_aptos/config.json",
        "weights_path": "/content/drive/MyDrive/working/Xception_bestqwk_aptos/model.weights.h5",
        "preprocess": xception_preprocess,
        "img_size": 244,
        "base_model": Xception
    },
    "inceptionv3": {
        "model_type": "inceptionv3",
        "config_path": "/content/drive/MyDrive/working/InceptionV3_bestqwk_aptos/config.json",
        "weights_path": "/content/drive/MyDrive/working/InceptionV3_bestqwk_aptos/model.weights.h5",
        "preprocess": inceptionv3_preprocess,
        "img_size": 299,
        "base_model": InceptionV3
    },
    "resnet50": {
        "model_type": "resnet50",
        "config_path": "/content/drive/MyDrive/working/ResNet50_bestqwk_aptos/config.json",
        "weights_path": "/content/drive/MyDrive/working/ResNet50_bestqwk_aptos/model.weights.h5",
        "preprocess": resnet50_preprocess,
        "img_size": 244,
        "base_model": ResNet50
    },
    "densenet121": {
        "model_type": "densenet121",
        "config_path": "/content/drive/MyDrive/working/DenseNet121_bestqwk_aptos/config.json",
        "weights_path": "/content/drive/MyDrive/working/DenseNet121_bestqwk_aptos/model.weights.h5",
        "preprocess": densenet121_preprocess,
        "img_size": 244,
        "base_model": DenseNet121
    }
}

# My_Generator với mixup
class My_Generator(tf.keras.utils.Sequence):
    def __init__(self, images, labels, batch_size, is_train=False, mix=True, augment=False, size1=244, size2=299, model_type="default", preprocess=None):
        self.labels = np.array(labels, dtype=np.float32)
        self.batch_size = batch_size
        self.is_train = is_train
        self.is_augment = augment
        self.is_mix = mix
        self.model_type = str(model_type).lower()
        self.preprocess = preprocess
        self.temp_augment_dir = TEMP_AUGMENT_DIR
        os.makedirs(self.temp_augment_dir, exist_ok=True)
        self.target_size = (size2, size2) if 'inceptionv3' in self.model_type or 'xception' in self.model_type else (size1, size1)
        self.image_paths = []
        if isinstance(images, np.ndarray):
            for i, img in enumerate(images):
                img_path = os.path.join(self.temp_augment_dir, f"img_{i}_{np.random.randint(1000000)}.png")
                try:
                    if img.dtype != np.uint8:
                        if img.max() <= 1.0:
                            img = (img * 255).astype(np.uint8)
                        else:
                            img = img.astype(np.uint8)
                    cv2.imwrite(img_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
                    self.image_paths.append(img_path)
                except Exception as e:
                    logging.error(f"Lỗi khi lưu ảnh {img_path}: {str(e)}")
                    continue
        else:
            self.image_paths = list(images)
        unique_paths = []
        unique_indices = []
        seen = set()
        for i, path in enumerate(self.image_paths):
            if path not in seen:
                seen.add(path)
                unique_paths.append(path)
                unique_indices.append(i)
        self.image_paths = unique_paths
        self.labels = self.labels[unique_indices]
        if len(self.image_paths) != len(self.labels):
            logging.error(f"Số image_paths ({len(self.image_paths)}) không khớp với số labels ({len(self.labels)})")
            self.labels = self.labels[:len(self.image_paths)]
        if not self.image_paths:
            raise ValueError("Không có image_paths hợp lệ được tạo.")
        print(f"Khởi tạo My_Generator: {len(self.image_paths)} mẫu, target_size={self.target_size}, is_train={is_train}")
        self.dataset = self._create_dataset()

    def _load_image(self, img_path):
        img = tf.io.read_file(img_path)
        img = tf.image.decode_png(img, channels=3)
        img = tf.image.resize(img, self.target_size, method=tf.image.ResizeMethod.BILINEAR)
        img = tf.ensure_shape(img, [self.target_size[0], self.target_size[1], 3])
        img = tf.cast(img, tf.float32) / 255.0
        if self.preprocess is not None:
            img = self.preprocess(img)
        return img

    def _mixup(self, images, labels):
        batch_size = tf.shape(images)[0]
        lam = tf.random.uniform([], minval=0.2, maxval=0.4, dtype=tf.float32)
        indices = tf.random.shuffle(tf.range(batch_size, dtype=tf.int32))
        mixed_images = lam * images + (1 - lam) * tf.gather(images, indices)
        mixed_labels = lam * labels + (1 - lam) * tf.gather(labels, indices)
        return mixed_images, mixed_labels

    def _create_dataset(self):
        dataset = tf.data.Dataset.from_tensor_slices((self.image_paths, self.labels))
        dataset = dataset.map(
            lambda img_path, label: (self._load_image(img_path), label),
            num_parallel_calls=tf.data.AUTOTUNE
        )
        if self.is_train:
            dataset = dataset.shuffle(buffer_size=len(self.image_paths))
        dataset = dataset.batch(self.batch_size, drop_remainder=False)
        if self.is_train and self.is_mix:
            dataset = dataset.map(
                self._mixup,
                num_parallel_calls=tf.data.AUTOTUNE
            )
        dataset = dataset.prefetch(tf.data.AUTOTUNE)
        return dataset

    def __len__(self):
        return int(np.ceil(len(self.image_paths) / self.batch_size))

    def __iter__(self):
        self.iterator = iter(self.dataset)
        return self

    def __next__(self):
        return next(self.iterator)

# Callback để tính trọng số lớp từ confusion matrix
class ConfusionMatrixWeightCallback(tf.keras.callbacks.Callback):
    def __init__(self, valid_features, valid_labels, classification_model, num_classes=5, class_counts=None):
        super().__init__()
        self.valid_features = valid_features
        self.valid_labels = valid_labels
        self.classification_model = classification_model
        self.num_classes = num_classes
        self.prev_cm = None
        self.class_weights = np.ones(num_classes, dtype=np.float32)
        self.class_counts = class_counts
        self.history_dir = os.path.join(feature_save_dir, "history")
        os.makedirs(self.history_dir, exist_ok=True)
        self.weights_history = []

    def on_epoch_end(self, epoch, logs=None):
        y_pred = self.classification_model.predict(self.valid_features, verbose=0, batch_size=32)
        y_true = np.argmax(self.valid_labels, axis=1)
        y_pred_classes = np.argmax(y_pred, axis=1)
        cm = confusion_matrix(y_true, y_pred_classes, labels=list(range(self.num_classes)))
        print(f"Epoch {epoch+1} - Ma trận nhầm lẫn:\n{cm}")
        errors = np.sum(cm * (1 - np.eye(self.num_classes)), axis=1)
        total_samples_per_class = np.sum(cm, axis=1)
        total_samples_per_class = np.where(total_samples_per_class == 0, 1, total_samples_per_class)
        error_rates = errors / total_samples_per_class
        weak_classes = []
        if self.class_counts is not None:
            min_count = np.min(self.class_counts[self.class_counts > 0])
            weak_classes = np.where(self.class_counts <= min_count * 1.5)[0]
        high_error_classes = np.where(error_rates >= np.percentile(error_rates, 75))[0]
        weak_classes = np.unique(np.concatenate([weak_classes, high_error_classes])).astype(int)
        self.class_weights = 1.0 + error_rates
        for cls in weak_classes:
            self.class_weights[cls] *= 2.0
        self.class_weights /= self.class_weights.max()
        print(f"Epoch {epoch+1} - Lớp yếu: {weak_classes}")
        print(f"Epoch {epoch+1} - Trọng số lớp: {self.class_weights}")
        self.weights_history.append({
            "epoch": epoch + 1,
            "class_weights": self.class_weights.tolist(),
            "weak_classes": weak_classes.tolist(),
            "confusion_matrix": cm.tolist()
        })
        weights_path = os.path.join(self.history_dir, f"class_weights_epoch_{epoch+1}.json")
        with open(weights_path, 'w') as f:
            json.dump(self.weights_history[-1], f, indent=4)
        print(f"Đã lưu trọng số lớp tại: {weights_path}")
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                    xticklabels=list(range(self.num_classes)),
                    yticklabels=list(range(self.num_classes)))
        plt.title(f'Ma trận nhầm lẫn - Epoch {epoch+1}')
        plt.xlabel('Dự đoán')
        plt.ylabel('Thực tế')
        cm_path = os.path.join(feature_save_dir, f'confusion_matrix_epoch_{epoch+1}.png')
        plt.savefig(cm_path)
        plt.close()
        print(f"Đã lưu ma trận nhầm lẫn tại: {cm_path}")
        self.prev_cm = cm.copy()

    def get_class_weights(self):
        return self.class_weights

# Các hàm và lớp hỗ trợ
def load_model_from_config(config_path, weights_path, base_model_class):
    try:
        if config_path and os.path.exists(config_path) and weights_path and os.path.exists(weights_path):
            with open(config_path, 'r') as f:
                model_config = json.load(f)
            model = tf.keras.models.model_from_json(json.dumps(model_config))
            model.load_weights(weights_path)
            return model
        raise FileNotFoundError
    except:
        return base_model_class(weights='imagenet', include_top=False, pooling='avg')

class GradientReversalLayer(Layer):
    def __init__(self, lambda_=1.0, **kwargs):
        super().__init__(**kwargs)
        self.lambda_ = lambda_
    def call(self, inputs, training=None):
        inputs = tf.convert_to_tensor(inputs, dtype=tf.float32)
        return inputs if not training else tf.math.multiply(-self.lambda_, inputs)
    def get_config(self):
        config = super().get_config()
        config.update({"lambda_": self.lambda_})
        return config

class MemoryAugmentedLayer(tf.keras.layers.Layer):
    def __init__(self, memory_size, memory_dim, **kwargs):
        super().__init__(**kwargs)
        self.memory_size = memory_size
        self.memory_dim = memory_dim
    def build(self, input_shape):
        self.memory = self.add_weight(
            shape=(self.memory_size, self.memory_dim),
            initializer='zeros',
            trainable=False,
            dtype=tf.float32
        )
        super().build(input_shape)
    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        memory_size = tf.shape(self.memory)[0]
        memory_sliced = tf.cond(
            batch_size > memory_size,
            lambda: tf.tile(self.memory, [(batch_size + memory_size - 1) // memory_size, 1])[:batch_size],
            lambda: self.memory[:batch_size]
        )
        return tf.reduce_mean(tf.stack([inputs, memory_sliced], axis=0), axis=0)
    def get_config(self):
        config = super().get_config()
        config.update({'memory_size': self.memory_size, 'memory_dim': self.memory_dim})
        return config

class CustomGridDropout(tf.keras.layers.Layer):
    def __init__(self, ratio=0.3, holes_number=4, p=0.5, **kwargs):
        super().__init__(**kwargs)
        self.ratio = ratio
        self.holes_number = holes_number
        self.p = p
    def call(self, inputs, training=None):
        if not training:
            return inputs
        inputs = tf.convert_to_tensor(inputs, dtype=tf.float32)
        batch_size = tf.shape(inputs)[0]
        feature_dim = tf.shape(inputs)[1]
        hole_size = tf.maximum(1, tf.cast(tf.cast(feature_dim, tf.float32) * self.ratio, tf.int32))
        mask = tf.ones_like(inputs, dtype=tf.float32)
        random_probs = tf.random.uniform([self.holes_number], 0, 1)
        active_holes = tf.cast(random_probs < self.p, tf.int32)
        hole_indices = tf.range(self.holes_number)
        start_indices = (hole_indices * feature_dim) // self.holes_number
        end_indices = tf.minimum(start_indices + hole_size, feature_dim)
        all_indices = []
        for i in range(self.holes_number):
            should_apply = active_holes[i]
            indices = tf.cond(
                should_apply > 0,
                lambda: tf.stack([
                    tf.tile(tf.range(batch_size), [end_indices[i] - start_indices[i]]),
                    tf.repeat(tf.range(start_indices[i], end_indices[i]), batch_size)
                ], axis=1),
                lambda: tf.zeros([0, 2], dtype=tf.int32)
            )
            all_indices.append(indices)
        all_indices = tf.concat(all_indices, axis=0)
        updates = tf.zeros([tf.shape(all_indices)[0]], dtype=tf.float32)
        mask = tf.tensor_scatter_nd_update(mask, all_indices, updates)
        return inputs * mask
    def get_config(self):
        config = super().get_config()
        config.update({
            "ratio": self.ratio,
            "holes_number": self.holes_number,
            "p": self.p
        })
        return config

from sklearn.preprocessing import StandardScaler
def normalize_features(features, target_dim):
    try:
        if len(features.shape) != 2:
            raise ValueError(f"Expected 2D features, got shape {features.shape}")
        scaler = StandardScaler()
        normalized = scaler.fit_transform(features)
        current_dim = normalized.shape[1]
        if current_dim == target_dim:
            return normalized
        elif current_dim > target_dim:
            return normalized[:, :target_dim]
        else:
            padding = np.zeros((normalized.shape[0], target_dim - current_dim), dtype=np.float32)
            return np.concatenate([normalized, padding], axis=1)
    except Exception as e:
        logging.error(f"Error normalizing features: {str(e)}")
        return np.zeros((features.shape[0], target_dim), dtype=np.float32)

def extract_and_save_features(model_name, feature_extractor, generator, save_dir, sample_ids):
    expected_samples = len(generator.image_paths)
    features_2d = []
    processed_samples = []
    iterator = iter(generator.dataset)
    steps = int(np.ceil(expected_samples / generator.batch_size))
    for step in range(steps):
        try:
            batch_data = next(iterator, None)
            if batch_data is None:
                break
            batch_images, _ = batch_data
            if batch_images.shape[0] == 0:
                continue
            batch_features_2d = feature_extractor(batch_images, training=False)
            features_2d.append(batch_features_2d.numpy().astype(np.float32))
            processed_samples.extend(sample_ids[step * generator.batch_size: (step + 1) * generator.batch_size][:batch_images.shape[0]])
        except Exception as e:
            logging.error(f"Lỗi tại batch {step+1}: {str(e)}")
            continue
    features_2d = np.concatenate(features_2d, axis=0) if features_2d else np.zeros((expected_samples, 512), dtype=np.float32)
    if features_2d.shape[0] != expected_samples:
        features_2d = features_2d[:expected_samples] if features_2d.shape[0] > expected_samples else \
                      np.pad(features_2d, ((0, expected_samples - features_2d.shape[0]), (0, 0)), mode='edge')
    os.makedirs(save_dir, exist_ok=True)
    features_2d_path = os.path.join(save_dir, f"{model_name}_features_2d.npy")
    try:
        np.save(features_2d_path, features_2d)
        for _ in range(3):
            subprocess.run(["sync"])
            time.sleep(1)
            if os.path.exists(features_2d_path):
                print(f"Đã lưu đặc trưng 2D tại: {features_2d_path}, shape={features_2d.shape}")
                break
        else:
            logging.error(f"Không thể lưu tệp 2D tại: {features_2d_path} sau nhiều lần thử")
    except Exception as e:
        logging.error(f"Lỗi khi lưu đặc trưng 2D tại {features_2d_path}: {str(e)}")
        raise
    metadata = {
        "model_name": model_name,
        "features_2d_path": features_2d_path,
        "sample_ids": processed_samples,
        "timestamp": pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')
    }
    metadata_path = os.path.join(save_dir, f"{model_name}_features_metadata.json")
    try:
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f, indent=4)
        subprocess.run(["sync"])
        print(f"Đã lưu metadata tại: {metadata_path}")
    except Exception as e:
        logging.error(f"Lỗi khi lưu metadata tại {metadata_path}: {str(e)}")
    return features_2d

def combine_and_reduce_features(features_dict, labels, sample_ids, save_dir, n_components=50, target_dim=512):
    num_samples = len(sample_ids)
    normalized_features = {}
    for model_name, features in features_dict.items():
        if features.shape[0] == num_samples:
            normalized_features[model_name] = normalize_features(features, target_dim)
        else:
            features_adj = features[:num_samples] if features.shape[0] > num_samples else \
                           np.pad(features, ((0, num_samples - features.shape[0]), (0, 0)), mode='edge')
            normalized_features[model_name] = normalize_features(features_adj, target_dim)
    for model_name, features in normalized_features.items():
        features_path = os.path.join(save_dir, f"{model_name}_normalized_features_2d.npy")
        np.save(features_path, features)
        print(f"Đã lưu đặc trưng 2D chuẩn hóa cho {model_name} tại: {features_path}, shape={features.shape}")
    combined_features = []
    valid_indices = []
    for i in range(num_samples):
        sample_features = [normalized_features[model_name][i] for model_name in normalized_features
                          if i < len(normalized_features[model_name])]
        if sample_features:
            combined_features.append(np.concatenate(sample_features))
            valid_indices.append(i)
    combined_features = np.array(combined_features, dtype=np.float32) if combined_features else \
                       np.zeros((num_samples, target_dim * len(normalized_features)), dtype=np.float32)
    valid_indices = np.array(valid_indices) if valid_indices else np.arange(num_samples)
    combined_features_path = os.path.join(save_dir, "combined_features_2d_before_pca.npy")
    np.save(combined_features_path, combined_features)
    print(f"Đã lưu đặc trưng 2D kết hợp (trước PCA) tại: {combined_features_path}, shape={combined_features.shape}")
    pca = None
    reduced_features = combined_features
    if n_components is not None:
        pca = PCA(n_components=n_components)
        reduced_features = pca.fit_transform(combined_features)
    reduced_features_path = os.path.join(save_dir, "combined_features_2d_after_pca.npy")
    np.save(reduced_features_path, reduced_features)
    print(f"Đã lưu đặc trưng 2D kết hợp (sau PCA) tại: {reduced_features_path}, shape={reduced_features.shape}")
    metadata = {
        "model_names": list(features_dict.keys()),
        "num_samples": num_samples,
        "target_dim": target_dim,
        "n_components": n_components,
        "sample_ids": sample_ids.tolist(),
        "labels": labels.tolist(),
        "features_2d_paths": {model_name: os.path.join(save_dir, f"{model_name}_normalized_features_2d.npy")
                             for model_name in features_dict},
        "combined_features_before_pca": combined_features_path,
        "combined_features_after_pca": reduced_features_path,
        "timestamp": pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')
    }
    metadata_path = os.path.join(save_dir, "combined_features_metadata.json")
    with open(metadata_path, 'w') as f:
        json.dump(metadata, f, indent=4)
    print(f"Đã lưu metadata tại: {metadata_path}")
    return reduced_features, pca, valid_indices, reduced_features.shape[1]

def save_meta_learner_features(meta_feature_model, features, labels, sample_ids, save_dir):
    meta_features_2d = meta_feature_model.predict(features, batch_size=32, verbose=0)
    meta_features_2d_path = os.path.join(save_dir, "meta_learner_features_2d.npy")
    np.save(meta_features_2d_path, meta_features_2d)
    print(f"Đã lưu đặc trưng 2D của meta-learner tại: {meta_features_2d_path}, shape={meta_features_2d.shape}")
    metadata = {
        "model_name": "meta_learner",
        "features_2d_path": meta_features_2d_path,
        "sample_ids": sample_ids.tolist(),
        "labels": labels.tolist(),
        "timestamp": pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')
    }
    metadata_path = os.path.join(save_dir, "meta_learner_features_metadata.json")
    with open(metadata_path, 'w') as f:
        json.dump(metadata, f, indent=4)
    print(f"Đã lưu metadata meta-learner tại: {metadata_path}")
    return meta_features_2d

def augment_single_class(features, labels, cls, num_samples_needed):
    cls_indices = np.where(np.argmax(labels, axis=1) == cls)[0]
    if len(cls_indices) == 0:
        return [], []
    augmenter = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.Rotate(limit=45, p=0.7),
        A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.7),
        A.GaussNoise(p=0.5),
        A.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), p=0.3),
        A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5),
        A.RandomCrop(height=int(SIZE*0.9), width=int(SIZE*0.9), p=0.3),
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=30, p=0.5)
    ])
    aug_features = []
    aug_labels = []
    for _ in range(num_samples_needed):
        idx = np.random.choice(cls_indices)
        img = features[idx].astype(np.uint8)
        aug_img = augmenter(image=img)['image']
        aug_img = custom_random_erasing(
            aug_img, scale=(0.01, 0.05), ratio=(0.5, 2.0), p=0.3, value=np.mean(aug_img, axis=(0, 1))
        )
        aug_features.append(aug_img)
        aug_labels.append(labels[idx])
    return np.array(aug_features, dtype=np.float32), np.array(aug_labels, dtype=np.float32)

def create_episode(features, labels, n_support=10, n_query=10, hard_sample_ratio=0.3, class_3_multiplier=2):
    target_classes = [0, 1, 2, 3, 4]
    if len(labels.shape) > 1:
        label_indices = np.argmax(labels, axis=1)
    else:
        label_indices = labels
    keep_indices = np.isin(label_indices, target_classes)
    features = features[keep_indices]
    labels = labels[keep_indices]
    label_indices = label_indices[keep_indices]
    support_features, support_labels, query_features, query_labels = [], [], [], []
    hard_indices = []
    n_support_class_3 = int(n_support * class_3_multiplier)
    n_query_class_3 = int(n_query * class_3_multiplier)
    for label in target_classes:
        indices = np.where(label_indices == label)[0]
        n_support_cls = n_support_class_3 if label == 3 else n_support
        n_query_cls = n_query_class_3 if label == 3 else n_query
        min_samples_per_class = n_support_cls + n_query_cls
        if len(indices) == 0:
            logging.warning(f"Lớp {label} không có mẫu. Bỏ qua.")
            continue
        if len(indices) < min_samples_per_class:
            logging.info(f"Lớp {label} chỉ có {len(indices)} mẫu, cần {min_samples_per_class}. Sử dụng oversampling.")
            indices = np.random.choice(indices, size=min_samples_per_class, replace=True)
        support_indices = np.random.choice(indices, n_support_cls, replace=False)
        remaining_indices = np.setdiff1d(indices, support_indices)
        n_hard = int(n_query_cls * hard_sample_ratio)
        if len(remaining_indices) < n_hard:
            hard_samples = np.random.choice(remaining_indices, n_hard, replace=True)
        else:
            hard_samples = np.random.choice(remaining_indices, n_hard, replace=False)
        easy_samples = np.setdiff1d(remaining_indices, hard_samples)
        n_easy = n_query_cls - len(hard_samples)
        if n_easy > 0:
            if len(easy_samples) < n_easy:
                easy_samples = np.random.choice(easy_samples, n_easy, replace=True)
            else:
                easy_samples = np.random.choice(easy_samples, n_easy, replace=False)
            query_indices = np.concatenate([hard_samples, easy_samples])
        else:
            query_indices = hard_samples
        support_features.extend(features[support_indices])
        support_labels.extend(labels[support_indices])
        query_features.extend(features[query_indices])
        query_labels.extend(labels[query_indices])
        hard_indices.extend(hard_samples)
    support_features = np.array(support_features, dtype=np.float32)
    support_labels = np.array(support_labels, dtype=np.float32)
    query_features = np.array(query_features, dtype=np.float32)
    query_labels = np.array(query_labels, dtype=np.float32)
    hard_indices = np.array(hard_indices, dtype=np.int32)
    if support_features.size == 0 or query_features.size == 0:
        logging.warning("Episode rỗng được tạo ra. Trả về episode rỗng.")
        return np.array([]), np.array([]), np.array([]), np.array([]), np.array([])
    support_label_counts = np.bincount(np.argmax(support_labels, axis=1), minlength=5)
    query_label_counts = np.bincount(np.argmax(query_labels, axis=1), minlength=5)
    print(f"Episode distribution - Support: {dict(zip(range(5), support_label_counts))}")
    print(f"Episode distribution - Query: {dict(zip(range(5), query_label_counts))}")
    return support_features, support_labels, query_features, query_labels, hard_indices

class CustomIsotonicRegression:
    def __init__(self):
        self.iso_reg = IsotonicRegression()
        self.X_min_ = None
        self.X_max_ = None
    def fit(self, X, y):
        self.X_min_ = np.min(X)
        self.X_max_ = np.max(X)
        self.iso_reg.fit(X, y)
        return self
    def predict(self, X):
        X_clipped = np.clip(X, self.X_min_, self.X_max_)
        return self.iso_reg.predict(X_clipped)

import tensorflow as tf
import numpy as np
import logging
import matplotlib.pyplot as plt
import seaborn as sns
import os
import pandas as pd
import gc
from sklearn.metrics import cohen_kappa_score, f1_score, precision_score, recall_score, confusion_matrix
from tensorflow.keras.layers import Dense, Input, BatchNormalization, Dropout
from tensorflow.keras.models import Model

def maml_fomaml_train_manual(features, labels, valid_features, valid_labels, input_dim, n_episodes=20,
                             n_support=10, n_query=10, inner_lr=0.0005, outer_lr=0.0005, fine_tune_lr=0.0001,
                             use_fomaml=True, memory_size=20, sample_ids=None):
    """
    Huấn luyện meta-learner sử dụng MAML/FO-MAML với log đầy đủ chỉ số trên support set.

    Args:
        features: Đặc trưng tập train (train_combined_features).
        labels: Nhãn tập train (train_y_multi, one-hot).
        valid_features: Đặc trưng tập valid.
        valid_labels: Nhãn tập valid (one-hot).
        input_dim: Kích thước đầu vào của mô hình.
        n_episodes: Số episode huấn luyện.
        n_support: Số mẫu trong support set mỗi lớp.
        n_query: Số mẫu trong query set mỗi lớp.
        inner_lr: Learning rate cho inner loop.
        outer_lr: Learning rate cho outer loop.
        fine_tune_lr: Learning rate cho fine-tuning.
        use_fomaml: Sử dụng FO-MAML thay vì MAML.
        memory_size: Kích thước bộ nhớ cho MemoryAugmentedLayer.
        sample_ids: ID mẫu cho việc lưu đặc trưng meta-learner.

    Returns:
        meta_model: Mô hình meta-learner đầy đủ.
        meta_classification_model: Mô hình phân lớp.
        history: Lịch sử huấn luyện với các chỉ số.
    """
    # Định nghĩa thư mục lưu trữ
    feature_save_dir = "/content/drive/MyDrive/working"
    log_dir = os.path.join(feature_save_dir, "logs")
    os.makedirs(log_dir, exist_ok=True)

    # Định nghĩa lớp MemoryAugmentedLayer
    class MemoryAugmentedLayer(tf.keras.layers.Layer):
        def __init__(self, memory_size, memory_dim, **kwargs):
            super().__init__(**kwargs)
            self.memory_size = memory_size
            self.memory_dim = memory_dim
        def build(self, input_shape):
            self.memory = self.add_weight(
                shape=(self.memory_size, self.memory_dim),
                initializer='zeros',
                trainable=False,
                dtype=tf.float32
            )
            super().build(input_shape)
        def call(self, inputs):
            batch_size = tf.shape(inputs)[0]
            memory_size = tf.shape(self.memory)[0]
            memory_sliced = tf.cond(
                batch_size > memory_size,
                lambda: tf.tile(self.memory, [(batch_size + memory_size - 1) // memory_size, 1])[:batch_size],
                lambda: self.memory[:batch_size]
            )
            return tf.reduce_mean(tf.stack([inputs, memory_sliced], axis=0), axis=0)
        def get_config(self):
            config = super().get_config()
            config.update({'memory_size': self.memory_size, 'memory_dim': self.memory_dim})
            return config

    # Định nghĩa lớp GradientReversalLayer
    class GradientReversalLayer(tf.keras.layers.Layer):
        def __init__(self, lambda_=1.0, **kwargs):
            super().__init__(**kwargs)
            self.lambda_ = lambda_
        def call(self, inputs, training=None):
            inputs = tf.convert_to_tensor(inputs, dtype=tf.float32)
            return inputs if not training else tf.math.multiply(-self.lambda_, inputs)
        def get_config(self):
            config = super().get_config()
            config.update({"lambda_": self.lambda_})
            return config

    # Hàm tạo mô hình
    def create_model(input_dim):
        inputs = Input(shape=(input_dim,))
        x = Dense(512, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.1), dtype=tf.float32)(inputs)
        x = BatchNormalization(dtype=tf.float32)(x)
        x = Dropout(0.6)(x)
        x = Dense(256, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.1), dtype=tf.float32)(x)
        x = BatchNormalization(dtype=tf.float32)(x)
        x = Dropout(0.6)(x)
        x = Dense(128, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.1), dtype=tf.float32)(x)
        x = BatchNormalization(dtype=tf.float32)(x)
        feature_output = x
        x = MemoryAugmentedLayer(memory_size=memory_size, memory_dim=128)(x)
        classification_output = Dense(5, activation='softmax', name='classification', dtype=tf.float32)(x)
        domain_inputs = GradientReversalLayer(lambda_=1.0)(x)
        domain_x = Dense(64, activation='relu', dtype=tf.float32)(domain_inputs)
        domain_output = Dense(2, activation='softmax', name='domain', dtype=tf.float32)(domain_x)
        model = Model(inputs=inputs, outputs=[classification_output, domain_output])
        classification_model = Model(inputs=inputs, outputs=classification_output)
        feature_model = Model(inputs=inputs, outputs=feature_output)
        memory_layer = [layer for layer in model.layers if isinstance(layer, MemoryAugmentedLayer)][0]
        return model, classification_model, memory_layer, feature_model

    # Hàm tính độ chính xác
    def compute_accuracy(y_true, y_pred):
        y_pred = tf.argmax(y_pred, axis=1, output_type=tf.int32)
        y_true = tf.argmax(y_true, axis=1, output_type=tf.int32)
        return tf.reduce_mean(tf.cast(tf.equal(y_true, y_pred), tf.float32))

    # Hàm tính prototype
    def compute_prototypes(features, labels, feature_model):
        features = feature_model(features, training=False)
        labels_arg = tf.argmax(labels, axis=1)
        prototypes = []
        for cls in range(5):
            cls_mask = tf.equal(labels_arg, cls)
            cls_features = tf.boolean_mask(features, cls_mask)
            prototype = tf.reduce_mean(cls_features, axis=0) if tf.shape(cls_features)[0] > 0 else \
                        tf.zeros([features.shape[1]], dtype=tf.float32)
            prototypes.append(prototype)
        return tf.stack(prototypes)

    # Hàm tính prototypical loss
    def prototypical_loss(query_features, query_labels, prototypes):
        query_labels_arg = tf.argmax(query_labels, axis=1)
        query_features_exp = tf.expand_dims(query_features, axis=1)
        prototypes_exp = tf.expand_dims(prototypes, axis=0)
        distances = tf.reduce_sum(tf.square(query_features_exp - prototypes_exp), axis=-1)
        logits = tf.cast(-distances, tf.float32)
        return tf.cast(tf.keras.losses.categorical_crossentropy(query_labels, logits, from_logits=True), tf.float32)

    # Hàm dự đoán prototypical
    def prototypical_predict(query_features, prototypes):
        query_features_exp = tf.expand_dims(query_features, axis=1)
        prototypes_exp = tf.expand_dims(prototypes, axis=0)
        distances = tf.reduce_sum(tf.square(query_features_exp - prototypes_exp), axis=-1)
        return tf.cast(tf.nn.softmax(tf.cast(-distances, tf.float32)), tf.float32)

    # Hàm áp dụng temperature scaling
    def apply_temperature_scaling(logits, temperature=2.0):
        logits = tf.convert_to_tensor(logits, dtype=tf.float32)
        return tf.nn.softmax(logits / temperature)

    # Hàm Laplace smoothing
    def laplace_smoothing(probs, epsilon=1e-5):
        probs = tf.convert_to_tensor(probs, dtype=tf.float32)
        return (probs + epsilon) / (tf.reduce_sum(probs, axis=-1, keepdims=True) + 5 * epsilon)

    # Hàm lưu confusion matrix
    def save_confusion_matrix(y_true, y_pred, episode, qwk, save_dir, prefix=''):
        cm = confusion_matrix(y_true, y_pred, labels=[0, 1, 2, 3, 4])
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                    xticklabels=[0, 1, 2, 3, 4], yticklabels=[0, 1, 2, 3, 4])
        plt.title(f'Ma trận nhầm lẫn - {prefix}QWK: {qwk:.4f} tại Episode {episode+1}')
        plt.xlabel('Dự đoán')
        plt.ylabel('Thực tế')
        cm_path = os.path.join(save_dir, f'confusion_matrix_{prefix}episode_{episode+1}.png')
        plt.savefig(cm_path)
        plt.close()
        print(f"Đã lưu ma trận nhầm lẫn {prefix} tại: {cm_path}")
        return cm

    # Hàm giảm learning rate và early stopping
    def reduce_lr_and_early_stop(episode, qwk, best_qwk, patience_lr, patience_stop, lr_patience_counter, stop_patience_counter,
                                 inner_lr, outer_lr, fine_tune_lr, min_lr=1e-7, reduce_factor=0.5):
        stop_training = False
        if qwk > best_qwk:
            lr_patience_counter = 0
            stop_patience_counter = 0
        else:
            lr_patience_counter += 1
            stop_patience_counter += 1
        if lr_patience_counter >= patience_lr:
            inner_lr = max(inner_lr * reduce_factor, min_lr)
            outer_lr = max(outer_lr * reduce_factor, min_lr)
            fine_tune_lr = max(fine_tune_lr * reduce_factor, min_lr)
            lr_patience_counter = 0
            print(f"Episode {episode+1}: Giảm learning rate - inner_lr={inner_lr:.6f}, outer_lr={outer_lr:.6f}, fine_tune_lr={fine_tune_lr:.6f}")
        if stop_patience_counter >= patience_stop:
            stop_training = True
            print(f"Episode {episode+1}: Early stopping do QWK không cải thiện sau {patience_stop} episode")
        return inner_lr, outer_lr, fine_tune_lr, lr_patience_counter, stop_patience_counter, stop_training

    # Khởi tạo mô hình và optimizer
    meta_model, meta_classification_model, memory_layer, feature_model = create_model(input_dim)
    meta_optimizer = tf.keras.optimizers.Adam(learning_rate=outer_lr)
    fine_tune_optimizer = tf.keras.optimizers.Adam(learning_rate=fine_tune_lr)
    loss_fn = tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1)
    domain_loss_fn = tf.keras.losses.CategoricalCrossentropy()

    # Khởi tạo các biến theo dõi
    best_qwk = -float('inf')
    lr_patience_counter = 0
    stop_patience_counter = 0
    patience_lr = 10
    patience_stop = 50
    min_lr = 1e-7
    reduce_factor = 0.5
    weights_filepath = os.path.join(feature_save_dir, "meta_model_maml_fomaml_best_weights.weights.h5")
    history = {
        'qwk': [], 'loss': [], 'support_loss': [], 'support_accuracy': [], 'query_loss': [], 'query_accuracy': [],
        'precision': [], 'recall': [], 'support_qwk': [], 'support_precision': [], 'support_recall': [],
        'support_f1': [], 'support_cm': []
    }
    class_weights = np.ones(5, dtype=np.float32)
    class_weights[3] = 10 / (10 * 2)
    source_domain_labels = tf.keras.utils.to_categorical(tf.zeros(len(features), dtype=tf.int32), num_classes=2)
    source_domain_labels = tf.cast(source_domain_labels, tf.float32)
    target_domain_labels = tf.keras.utils.to_categorical(tf.ones(len(valid_features), dtype=tf.int32), num_classes=2)
    target_domain_labels = tf.cast(target_domain_labels, tf.float32)

    # Lưu đặc trưng meta-learner nếu có sample_ids
    if sample_ids is not None:
        print("Lưu đặc trưng 2D của meta-learner...")
        meta_features_2d = feature_model.predict(valid_features, batch_size=32, verbose=0)
        meta_features_2d_path = os.path.join(feature_save_dir, "meta_learner_features_2d.npy")
        np.save(meta_features_2d_path, meta_features_2d)
        print(f"Đã lưu đặc trưng 2D của meta-learner tại: {meta_features_2d_path}, shape={meta_features_2d.shape}")
        metadata = {
            "model_name": "meta_learner",
            "features_2d_path": meta_features_2d_path,
            "sample_ids": sample_ids.tolist(),
            "labels": valid_labels.tolist(),
            "timestamp": pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')
        }
        metadata_path = os.path.join(feature_save_dir, "meta_learner_features_metadata.json")
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f, indent=4)
        print(f"Đã lưu metadata meta-learner tại: {metadata_path}")

    # Vòng lặp huấn luyện
    for episode in range(n_episodes):
        # Tạo episode
        support_features, support_labels, query_features, query_labels, hard_indices = create_episode(
            features, labels, n_support, n_query, hard_sample_ratio=0.3, class_3_multiplier=2)
        if support_features.size == 0 or query_features.size == 0:
            logging.warning(f"Episode {episode+1}: Không đủ mẫu")
            history['qwk'].append(0.0)
            history['support_loss'].append(0.0)
            history['support_accuracy'].append(0.0)
            history['query_loss'].append(0.0)
            history['query_accuracy'].append(0.0)
            history['precision'].append(0.0)
            history['recall'].append(0.0)
            history['support_qwk'].append(0.0)
            history['support_precision'].append(0.0)
            history['support_recall'].append(0.0)
            history['support_f1'].append(0.0)
            history['support_cm'].append(np.zeros((5, 5)).tolist())
            continue

        # Khởi tạo task model
        task_model, task_classification_model, task_memory_layer, task_feature_model = create_model(input_dim)
        task_model.set_weights(meta_model.get_weights())
        task_optimizer = tf.keras.optimizers.Adam(learning_rate=inner_lr)
        class_weight_dict = {i: float(w) for i, w in enumerate(class_weights)}
        print(f"Episode {episode+1} - Trọng số lớp: {class_weight_dict}")
        support_prototypes = compute_prototypes(support_features, support_labels, task_feature_model)

        # Inner loop training
        for _ in range(10):
            with tf.GradientTape() as tape:
                class_preds, domain_preds = task_model(support_features, training=True)
                min_size = min(class_preds.shape[0], support_labels.shape[0])
                if not min_size:
                    break
                class_preds = class_preds[:min_size]
                support_labels_adj = support_labels[:min_size]
                min_size_domain = min(domain_preds.shape[0], source_domain_labels.shape[0])
                domain_preds = domain_preds[:min_size_domain]
                source_domain_labels_slice = source_domain_labels[:min_size_domain]
                class_loss = tf.cast(loss_fn(support_labels_adj, class_preds, sample_weight=[
                    class_weight_dict.get(np.argmax(label), 1.0) for label in support_labels_adj]), tf.float32)
                domain_loss = tf.cast(domain_loss_fn(source_domain_labels_slice, domain_preds), tf.float32)
                support_features_task = task_feature_model(support_features, training=False)
                proto_loss = prototypical_loss(support_features_task, support_labels_adj, support_prototypes)
                total_loss = class_loss + 0.5 * domain_loss + 0.5 * proto_loss
            task_grads = tape.gradient(total_loss, task_model.trainable_variables)
            valid_grads = [(g, v) for g, v in zip(task_grads, task_model.trainable_variables) if g is not None]
            task_optimizer.apply_gradients(valid_grads)
            task_keys = task_feature_model(support_features, training=False)
            if task_keys.shape[1] != 128:
                task_keys = Dense(128, use_bias=False, dtype=tf.float32)(task_keys)
            task_keys = tf.concat([task_keys, tf.zeros([memory_size - task_keys.shape[0], task_keys.shape[1]], dtype=tf.float32)], axis=0) if task_keys.shape[0] < memory_size else task_keys[:memory_size]
            task_memory_layer.memory.assign(task_keys)
            del task_grads, valid_grads
            gc.collect()

        # Đánh giá trên support set
        support_preds = task_model(support_features, training=False)[0]
        support_loss_value = float(class_loss.numpy())
        support_accuracy = float(compute_accuracy(support_labels_adj, support_preds).numpy())
        support_preds_classes = np.argmax(support_preds.numpy(), axis=1)
        support_true_classes = np.argmax(support_labels_adj, axis=1)
        support_qwk = cohen_kappa_score(support_true_classes, support_preds_classes,
                                        labels=[0, 1, 2, 3, 4], weights='quadratic')
        support_precision = precision_score(support_true_classes, support_preds_classes,
                                           average='weighted', zero_division=0)
        support_recall = recall_score(support_true_classes, support_preds_classes,
                                     average='weighted', zero_division=0)
        support_f1 = f1_score(support_true_classes, support_preds_classes,
                             average='weighted', zero_division=0)
        support_cm = confusion_matrix(support_true_classes, support_preds_classes,
                                     labels=[0, 1, 2, 3, 4])
        history['support_qwk'].append(float(support_qwk))
        history['support_precision'].append(float(support_precision))
        history['support_recall'].append(float(support_recall))
        history['support_f1'].append(float(support_f1))
        history['support_cm'].append(support_cm.tolist())
        support_cm = save_confusion_matrix(support_true_classes, support_preds_classes, episode,
                                          support_qwk, feature_save_dir, prefix='support_')

        # Đánh giá trên query set
        with tf.GradientTape() as outer_tape:
            query_preds, domain_preds = task_model(query_features, training=True)
            min_size = min(query_preds.shape[0], query_labels.shape[0])
            if not min_size:
                logging.warning(f"Episode {episode+1}: Không có dữ liệu query hợp lệ")
                history['qwk'].append(0.0)
                history['support_loss'].append(support_loss_value)
                history['support_accuracy'].append(support_accuracy)
                history['query_loss'].append(0.0)
                history['query_accuracy'].append(0.0)
                history['precision'].append(0.0)
                history['recall'].append(0.0)
                continue
            query_preds = query_preds[:min_size]
            query_labels_adj = query_labels[:min_size]
            min_size_domain = min(domain_preds.shape[0], source_domain_labels.shape[0])
            domain_preds = domain_preds[:min_size_domain]
            source_domain_labels_slice = source_domain_labels[:min_size_domain]
            query_loss = tf.cast(loss_fn(query_labels_adj, query_preds, sample_weight=[
                class_weight_dict.get(np.argmax(label), 1.0) for label in query_labels_adj]), tf.float32)
            domain_loss = tf.cast(domain_loss_fn(source_domain_labels_slice, domain_preds), tf.float32)
            query_features_task = task_feature_model(query_features, training=False)
            proto_loss = prototypical_loss(query_features_task, query_labels_adj, support_prototypes)
            total_query_loss = query_loss + 0.5 * domain_loss + 0.5 * proto_loss
            query_accuracy = float(compute_accuracy(query_labels_adj, query_preds).numpy())
            query_loss_value = float(query_loss.numpy())

        # Cập nhật meta-model
        meta_grads = outer_tape.gradient(total_query_loss, task_model.trainable_variables)
        valid_grads = [(g, v) for g, v in zip(meta_grads, meta_model.trainable_variables) if g is not None]
        meta_optimizer.apply_gradients(valid_grads)
        memory_keys = feature_model(support_features, training=False)
        if memory_keys.shape[1] != 128:
            memory_keys = Dense(128, use_bias=False, dtype=tf.float32)(memory_keys)
        memory_keys = tf.concat([memory_keys, tf.zeros([memory_size - memory_keys.shape[0], memory_keys.shape[1]], dtype=tf.float32)], axis=0) if memory_keys.shape[0] < memory_size else memory_keys[:memory_size]
        memory_layer.memory.assign(memory_keys)
        del meta_grads, valid_grads
        gc.collect()

        # Fine-tuning trên query set
        for _ in range(5):
            with tf.GradientTape() as fine_tune_tape:
                fine_tune_preds = meta_classification_model(query_features, training=True)
                min_size = min(fine_tune_preds.shape[0], query_labels.shape[0])
                if not min_size:
                    continue
                fine_tune_preds = fine_tune_preds[:min_size]
                query_labels_adj = query_labels[:min_size]
                fine_tune_loss = tf.cast(loss_fn(query_labels_adj, fine_tune_preds, sample_weight=[
                    class_weight_dict.get(np.argmax(label), 1.0) for label in query_labels_adj]), tf.float32)
            fine_tune_grads = fine_tune_tape.gradient(fine_tune_loss, meta_classification_model.trainable_variables)
            valid_grads = [(g, v) for g, v in zip(fine_tune_grads, meta_classification_model.trainable_variables) if g is not None]
            fine_tune_optimizer.apply_gradients(valid_grads)
            del fine_tune_grads, valid_grads
            gc.collect()

        # Đánh giá trên valid set
        valid_preds_maml = tf.cast(meta_classification_model(valid_features, training=False), tf.float32)
        min_size = min(valid_preds_maml.shape[0], valid_labels.shape[0])
        if not min_size:
            logging.warning(f"Episode {episode+1}: Không có dữ liệu valid hợp lệ")
            history['qwk'].append(0.0)
            history['support_loss'].append(support_loss_value)
            history['support_accuracy'].append(support_accuracy)
            history['query_loss'].append(query_loss_value)
            history['query_accuracy'].append(query_accuracy)
            history['precision'].append(0.0)
            history['recall'].append(0.0)
            continue
        valid_preds_maml = valid_preds_maml[:min_size]
        valid_labels_adj = valid_labels[:min_size]
        valid_features_task = task_feature_model(valid_features, training=False)
        valid_prototypes = compute_prototypes(features, labels, task_feature_model)
        valid_preds_proto = tf.cast(prototypical_predict(valid_features_task, valid_prototypes), tf.float32)
        valid_preds_ensemble = (0.5 * valid_preds_maml + 0.5 * valid_preds_proto)
        valid_preds_scaled = tf.cast(apply_temperature_scaling(valid_preds_ensemble, temperature=2.0), tf.float32)
        valid_preds_scaled = laplace_smoothing(valid_preds_scaled, epsilon=1e-5)
        valid_preds_classes = np.argmax(valid_preds_scaled.numpy(), axis=1)
        qwk = cohen_kappa_score(np.argmax(valid_labels_adj, axis=1), valid_preds_classes,
                                labels=[0, 1, 2, 3, 4], weights='quadratic')
        precision = precision_score(np.argmax(valid_labels_adj, axis=1), valid_preds_classes,
                                   average='weighted', zero_division=0)
        recall = recall_score(np.argmax(valid_labels_adj, axis=1), valid_preds_classes,
                             average='weighted', zero_division=0)
        history['qwk'].append(float(qwk))
        history['precision'].append(float(precision))
        history['recall'].append(float(recall))
        history['support_loss'].append(float(support_loss_value))
        history['support_accuracy'].append(float(support_accuracy))
        history['query_loss'].append(float(query_loss_value))
        history['query_accuracy'].append(float(query_accuracy))

        # Giảm learning rate và early stopping
        inner_lr, outer_lr, fine_tune_lr, lr_patience_counter, stop_patience_counter, stop_training = \
            reduce_lr_and_early_stop(
                episode, qwk, best_qwk, patience_lr, patience_stop, lr_patience_counter, stop_patience_counter,
                inner_lr, outer_lr, fine_tune_lr, min_lr, reduce_factor
            )
        meta_optimizer.learning_rate.assign(outer_lr)
        fine_tune_optimizer.learning_rate.assign(fine_tune_lr)

        # Log vào TensorBoard
        with tf.summary.create_file_writer(log_dir).as_default():
            tf.summary.scalar('support_loss', support_loss_value, step=episode)
            tf.summary.scalar('support_accuracy', support_accuracy, step=episode)
            tf.summary.scalar('support_qwk', support_qwk, step=episode)
            tf.summary.scalar('support_precision', support_precision, step=episode)
            tf.summary.scalar('support_recall', support_recall, step=episode)
            tf.summary.scalar('support_f1', support_f1, step=episode)
            tf.summary.scalar('query_loss', query_loss_value, step=episode)
            tf.summary.scalar('query_accuracy', query_accuracy, step=episode)
            tf.summary.scalar('qwk', qwk, step=episode)
            tf.summary.scalar('precision', precision, step=episode)
            tf.summary.scalar('recall', recall, step=episode)
            tf.summary.scalar('inner_lr', inner_lr, step=episode)
            tf.summary.scalar('outer_lr', outer_lr, step=episode)
            tf.summary.scalar('fine_tune_lr', fine_tune_lr, step=episode)

        # In kết quả
        print(f"Episode {episode+1}/{n_episodes}:")
        print(f"  Support Loss: {support_loss_value:.4f}, Accuracy: {support_accuracy:.4f}")
        print(f"  Support QWK: {support_qwk:.4f}, F1: {support_f1:.4f}, Precision: {support_precision:.4f}, Recall: {support_recall:.4f}")
        print(f"  Support Confusion Matrix:\n{support_cm}")
        print(f"  Query Loss: {query_loss_value:.4f}, Accuracy: {query_accuracy:.4f}")
        print(f"  Valid QWK (Ensemble): {qwk:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}")
        print(f"  Inner LR: {inner_lr:.6f}, Outer LR: {outer_lr:.6f}, Fine-tune LR: {fine_tune_lr:.6f}")

        # Lưu trọng số nếu QWK tốt hơn
        if qwk > best_qwk:
            best_qwk = qwk
            try:
                meta_model.save_weights(weights_filepath, overwrite=True)
                print(f"Đã lưu trọng số tốt nhất tại episode {episode+1} với QWK: {best_qwk:.4f}")
                cm = save_confusion_matrix(
                    np.argmax(valid_labels_adj, axis=1), valid_preds_classes, episode,
                    best_qwk, feature_save_dir, prefix='best_'
                )
                print(f"Ma trận nhầm lẫn cho QWK tốt nhất tại Episode {episode+1}:\n{cm}")
            except Exception as e:
                logging.error(f"Lỗi khi lưu trọng số: {str(e)}. Thử lưu .h5")
                alt_weights_filepath = os.path.join(feature_save_dir, "meta_model_maml_fomaml_best_weights.h5")
                meta_model.save_weights(alt_weights_filepath, overwrite=True)
                print(f"Đã lưu trọng số (dạng thay thế) tại: {alt_weights_filepath}")

        if stop_training:
            print(f"Early stopping kích hoạt tại episode {episode+1}")
            break
        gc.collect()

    # Fine-tuning trên tập valid
    print("Fine-tuning trên tập valid...")
    meta_classification_model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=fine_tune_lr),
        loss='categorical_crossentropy',
        metrics=['accuracy', tf.keras.metrics.Precision(), tf.keras.metrics.Recall()]
    )
    meta_classification_model.fit(
        valid_features, valid_labels,
        validation_data=(valid_features, valid_labels),
        epochs=100,
        batch_size=32,
        verbose=1,
        callbacks=[
            tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1),
            tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3),
            tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
        ]
    )

    # Đánh giá cuối cùng trên valid set
    valid_preds = meta_classification_model.predict(valid_features, batch_size=32)
    valid_preds = apply_temperature_scaling(valid_preds, temperature=2.0)
    valid_preds = laplace_smoothing(valid_preds, epsilon=1e-5)
    valid_preds_classes = np.argmax(valid_preds, axis=1)
    valid_true_classes = np.argmax(valid_labels, axis=1)
    qwk_final = cohen_kappa_score(valid_true_classes, valid_preds_classes,
                                  labels=[0, 1, 2, 3, 4], weights='quadratic')
    f1_final = f1_score(valid_true_classes, valid_preds_classes, average='weighted')
    recall_final = recall_score(valid_true_classes, valid_preds_classes, average='weighted')
    precision_final = precision_score(valid_true_classes, valid_preds_classes, average='weighted')
    print(f"Kết quả cuối cùng trên tập valid:")
    print(f"  Quadratic Weighted Kappa (QWK): {qwk_final:.4f}")
    print(f"  Weighted F1 Score: {f1_final:.4f}")
    print(f"  Weighted Recall: {recall_final:.4f}")
    print(f"  Weighted Precision: {precision_final:.4f}")
    cm_final = save_confusion_matrix(
        valid_true_classes, valid_preds_classes, n_episodes, qwk_final, feature_save_dir, prefix='final_'
    )
    print(f"Ma trận nhầm lẫn cuối cùng:\n{cm_final}")

    # Lưu trữ số
    final_weights_filepath = os.path.join(feature_save_dir, "model.weights.h5")
    final_config_filepath = os.path.join(feature_save_dir, "config.json")
    final_metadata_filepath = os.path.join(feature_save_dir, "metadata.json")
    try:
        meta_model.save_weights(final_weights_filepath, overwrite=True)
        print(f"Đã lưu trọng số meta-model tại: {final_weights_filepath}")
    except Exception as e:
        logging.error("fLỗi khi lưu trọng số meta-model: {str(e)}")
        alt_final_weights_filepath = os.path.join(feature_save_dir, "f{model_weights_alt.h5}")
        meta_model.save_weights(alt_final_weights_filepath, overwrite=True)
        print(f"Đã lưu trọng số (dạng thay thế) tại: {alt_final_weights_filepath}")

    try:
        model_config = meta_model.to_json()
        with open(final_config_filepath, 'w') as f:
            json.dump(json.loads(model_config), f, indent=2)
        print(f"Đã lưu cấu hình meta-model tại: {final_config_filepath}")
    except Exception as e:
        logging.error(f"Lỗi khi lưu config meta-model: {str(e)}")

    metadata = {
        "model_type": "meta_model_maml_fomaml",
        "num_classes": 5,
        "input_dim": input_dim,
        "training_episodes": n_episodes,
        "timestamp": pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')
    }
    try:
        with open(final_metadata_filepath, 'w') as f:
            json.dump(metadata, f, indent=4)
        print(f"Đã lưu siêu dữ liệu meta-model tại: {final_metadata_filepath}")
    except Exception as e:
        logging.error(f"Lỗi khi lưu metadata: {str(e)}")

    metrics_history = {
        'qwk': float(qwk_final),
        'f1_score': float(f1_final),
        'recall': float(recall_final),
        'precision': float(precision_final),
        'training_history': {
            'qwk': [float(q) for q in history['qwk']],
            'support_loss': [float(l) for l in history['support_loss']],
            'support_accuracy': [float(a) for a in history['support_accuracy']],
            'query_loss': [float(l) for l in history['query_loss']],
            'query_accuracy': [float(a) for a in history['query_accuracy']],
            'precision': [float(p) for p in history['precision']],
            'recall': [float(r) for r in history['recall']],
            'support_qwk': [float(q) for q in history['support_qwk']],
            'support_precision': [float(p) for p in history['support_precision']],
            'support_recall': [float(r) for r in history['support_recall']],
            'support_f1': [float(f) for f in history['support_f1']],
            'support_cm': history['support_cm']
        }
    }

    metrics_filepath = os.path.join(feature_save_dir, 'final_metrics.json')
    try:
        with open(metrics_filepath, 'w') as f:
            json.dump(metrics_history, f, indent=4)
        print(f"Đã lưu các chỉ số đánh giá tại: {metrics_filepath}")
    except Exception as e:
        logging.error(f"Lỗi khi lưu metrics: {str(e)}")

    # Vẽ biểu đồ lịch sử huấn luyện
    plt.figure(figsize=(15, 10))
    plt.subplot(2, 3, 1)
    plt.plot(history['qwk'], label='Valid QWK')
    plt.title('QWK theo Episode')
    plt.xlabel('Episode')
    plt.ylabel('QWK')
    plt.legend()
    plt.subplot(2, 3, 2)
    plt.plot(history['support_loss'], label='Support Loss')
    plt.plot(history['query_loss'], label='Query Loss')
    plt.title('Loss theo Episode')
    plt.xlabel('Episode')
    plt.ylabel('Loss')
    plt.legend()
    plt.subplot(2, 3, 3)
    plt.plot(history['support_accuracy'], label='Support Accuracy')
    plt.plot(history['query_accuracy'], label='Query Accuracy')
    plt.title('Accuracy theo Episode')
    plt.xlabel('Episode')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.subplot(2, 3, 4)
    plt.plot(history['support_qwk'], label='Support QWK')
    plt.title('Support QWK theo Episode')
    plt.xlabel('Episode')
    plt.ylabel('QWK')
    plt.legend()
    plt.subplot(2, 3, 5)
    plt.plot(history['support_precision'], label='Support Precision')
    plt.plot(history['support_recall'], label='Support Recall')
    plt.plot(history['support_f1'], label='Support F1')
    plt.title('Support Metrics theo Episode')
    plt.xlabel('Episode')
    plt.ylabel('Score')
    plt.legend()
    plt.tight_layout()
    history_path = os.path.join(feature_save_dir, 'training_history.png')
    plt.savefig(history_path)
    plt.close()
    print(f"Đã lưu biểu đồ lịch sử huấn luyện tại: {history_path}")

    return meta_model, meta_classification_model, history

def evaluate_test_set(meta_classification_model, test_features, test_labels, save_dir):
    test_preds = meta_classification_model.predict(test_features, batch_size=32)
    test_preds = apply_temperature_scaling(test_preds, temperature=2.0)
    test_preds = laplace_smoothing(test_preds, epsilon=1e-5)
    test_probs = np.max(test_preds.numpy(), axis=1)
    test_probs = np.clip(test_probs, 0.0, 1.0)
    test_preds_classes = np.argmax(test_preds, axis=1)
    test_true_classes = np.argmax(test_labels, axis=1)
    qwk_test = cohen_kappa_score(test_true_classes, test_preds_classes, labels=[0, 1, 2, 3, 4], weights='quadratic')
    f1_test = f1_score(test_true_classes, test_preds_classes, average='weighted')
    precision_test = precision_score(test_true_classes, test_preds_classes, average='weighted')
    recall_test = recall_score(test_true_classes, test_preds_classes, average='weighted')
    accuracy_test = accuracy_score(test_true_classes, test_preds_classes)
    print(f"Kết quả trên tập test:")
    print(f"  Quadratic Weighted Kappa (QWK): {qwk_test:.4f}")
    print(f"  Weighted F1 Score: {f1_test:.4f}")
    print(f"  Weighted Precision: {precision_test:.4f}")
    print(f"  Weighted Recall: {recall_test:.4f}")
    print(f"  Accuracy: {accuracy_test:.4f}")
    cm_test = confusion_matrix(test_true_classes, test_preds_classes, labels=[0, 1, 2, 3, 4])
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm_test, annot=True, fmt='d', cmap='Blues',
                xticklabels=[0, 1, 2, 3, 4], yticklabels=[0, 1, 2, 3, 4])
    plt.title(f'Ma trận nhầm lẫn - Test Set, QWK: {qwk_test:.4f}')
    plt.xlabel('Dự đoán')
    plt.ylabel('Thực tế')
    cm_path = os.path.join(save_dir, 'confusion_matrix_test.png')
    plt.savefig(cm_path)
    plt.close()
    print(f"Đã lưu ma trận nhầm lẫn tập test tại: {cm_path}")
    print(f"Ma trận nhầm lẫn tập test:\n{cm_test}")
    test_metrics = {
        'qwk': float(qwk_test),
        'f1_score': float(f1_test),
        'precision': float(precision_test),
        'recall': float(recall_test),
        'accuracy': float(accuracy_test),
        'confusion_matrix': cm_test.tolist()
    }
    metrics_filepath = os.path.join(save_dir, 'test_metrics.json')
    try:
        with open(metrics_filepath, 'w') as f:
            json.dump(test_metrics, f, indent=4)
        print(f"Đã lưu các chỉ số đánh giá tập test tại: {metrics_filepath}")
    except Exception as e:
        logging.error(f"Lỗi khi lưu test metrics: {str(e)}")
    return test_metrics

def apply_temperature_scaling(logits, temperature=2.0):
    logits = tf.convert_to_tensor(logits, dtype=tf.float32)
    return tf.nn.softmax(logits / temperature)

def laplace_smoothing(probs, epsilon=1e-5):
    probs = tf.convert_to_tensor(probs, dtype=tf.float32)
    return (probs + epsilon) / (tf.reduce_sum(probs, axis=-1, keepdims=True) + NUM_CLASSES * epsilon)

def save_confusion_matrix(y_true, y_pred, episode, qwk, save_dir, prefix=''):
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1, 2, 3, 4])
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=[0, 1, 2, 3, 4], yticklabels=[0, 1, 2, 3, 4])
    plt.title(f'Ma trận nhầm lẫn - {prefix}QWK: {qwk:.4f} tại Episode {episode+1}')
    plt.xlabel('Dự đoán')
    plt.ylabel('Thực tế')
    cm_path = os.path.join(save_dir, f'confusion_matrix_{prefix}qwk_episode_{episode+1}.png')
    plt.savefig(cm_path)
    plt.close()
    print(f"Đã lưu ma trận nhầm lẫn {prefix}QWK tại: {cm_path}")
    return cm

def compute_prototypes(features, labels, feature_model):
    features = feature_model(features, training=False)
    labels_arg = tf.argmax(labels, axis=1)
    prototypes = []
    for cls in range(NUM_CLASSES):
        cls_mask = tf.equal(labels_arg, cls)
        cls_features = tf.boolean_mask(features, cls_mask)
        prototype = tf.reduce_mean(cls_features, axis=0) if tf.shape(cls_features)[0] > 0 else \
                    tf.zeros([features.shape[1]], dtype=tf.float32)
        prototypes.append(prototype)
    return tf.stack(prototypes)

def prototypical_loss(query_features, query_labels, prototypes):
    query_labels_arg = tf.argmax(query_labels, axis=1)
    query_features_exp = tf.expand_dims(query_features, axis=1)
    prototypes_exp = tf.expand_dims(prototypes, axis=0)
    distances = tf.reduce_sum(tf.square(query_features_exp - prototypes_exp), axis=-1)
    logits = tf.cast(-distances, tf.float32)
    return tf.cast(tf.keras.losses.categorical_crossentropy(query_labels, logits, from_logits=True), tf.float32)

def prototypical_predict(query_features, prototypes):
    query_features_exp = tf.expand_dims(query_features, axis=1)
    prototypes_exp = tf.expand_dims(prototypes, axis=0)
    distances = tf.reduce_sum(tf.square(query_features_exp - prototypes_exp), axis=-1)
    return tf.cast(tf.nn.softmax(tf.cast(-distances, tf.float32)), tf.float32)

def reduce_lr_and_early_stop(episode, qwk, best_qwk, patience_lr, patience_stop, lr_patience_counter, stop_patience_counter,
                             inner_lr, outer_lr, fine_tune_lr, min_lr=1e-7, reduce_factor=0.5):
    stop_training = False
    if qwk > best_qwk:
        lr_patience_counter = 0
        stop_patience_counter = 0
    else:
        lr_patience_counter += 1
        stop_patience_counter += 1
    if lr_patience_counter >= patience_lr:
        inner_lr = max(inner_lr * reduce_factor, min_lr)
        outer_lr = max(outer_lr * reduce_factor, min_lr)
        fine_tune_lr = max(fine_tune_lr * reduce_factor, min_lr)
        lr_patience_counter = 0
        print(f"Episode {episode+1}: Giảm learning rate - inner_lr={inner_lr:.6f}, outer_lr={outer_lr:.6f}, fine_tune_lr={fine_tune_lr:.6f}")
    if stop_patience_counter >= patience_stop:
        stop_training = True
        print(f"Episode {episode+1}: Early stopping do QWK không cải thiện sau {patience_stop} episode")
    return inner_lr, outer_lr, fine_tune_lr, lr_patience_counter, stop_patience_counter, stop_training

# Trích xuất và lưu đặc trưng
train_features_dict = {}
valid_features_dict = {}
test_features_dict = {}
for model_name, config in model_configs.items():
    print(f"Xử lý mô hình: {model_name}")
    base_model = load_model_from_config(
        config['config_path'], config['weights_path'], config['base_model']
    )
    feature_layer = base_model.layers[-2].output if len(base_model.layers) > 1 else base_model.output
    feature_extractor = Model(inputs=base_model.input, outputs=feature_layer)
    feature_extractor.trainable = False
    train_generator = My_Generator(
        balanced_train_x, balanced_train_y_multi, batch_size=32, is_train=True,
        mix=True, augment=True, size1=SIZE, size2=config['img_size'],
        model_type=config['model_type'], preprocess=config['preprocess']
    )
    valid_generator = My_Generator(
        resized_valid_x, valid_y_multi, batch_size=32, is_train=False,
        mix=False, augment=False, size1=SIZE, size2=config['img_size'],
        model_type=config['model_type'], preprocess=config['preprocess']
    )
    test_generator = My_Generator(
        resized_test_x, test_y_multi, batch_size=32, is_train=False,
        mix=False, augment=False, size1=SIZE, size2=config['img_size'],
        model_type=config['model_type'], preprocess=config['preprocess']
    )
    train_features_2d = extract_and_save_features(
        model_name, feature_extractor, train_generator, feature_save_dir, sample_ids=train_x.values
    )
    valid_features_2d = extract_and_save_features(
        model_name, feature_extractor, valid_generator, feature_save_dir, sample_ids=valid_x.values
    )
    test_features_2d = extract_and_save_features(
        model_name, feature_extractor, test_generator, feature_save_dir, sample_ids=test_x.values
    )
    train_features_dict[model_name] = train_features_2d
    valid_features_dict[model_name] = valid_features_2d
    test_features_dict[model_name] = test_features_2d
    del base_model, feature_extractor
    tf.keras.backend.clear_session()
    gc.collect()

# Kết hợp và giảm chiều đặc trưng
train_combined_features, train_pca, train_valid_indices, input_dim = combine_and_reduce_features(
    train_features_dict, balanced_train_y_multi, train_x.values, feature_save_dir, n_components=50
)
valid_combined_features, _, valid_valid_indices, _ = combine_and_reduce_features(
    valid_features_dict, valid_y_multi, valid_x.values, feature_save_dir, n_components=50
)
test_combined_features, _, test_valid_indices, _ = combine_and_reduce_features(
    test_features_dict, test_y_multi, test_x.values, feature_save_dir, n_components=50
)
train_y_multi = balanced_train_y_multi[train_valid_indices]
valid_y_multi = valid_y_multi[valid_valid_indices]
test_y_multi = test_y_multi[test_valid_indices]

# Huấn luyện meta-model
print("Huấn luyện meta-model với MAML/FO-MAML...")
meta_model, meta_classification_model, history = maml_fomaml_train_manual(
    train_combined_features, train_y_multi, valid_combined_features, valid_y_multi,
    input_dim=input_dim, n_episodes=20, n_support=10, n_query=10, inner_lr=0.001,
    outer_lr=0.001, fine_tune_lr=0.0001, use_fomaml=True, memory_size=20,
    sample_ids=valid_x.values[valid_valid_indices]
)

# Đánh giá trên tập test
print("Đánh giá trên tập test...")
test_metrics = evaluate_test_set(
    meta_classification_model, test_combined_features, test_y_multi, feature_save_dir
)

print("Hoàn thành quy trình huấn luyện và đánh giá!")

Đã kích hoạt tăng trưởng bộ nhớ GPU
Google Drive đã được mount.
Số mẫu lớp 0: 1534
Phân bố nhãn ban đầu: {0: np.int64(0), 1: np.int64(314), 2: np.int64(849), 3: np.int64(164), 4: np.int64(251)}
Số mẫu mục tiêu mỗi lớp: 1534
Lớp 1: 314 mẫu ban đầu
Tăng cường 1220 mẫu cho lớp 1


  A.GaussNoise(var_limit=(10.0, 20.0), p=0.2),


Lớp 2: 849 mẫu ban đầu
Tăng cường 685 mẫu cho lớp 2
Lớp 3: 164 mẫu ban đầu
Tăng cường 1370 mẫu cho lớp 3
Lớp 4: 251 mẫu ban đầu
Tăng cường 1283 mẫu cho lớp 4
Phân bố nhãn sau cân bằng: {0: np.int64(0), 1: np.int64(1534), 2: np.int64(1534), 3: np.int64(1534), 4: np.int64(1534)}
Phân bố nhãn sau khi thêm lớp 0: {0: np.int64(1534), 1: np.int64(1534), 2: np.int64(1534), 3: np.int64(1534), 4: np.int64(1534)}
balanced_train_x shape: (7670, 244, 244, 3)
balanced_train_y_multi shape: (7670, 5)
Xử lý mô hình: efficientnetb1
Khởi tạo My_Generator: 7670 mẫu, target_size=(244, 244), is_train=True
Khởi tạo My_Generator: 550 mẫu, target_size=(244, 244), is_train=False
Khởi tạo My_Generator: 733 mẫu, target_size=(244, 244), is_train=False
Đã lưu đặc trưng 2D tại: /content/drive/MyDrive/working/efficientnetb1_features_2d.npy, shape=(7670, 1024)
Đã lưu metadata tại: /content/drive/MyDrive/working/efficientnetb1_features_metadata.json
Đã lưu đặc trưng 2D tại: /content/drive/MyDrive/working/efficientnetb

In [None]:
print(processed_ids)

['cb2f3c5d71a7.png', '5c6194562ed2.png', 'e893e86dde94.png', '58059e73d2d4.png', '8ee50c26fc13.png', '6cee2e148520.png', 'd035c2bd9104.png', '22098b1fe461.png', '0a1076183736.png', 'b576c5269ad1.png', 'f4d3777f2710.png', 'd881c04f01fe.png', '14c3b41d289c.png', 'cb68fce07789.png', 'fcc55ae641ae.png', '64c6c6ee0d98.png', '4cae247d9909.png', '6b7cf869622a.png', 'e7291472109b.png', '90bde2ff8953.png', 'c0a117de7d0a.png', '7c2e852171c0.png', '42af7282349b.png', '674057ab250c.png', '8688f3d0fcaf.png', 'd6f6bdfd8011.png', '7adfb8fc0621.png', 'cc3d2e961768.png', '8958a4d17b7e.png', '26999ebc21de.png', '07a1c7073982.png', 'a8dea22ef903.png', '41345cec5957.png', '35beb47fe159.png', '79ade634c633.png', '7e4019ac7f5a.png', '79d44db3da2d.png', 'a4d41c495666.png', '415d5c5e785f.png', '6c6505a0c637.png', '0fb1053285cf.png', '838b3e4d0bb4.png', '13d411c85ffd.png', '8e76054f0831.png', '32d7d360d891.png', '6f460f9968c7.png', 'd801c0a66738.png', '38e0e28d35d3.png', '51269b77d312.png', '2927665214e1.png',