In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import transforms, datasets
from torchvision.models import densenet121
import pandas as pd
import numpy as np
import os
from PIL import Image
from sklearn.metrics import f1_score
from tqdm import tqdm
import cv2
import random
from collections import Counter
from sklearn.model_selection import StratifiedKFold

# =============================================================================
# 1. FUNGSI HELPER DAN PENGATURAN SEED
# =============================================================================

def set_seed(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

SEED = 42
set_seed(SEED)
print(f"Random seed diatur ke {SEED}")

# =============================================================================
# 2. DEFINISI MODEL, DATASET, DAN FUNGSI HELPER
# =============================================================================

# --- DEFINISI MODEL CHEXNET ---
class CheXNetModel(nn.Module):
    def _init_(self, num_classes, pretrained=True, checkpoint_path=None):
        super(CheXNetModel, self)._init_()
        self.densenet = torch.hub.load('pytorch/vision:v0.10.0', 'densenet121', pretrained=pretrained)
        num_features = self.densenet.classifier.in_features
        self.densenet.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(num_features, num_classes)
        )
        if checkpoint_path and os.path.exists(checkpoint_path):
            self.load_pretrained_weights(checkpoint_path)

    def load_pretrained_weights(self, checkpoint_path):
        try:
            checkpoint = torch.load(checkpoint_path, map_location='cpu')
            state_dict = checkpoint.get('state_dict', checkpoint.get('model', checkpoint))
            new_state_dict = {k[7:] if k.startswith('module.') else k: v for k, v in state_dict.items()}
            model_dict = self.densenet.state_dict()
            pretrained_dict = {k: v for k, v in new_state_dict.items() if k in model_dict and 'classifier' not in k}
            model_dict.update(pretrained_dict)
            self.densenet.load_state_dict(model_dict, strict=False)
            print(f"Pre-trained CheXNet weights loaded from {checkpoint_path}")
        except Exception as e:
            print(f"Could not load CheXNet weights: {e}. Using ImageNet weights instead.")

    def forward(self, x):
        return self.densenet(x)

# --- DEFINISI DATASET CUSTOM (Versi yang mendukung is_test) ---
class CustomImageDataset(Dataset):
    def _init_(self, root_dir, transform=None, is_test=False):
        self.root_dir = root_dir
        self.transform = transform
        self.is_test = is_test
        
        if not is_test:
            self.dataset = datasets.ImageFolder(root_dir)
            self.samples = self.dataset.samples
            self.classes = self.dataset.classes
        else:
            self.image_paths = sorted([os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
    
    def _len_(self):
        return len(self.image_paths) if self.is_test else len(self.samples)
    
    def _getitem_(self, idx):
        if self.is_test:
            img_path = self.image_paths[idx]
            image = Image.open(img_path).convert('RGB')
            filename = os.path.basename(img_path)
            # Untuk TTA, transformasi diterapkan saat prediksi, bukan di sini
            return image, filename
        else:
            img_path, label = self.samples[idx]
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image, label

# --- CUSTOM TRANSFORM UNTUK CLAHE ---
class ApplyCLAHE:
    def _init_(self, clip_limit=2.0, tile_grid_size=(8, 8)):
        self.clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size)
    def _call_(self, img):
        img_np = np.array(img)
        if len(img_np.shape) == 3:
            lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
            l, _, _ = cv2.split(lab)
            l_clahe = self.clahe.apply(l)
            lab_clahe = cv2.merge((l_clahe, cv2.split(lab)[1], cv2.split(lab)[2]))
            return Image.fromarray(cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2RGB))
        return Image.fromarray(self.clahe.apply(img_np))

# --- DEFINISI FOCAL LOSS ---
class FocalLoss(nn.Module):
    def _init_(self, gamma=2.0, alpha=None, reduction='mean'):
        super(FocalLoss, self)._init_()
        self.gamma, self.alpha, self.reduction = gamma, alpha, reduction
        if isinstance(alpha, list): self.alpha = torch.tensor(alpha)

    def forward(self, inputs, targets):
        ce_loss = nn.functional.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma * ce_loss)
        if self.alpha is not None:
            if self.alpha.device != inputs.device: self.alpha = self.alpha.to(inputs.device)
            at = self.alpha.gather(0, targets)
            focal_loss = at * focal_loss
        return focal_loss.mean() if self.reduction == 'mean' else focal_loss.sum()

# =============================================================================
# 3. KONFIGURASI DAN PERSIAPAN DATA
# =============================================================================
IMG_SIZE = (384, 384)
BATCH_SIZE = 16
EPOCHS = 15
LEARNING_RATE = 0.001
N_SPLITS = 4

TRAIN_PATH = "/kaggle/input/final-lung-disease/train/train"
TEST_PATH = "/kaggle/input/final-lung-disease/test/test"
CHEXNET_WEIGHTS = '/kaggle/input/chexnet-weights/brucechou1983_CheXNet_Keras_0.3.0_weights.h5'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

train_transform = transforms.Compose([
    transforms.Resize(IMG_SIZE), ApplyCLAHE(), transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10), transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
    transforms.Resize(IMG_SIZE), ApplyCLAHE(), transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

full_dataset_obj = datasets.ImageFolder(TRAIN_PATH)
class_names = full_dataset_obj.classes
num_classes = len(class_names)
print(f"Classes: {class_names} ({num_classes} classes)")

X = np.arange(len(full_dataset_obj.samples))
y = np.array(full_dataset_obj.targets)

# =============================================================================
# 4. FUNGSI TRAINING & VALIDATION
# =============================================================================
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss, all_preds, all_labels = 0.0, [], []
    for images, labels in tqdm(loader, desc="Training"):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    return running_loss / len(loader.dataset), f1_score(all_labels, all_preds, average='macro')

def validate_epoch(model, loader, criterion, device):
    model.eval()
    running_loss, all_preds, all_labels = 0.0, [], []
    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Validation"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    return running_loss / len(loader.dataset), f1_score(all_labels, all_preds, average='macro')

# =============================================================================
# 5. LOOP CROSS-VALIDATION
# =============================================================================
skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)
all_folds_best_f1 = []
saved_model_paths = []

