In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import transforms, datasets
from torchvision.models import densenet121
import pandas as pd
import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report
from sklearn.model_selection import StratifiedKFold
from tqdm import tqdm
import cv2
import random
import seaborn as sns
from torchvision.transforms import InterpolationMode

# =============================================================================
# 1) SEED & HELPER
# =============================================================================
def set_seed(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

SEED = 42
set_seed(SEED)
print(f"Random seed diatur ke {SEED}")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Menggunakan device: {device}")

# =============================================================================
# 2) DATASET & TRANSFORM (mempertahankan teknik pemrosesan Anda)
# =============================================================================
class ImageDataset(Dataset):
    def __init__(self, root_dir, transform=None, is_test=False):
        self.root_dir = root_dir
        self.transform = transform
        self.is_test = is_test
        
        if not is_test:
            self.dataset = datasets.ImageFolder(root_dir)
            self.samples = self.dataset.samples
            self.classes = self.dataset.classes
        else:
            self.image_paths = []
            for file in sorted(os.listdir(root_dir)):
                if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                    self.image_paths.append(os.path.join(root_dir, file))
    
    def __len__(self):
        if self.is_test:
            return len(self.image_paths)
        return len(self.samples)
    
    def __getitem__(self, idx):
        if self.is_test:
            img_path = self.image_paths[idx]
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image, os.path.basename(img_path)
        else:
            img_path, label = self.samples[idx]
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image, label

class ApplyCLAHE(object):
    def __init__(self, clip_limit=2.0, tile_grid_size=(8, 8)):
        self.clip_limit = clip_limit
        self.tile_grid_size = tile_grid_size
        self.clahe = cv2.createCLAHE(clipLimit=self.clip_limit, tileGridSize=self.tile_grid_size)

    def __call__(self, img):
        img_np = np.array(img)
        if len(img_np.shape) == 3 and img_np.shape[2] == 3:
            lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
            l, a, b = cv2.split(lab)
            l_clahe = self.clahe.apply(l)
            lab_clahe = cv2.merge((l_clahe, a, b))
            img_clahe = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2RGB)
            return Image.fromarray(img_clahe)
        else:
            gray_clahe = self.clahe.apply(img_np)
            return Image.fromarray(gray_clahe)

# =============================================================================
# 3) KONFIGURASI
# =============================================================================
IMG_SIZE = (384, 384)
BATCH_SIZE = 16
EPOCHS = 15
LEARNING_RATE = 1e-3
NUM_FOLDS = 4

# Path data & bobot model 5-kelas lama
TRAIN_PATH = "/kaggle/input/final-srifoton-25-machine-learning-competition/train/train"  # sekarang berisi 3 kelas
TEST_PATH  = "/kaggle/input/final-srifoton-25-machine-learning-competition/test/test"
PREV_5CLS_CHECKPOINT = "/kaggle/input/weights/model_5cls.pth"  # GANTI dengan path bobot lama Anda

train_transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    ApplyCLAHE(clip_limit=2.0),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomAffine(
        degrees=10,
        translate=(0.05, 0.05),
        scale=(0.95, 1.05),
        shear=(-5, 5, -5, 5),
        interpolation=InterpolationMode.BILINEAR,
        fill=0),
    transforms.ColorJitter(brightness=0.15, contrast=0.15),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_test_transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    ApplyCLAHE(clip_limit=2.0),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

print("Memuat dataset...")
full_train_dataset = ImageDataset(TRAIN_PATH, transform=None)
class_names = full_train_dataset.classes
num_classes = len(class_names)
print(f"Kelas: {class_names} (n={num_classes})")

samples = full_train_dataset.samples
paths, labels = zip(*samples)
labels = np.array(labels)

test_dataset = ImageDataset(TEST_PATH, transform=val_test_transform, is_test=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
print(f"Test batches: {len(test_loader)}")

# =============================================================================
# 4) MEMBANGUN MODEL DARI BOOTSTRAP 5-KELAS → GANTI HEAD → 3-KELAS
# =============================================================================
def load_any_checkpoint(ckpt_path, map_location='cpu'):
    ckpt = torch.load(ckpt_path, map_location=map_location)
    if isinstance(ckpt, dict):
        if 'state_dict' in ckpt:
            return ckpt['state_dict']
        elif 'model' in ckpt:
            return ckpt['model']
        else:
            return ckpt
    return ckpt

def build_backbone_from_checkpoint(ckpt_path, num_classes_new):
    """
    Asumsi backbone lama DenseNet121. 
    - Load state_dict lama (5 kelas).
    - Salin weight 'features' ke backbone baru.
    - Ganti classifier → num_classes_new.
    """
    # 1) Siapkan model backbone baru
    model = densenet121(pretrained=True)
    in_features = model.classifier.in_features
    model.classifier = nn.Sequential(
        nn.Dropout(0.5),
        nn.Linear(in_features, num_classes_new)
    )
    model = model.to(device)

    # 2) Muat state_dict lama dan mapping
    try:
        old_sd = load_any_checkpoint(ckpt_path, map_location='cpu')

        # Hilangkan prefix 'module.' jika ada (DDP)
        new_sd = {}
        for k, v in old_sd.items():
            name = k[7:] if k.startswith('module.') else k
            new_sd[name] = v

        # 3) Ambil hanya weight 'features' (abaikan classifier lama 5 kelas)
        model_sd = model.state_dict()
        transferable = {k: v for k, v in new_sd.items() if k.startswith('features.') and k in model_sd}
        model_sd.update(transferable)
        model.load_state_dict(model_sd, strict=False)
        print(f"Backbone terinisialisasi dari checkpoint 5-kelas: {ckpt_path}")
    except Exception as e:
        print(f"PERINGATAN: Gagal memuat backbone dari checkpoint: {e}")
        print("Model akan mulai dari ImageNet pretrained.")

    return model

# Wrapper untuk training
class MyBackboneFineTune(nn.Module):
    def __init__(self, num_classes, prev_ckpt):
        super().__init__()
        self.model = build_backbone_from_checkpoint(prev_ckpt, num_classes)

    def forward(self, x):
        return self.model(x)

# =============================================================================
# 5) TRAINING UTILS
# =============================================================================
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    all_preds, all_labels = [], []

    for images, labels in tqdm(train_loader, desc="Training"):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        all_preds.extend(predicted.detach().cpu().numpy())
        all_labels.extend(labels.detach().cpu().numpy())

    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100. * correct / total
    epoch_f1 = f1_score(all_labels, all_preds, average='macro')
    return epoch_loss, epoch_acc, epoch_f1

def validate_epoch(model, val_loader, criterion, device):
    model.eval()
    running_loss, correct, total = 0.0, 0, 0
    all_preds, all_labels = [], []

    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc="Validation"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            all_preds.extend(predicted.detach().cpu().numpy())
            all_labels.extend(labels.detach().cpu().numpy())

    epoch_loss = running_loss / len(val_loader)
    epoch_acc = 100. * correct / total
    epoch_f1 = f1_score(all_labels, all_preds, average='macro')
    return epoch_loss, epoch_acc, epoch_f1

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

# =============================================================================
# 6) K-FOLD FINE-TUNING (dari backbone 5-kelas → 3-kelas)
# =============================================================================
skf = StratifiedKFold(n_splits=NUM_FOLDS, shuffle=True, random_state=SEED)
model_paths, fold_val_accs, fold_val_f1s = [], [], []

for fold, (train_idx, val_idx) in enumerate(skf.split(paths, labels)):
    print(f"\n=== Fold {fold+1}/{NUM_FOLDS} ===")

    # Subset
    train_subset = Subset(ImageDataset(TRAIN_PATH, transform=train_transform), train_idx)
    val_subset   = Subset(ImageDataset(TRAIN_PATH, transform=val_test_transform), val_idx)
    train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True,  num_workers=4)
    val_loader   = DataLoader(val_subset,   batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

    # Model dari bobot 5-kelas → head 3-kelas
    model = MyBackboneFineTune(num_classes=num_classes, prev_ckpt=PREV_5CLS_CHECKPOINT).to(device)

    # Stage 1: freeze fitur, latih classifier (warm-up)
    for n, p in model.named_parameters():
        if 'classifier' in n:
            p.requires_grad = True
        else:
            p.requires_grad = False

    opt_head = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LEARNING_RATE)
    epochs_head = 5
    for ep in range(epochs_head):
        print(f"[Fold {fold+1}] Head Epoch {ep+1}/{epochs_head}")
        tr_loss, tr_acc, tr_f1 = train_epoch(model, train_loader, criterion, opt_head, device)
        va_loss, va_acc, va_f1 = validate_epoch(model, val_loader, criterion, device)
        print(f"Head Train: loss={tr_loss:.4f} acc={tr_acc:.2f}% f1={tr_f1:.4f} | "
              f"Val: loss={va_loss:.4f} acc={va_acc:.2f}% f1={va_f1:.4f}")

    # Stage 2: unfreeze semua, fine-tune halus
    for p in model.parameters():
        p.requires_grad = True

    opt_full = optim.AdamW(model.parameters(), lr=LEARNING_RATE/10, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(opt_full, mode='max', factor=0.5, patience=3, verbose=True)

    train_losses, train_accs, train_f1s = [], [], []
    val_losses, val_accs, val_f1s = [], [], []
    best_val_acc = 0.0
    best_model_path = f'best_model_fold{fold+1}.pth'
    patience, patience_counter = 7, 0

    for ep in range(EPOCHS):
        print(f"[Fold {fold+1}] FT Epoch {ep+1}/{EPOCHS}")
        tr_loss, tr_acc, tr_f1 = train_epoch(model, train_loader, criterion, opt_full, device)
        va_loss, va_acc, va_f1 = validate_epoch(model, val_loader, criterion, device)

        train_losses.append(tr_loss); train_accs.append(tr_acc); train_f1s.append(tr_f1)
        val_losses.append(va_loss);   val_accs.append(va_acc);   val_f1s.append(va_f1)

        scheduler.step(va_acc)
        print(f"Train: loss={tr_loss:.4f} acc={tr_acc:.2f}% f1={tr_f1:.4f} | "
              f"Val: loss={va_loss:.4f} acc={va_acc:.2f}% f1={va_f1:.4f}")

        if va_acc > best_val_acc:
            best_val_acc = va_acc
            torch.save(model.state_dict(), best_model_path)
            patience_counter = 0
            print(f"** Best Val Acc updated: {best_val_acc:.2f}% (model disimpan: {best_model_path})")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(">> Early stopping.")
                break

    model_paths.append(best_model_path)
    fold_val_accs.append(best_val_acc)
    fold_val_f1s.append(max(val_f1s))

# Ringkasan CV
print("\n=== Hasil Cross-Validation ===")
print(f"Rata-rata Val Acc: {np.mean(fold_val_accs):.2f}%")
print(f"Rata-rata Val F1 : {np.mean(fold_val_f1s):.4f}")

# =============================================================================
# 7) MUAT MODEL TERBAIK PER FOLD UNTUK EVALUASI & ENSEMBLE
# =============================================================================
def load_finetuned_model(path_ckpt, num_classes):
    m = MyBackboneFineTune(num_classes=num_classes, prev_ckpt=PREV_5CLS_CHECKPOINT).to(device)
    # Muat head yang sudah fine-tuned 3 kelas
    sd = torch.load(path_ckpt, map_location='cpu')
    m.load_state_dict(sd, strict=True)
    m.eval()
    return m

models = [load_finetuned_model(p, num_classes) for p in model_paths]

# (Opsional) Evaluasi Confusion Matrix dengan val_loader dari fold terakhir
print("\nEvaluasi (last fold) pada validation set...")
last_fold_val_idx = list(skf.split(paths, labels))[-1][1]
val_subset_last = Subset(ImageDataset(TRAIN_PATH, transform=val_test_transform), last_fold_val_idx)
val_loader_last = DataLoader(val_subset_last, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

all_preds, all_labels = [], []
with torch.no_grad():
    for images, labels in tqdm(val_loader_last, desc="Validating"):
        images = images.to(device)
        # rata-rata ensemble prediksi logits → softmax
        logits_sum = None
        for m in models:
            out = m(images)
            logits_sum = out if logits_sum is None else (logits_sum + out)
        probs = torch.softmax(logits_sum / len(models), dim=1)
        pred = probs.argmax(1)
        all_preds.extend(pred.cpu().numpy())
        all_labels.extend(labels.numpy())

print("\n" + "="*50)
print("Classification Report (Last Fold, Ensemble)")
print("="*50)
print(classification_report(all_labels, all_preds, target_names=class_names))

cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(8,6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted'); plt.ylabel('True')
plt.title('Confusion Matrix (Last Fold, Ensemble)')
plt.tight_layout()
plt.show()

# =============================================================================
# 8) PREDIKSI TEST DENGAN TTA + ENSEMBLE
# =============================================================================
base_transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    ApplyCLAHE(clip_limit=2.0),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

hflip_transform = transforms.Compose([
    transforms.Resize((384, 384)),
    ApplyCLAHE(clip_limit=2.0),
    transforms.Lambda(lambda im: F.hflip(im)),  # selalu flip
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])

rotate_transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    ApplyCLAHE(clip_limit=2.0),
    transforms.RandomRotation(degrees=15),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

tta_transforms = [base_transform, hflip_transform, rotate_transform]
print(f"\nMenggunakan {len(tta_transforms)} transformasi untuk TTA per model.")

# Dapatkan daftar file test
test_image_paths = sorted([os.path.join(TEST_PATH, f)
                           for f in os.listdir(TEST_PATH)
                           if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
print(f"Menemukan {len(test_image_paths)} gambar test.")

predictions_indices, filenames = [], []

with torch.no_grad():
    for image_path in tqdm(test_image_paths, desc="Testing with TTA + Ensemble"):
        fn = os.path.basename(image_path)
        img = Image.open(image_path).convert('RGB')

        # Ensemble akumulasi probabilitas
        ensemble_probs = torch.zeros(1, num_classes).to(device)

        for m in models:
            tta_probs = torch.zeros(1, num_classes).to(device)
            for tta in tta_transforms:
                x = tta(img).unsqueeze(0).to(device)
                out = m(x)
                probs = torch.softmax(out, dim=1)
                tta_probs += probs
            ensemble_probs += (tta_probs / len(tta_transforms))

        avg_probs = ensemble_probs / len(models)
        pred_idx = avg_probs.argmax(1).item()

        predictions_indices.append(pred_idx)
        filenames.append(fn)

# Mapping kelas → angka mengikuti urutan class_names
class_to_number = {name: i for i, name in enumerate(class_names)}

final_numbers = [pred_idx for pred_idx in predictions_indices]
filenames_no_ext = [os.path.splitext(f)[0] for f in filenames]

results_df = pd.DataFrame({
    'ID': filenames_no_ext,
    'Predicted': final_numbers
})
submission_path = 'submission.csv'
results_df.to_csv(submission_path, index=False)

print(f"\n✅ Prediksi selesai! File disimpan di: {submission_path}")
print("Preview 10 baris:")
print(results_df.head(10))

print("\nDistribusi prediksi:")
counts = results_df['Predicted'].value_counts().sort_index()
percentages = results_df['Predicted'].value_counts(normalize=True).sort_index() * 100
distribution = pd.DataFrame({'Count': counts, 'Percentage': percentages.round(2)})
print(distribution)