for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
    print(f"\n{'='*20} FOLD {fold + 1}/{N_SPLITS} {'='*20}")
    
    train_subset = Subset(CustomImageDataset(TRAIN_PATH, transform=train_transform), train_idx)
    val_subset = Subset(CustomImageDataset(TRAIN_PATH, transform=val_transform), val_idx)
    train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
    
    train_labels = y[train_idx]
    class_counts = Counter(train_labels)
    class_weights = [len(train_labels) / (num_classes * class_counts[i]) for i in range(num_classes)]
    class_weights_tensor = torch.tensor(class_weights, dtype=torch.float).to(device)
    print(f"Class weights for this fold: {np.round(class_weights, 2)}")
    
    model = CheXNetModel(num_classes=num_classes, checkpoint_path=CHEXNET_WEIGHTS).to(device)
    criterion = FocalLoss(gamma=2.0, alpha=class_weights_tensor)
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE / 10, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3, verbose=True)
    
    best_fold_f1 = 0.0
    model_save_path = f'best_model_fold_{fold+1}.pth'

    for epoch in range(EPOCHS):
        print(f"\n--- Epoch {epoch + 1}/{EPOCHS} ---")
        train_loss, train_f1 = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_f1 = validate_epoch(model, val_loader, criterion, device)
        print(f"Train Loss: {train_loss:.4f} | Train F1: {train_f1:.4f}")
        print(f"Val Loss  : {val_loss:.4f} | Val F1  : {val_f1:.4f}")
        scheduler.step(val_f1)
        if val_f1 > best_fold_f1:
            best_fold_f1 = val_f1
            torch.save(model.state_dict(), model_save_path)
            print(f"🚀 Model saved to {model_save_path} with F1: {best_fold_f1:.4f}")

    all_folds_best_f1.append(best_fold_f1)
    saved_model_paths.append(model_save_path)
    print(f"\nBest F1 for Fold {fold + 1} was: {best_fold_f1:.4f}")

# =============================================================================
# 6. HASIL CROSS-VALIDATION
# =============================================================================
print(f"\n\n{'='*20} HASIL CROSS-VALIDATION {'='*20}")
average_f1 = np.mean(all_folds_best_f1)
print(f"📊 Rata-rata F1-score terbaik dari {N_SPLITS}-fold CV: {average_f1:.4f}")
print("=" * 55)

# =============================================================================
# 7. PREDIKSI FINAL DENGAN ENSEMBLE DARI 4-FOLD MODELS
# =============================================================================
print("\nMemulai proses prediksi dengan TTA dan ensemble dari 4 model...")

# --- Muat semua model yang telah dilatih ---
ensemble_models = []
for path in saved_model_paths:
    model = CheXNetModel(num_classes=num_classes, pretrained=False).to(device) # pretrained=False karena kita load state dict
    model.load_state_dict(torch.load(path))
    model.eval()
    ensemble_models.append(model)
print(f"✅ Berhasil memuat {len(ensemble_models)} model untuk ensembling.")

# --- Definisikan transformasi TTA ---
tta_transforms = [
    transforms.Compose([ # Base
        transforms.Resize(IMG_SIZE), ApplyCLAHE(), transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
    transforms.Compose([ # Horizontal Flip
        transforms.Resize(IMG_SIZE), ApplyCLAHE(), transforms.RandomHorizontalFlip(p=1.0), transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
    transforms.Compose([ # Rotation
        transforms.Resize(IMG_SIZE), ApplyCLAHE(), transforms.RandomRotation(degrees=10), transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
]

# --- Siapkan data test ---
test_dataset = CustomImageDataset(TEST_PATH, is_test=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False) # Batch size 1 untuk TTA

# --- Jalankan prediksi ---
predictions_indices, filenames = [], []
with torch.no_grad():
    for image, filename in tqdm(test_loader, desc="Prediksi Test"):
        image = image[0] # Hapus dimensi batch
        filename = filename[0]
        
        ensemble_probs = torch.zeros(1, num_classes).to(device)
        
        for model in ensemble_models:
            tta_probs = torch.zeros(1, num_classes).to(device)
            for t_transform in tta_transforms:
                transformed_image = t_transform(image).unsqueeze(0).to(device)
                outputs = model(transformed_image)
                tta_probs += nn.functional.softmax(outputs, dim=1)
            ensemble_probs += (tta_probs / len(tta_transforms))
            
        final_probs = ensemble_probs / len(ensemble_models)
        _, predicted_idx = final_probs.max(1)
        
        predictions_indices.append(predicted_idx.cpu().item())
        filenames.append(filename)

# --- Buat file submission ---
class_to_number = {'COVID': 0, 'Normal': 1, 'Viral Pneumonia': 2}
final_predictions = [class_to_number.get(class_names[idx], -1) for idx in predictions_indices]
filenames_no_ext = [f.replace('.png', '') for f in filenames]

results_df = pd.DataFrame({'ID': filenames_no_ext, 'Predicted': final_predictions})
submission_path = 'submission.csv'
results_df.to_csv(submission_path, index=False)

print(f"\n✅ Prediksi selesai! File disimpan di: {submission_path}")
print("\nPreview 10 prediksi pertama:")
print(results_df.head(10))

print("\nDistribusi prediksi:")
print(results_df['Predicted'].value_counts(normalize=True).sort_index() * 100)