In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import pandas as pd
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Paths - sesuaikan dengan struktur folder Anda
TRAIN_PATH = "path/to/train"  # Ganti dengan path train folder Anda
VAL_PATH = "path/to/val"      # Ganti dengan path val folder Anda  
TEST_PATH = "path/to/test"    # Ganti dengan path test folder Anda

# Hyperparameters
BATCH_SIZE = 16
LEARNING_RATE = 1e-4
NUM_EPOCHS = 50
NUM_CLASSES = 5
PATIENCE = 10  # Early stopping patience
IMG_SIZE = 224

class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None, is_test=False):
        self.root_dir = root_dir
        self.transform = transform
        self.is_test = is_test
        self.samples = []
        
        if is_test:
            # Untuk test set, ambil semua file gambar
            for file_name in os.listdir(root_dir):
                if file_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                    self.samples.append((os.path.join(root_dir, file_name), file_name))
        else:
            # Untuk train/val set, ambil berdasarkan subfolder kelas
            self.class_to_idx = {}
            idx = 0
            
            for class_name in sorted(os.listdir(root_dir)):
                class_path = os.path.join(root_dir, class_name)
                if os.path.isdir(class_path):
                    self.class_to_idx[class_name] = idx
                    
                    for file_name in os.listdir(class_path):
                        if file_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                            file_path = os.path.join(class_path, file_name)
                            self.samples.append((file_path, idx))
                    idx += 1
                            
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        if self.is_test:
            img_path, file_name = self.samples[idx]
            image = Image.open(img_path).convert('RGB')
            
            if self.transform:
                image = self.transform(image)
                
            return image, file_name
        else:
            img_path, label = self.samples[idx]
            image = Image.open(img_path).convert('RGB')
            
            if self.transform:
                image = self.transform(image)
                
            return image, label

# Data transforms dengan augmentasi untuk training
train_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomPerspective(distortion_scale=0.2, p=0.3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create datasets
print("Creating datasets...")
train_dataset = CustomDataset(TRAIN_PATH, transform=train_transforms, is_test=False)
val_dataset = CustomDataset(VAL_PATH, transform=val_transforms, is_test=False)
test_dataset = CustomDataset(TEST_PATH, transform=test_transforms, is_test=True)

print(f"Train samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

# DINOv2 Model Class
class DINOv2Classifier(nn.Module):
    def __init__(self, num_classes=5, model_name='dinov2_vits14'):
        super(DINOv2Classifier, self).__init__()
        
        # Load DINOv2 model
        self.backbone = torch.hub.load('facebookresearch/dinov2', model_name)
        
        # Freeze backbone parameters
        for param in self.backbone.parameters():
            param.requires_grad = False
            
        # Get feature dimension
        if 'vits14' in model_name:
            feature_dim = 384
        elif 'vitb14' in model_name:
            feature_dim = 768
        elif 'vitl14' in model_name:
            feature_dim = 1024
        elif 'vitg14' in model_name:
            feature_dim = 1536
        else:
            feature_dim = 384  # default
            
        # Classification head
        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(feature_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes)
        )
        
    def forward(self, x):
        # Extract features from DINOv2
        with torch.no_grad():
            features = self.backbone(x)
        
        # Classification
        output = self.classifier(features)
        return output

# Initialize model
print("Loading DINOv2 model...")
model = DINOv2Classifier(num_classes=NUM_CLASSES, model_name='dinov2_vits14')
model = model.to(device)

print(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.classifier.parameters(), lr=LEARNING_RATE, weight_decay=0.01)

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)

# Training function
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    train_bar = tqdm(train_loader, desc='Training')
    for images, labels in train_bar:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        train_bar.set_postfix({
            'Loss': f'{running_loss/(train_bar.n+1):.4f}',
            'Acc': f'{100.*correct/total:.2f}%'
        })
    
    return running_loss / len(train_loader), 100. * correct / total

# Validation function
def validate(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        val_bar = tqdm(val_loader, desc='Validation')
        for images, labels in val_bar:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            val_bar.set_postfix({
                'Loss': f'{running_loss/(val_bar.n+1):.4f}',
                'Acc': f'{100.*correct/total:.2f}%'
            })
    
    return running_loss / len(val_loader), 100. * correct / total

# Training loop
print("Starting training...")
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []

best_val_acc = 0.0
patience_counter = 0

for epoch in range(NUM_EPOCHS):
    print(f'\nEpoch {epoch+1}/{NUM_EPOCHS}')
    print('-' * 50)
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    train_losses.append(train_loss)
    train_accuracies.append(train_acc)
    
    # Validate
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)
    
    # Learning rate scheduler
    scheduler.step(val_loss)
    
    print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
    print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), 'best_dinov2_lung_classifier.pth')
        print(f'New best model saved with validation accuracy: {best_val_acc:.2f}%')
        patience_counter = 0
    else:
        patience_counter += 1
    
    # Early stopping
    if patience_counter >= PATIENCE:
        print(f'Early stopping triggered after {epoch+1} epochs')
        break

# Load best model for testing
print("\nLoading best model for testing...")
model.load_state_dict(torch.load('best_dinov2_lung_classifier.pth'))

# Plot training history
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()

plt.subplot(1, 3, 2)
plt.plot(train_accuracies, label='Train Accuracy')
plt.plot(val_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Training and Validation Accuracy')
plt.legend()

# Final validation evaluation
print("\nFinal validation evaluation...")
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in tqdm(val_loader, desc='Final Validation'):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Classification report
print("\nClassification Report:")
print(classification_report(all_labels, all_preds))

# Confusion Matrix
plt.subplot(1, 3, 3)
cm = confusion_matrix(all_labels, all_preds)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('Actual')

plt.tight_layout()
plt.savefig('training_results.png', dpi=300, bbox_inches='tight')
plt.show()

# Test Time Augmentation transforms
tta_transforms = [
    # Original
    transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    # Horizontal flip
    transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomHorizontalFlip(p=1.0),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    # Rotation +10 degrees
    transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomRotation(degrees=(10, 10)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    # Rotation -10 degrees
    transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomRotation(degrees=(-10, -10)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    # Brightness adjustment
    transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ColorJitter(brightness=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    # Scale variation
    transforms.Compose([
        transforms.Resize((int(IMG_SIZE * 1.1), int(IMG_SIZE * 1.1))),
        transforms.CenterCrop(IMG_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
]

def predict_with_tta(model, image_path, transforms_list, device):
    """
    Perform prediction with Test Time Augmentation
    """
    model.eval()
    predictions = []
    
    with torch.no_grad():
        for transform in transforms_list:
            # Load and transform image
            image = Image.open(image_path).convert('RGB')
            image_tensor = transform(image).unsqueeze(0).to(device)
            
            # Get prediction
            output = model(image_tensor)
            probabilities = torch.softmax(output, dim=1)
            predictions.append(probabilities.cpu().numpy())
    
    # Average predictions across all augmentations
    avg_predictions = np.mean(predictions, axis=0)
    predicted_class = np.argmax(avg_predictions)
    confidence = np.max(avg_predictions)
    
    return predicted_class, confidence, avg_predictions[0]

# Test predictions with TTA
print("\nGenerating test predictions with Test Time Augmentation...")
print("Using TTA with following augmentations:")
print("- Original image")
print("- Horizontal flip") 
print("- Rotation +10°")
print("- Rotation -10°")
print("- Brightness adjustment")
print("- Scale variation")

model.eval()
predictions = []
filenames = []
confidences = []
all_probabilities = []

# Get list of test files
test_files = []
for file_name in os.listdir(TEST_PATH):
    if file_name.lower().endswith(('.png', '.jpg', '.jpeg')):
        test_files.append(file_name)

test_files = sorted(test_files)

# Process each test file with TTA
for file_name in tqdm(test_files, desc='TTA Prediction'):
    file_path = os.path.join(TEST_PATH, file_name)
    
    # Get TTA prediction
    predicted_class, confidence, probabilities = predict_with_tta(
        model, file_path, tta_transforms, device
    )
    
    predictions.append(predicted_class)
    filenames.append(file_name)
    confidences.append(confidence)
    all_probabilities.append(probabilities)

# Create submission DataFrame with confidence scores
submission_df = pd.DataFrame({
    'Id': filenames,
    'Predicted': predictions,
    'Confidence': confidences
})

# Create detailed probabilities DataFrame (optional, for analysis)
prob_columns = [f'Prob_Class_{i}' for i in range(NUM_CLASSES)]
probabilities_df = pd.DataFrame(all_probabilities, columns=prob_columns)
probabilities_df['Id'] = filenames
probabilities_df['Predicted'] = predictions

# Sort by Id for consistency
submission_df = submission_df.sort_values('Id').reset_index(drop=True)
probabilities_df = probabilities_df.sort_values('Id').reset_index(drop=True)

# Save main submission file (Kaggle format)
kaggle_submission = submission_df[['Id', 'Predicted']].copy()
kaggle_submission.to_csv('submission_tta.csv', index=False)

# Save detailed results with confidence scores
submission_df.to_csv('submission_tta_detailed.csv', index=False)

# Save probabilities for further analysis
probabilities_df.to_csv('test_probabilities_tta.csv', index=False)

print(f"\nTTA Submission files saved:")
print(f"- submission_tta.csv (Kaggle format)")
print(f"- submission_tta_detailed.csv (with confidence scores)")
print(f"- test_probabilities_tta.csv (with all class probabilities)")
print(f"Total test predictions: {len(submission_df)}")
print(f"Average confidence: {np.mean(confidences):.4f}")
print(f"Prediction distribution:")
print(submission_df['Predicted'].value_counts().sort_index())

# Display confidence statistics
print(f"\nConfidence Statistics:")
print(f"Min confidence: {np.min(confidences):.4f}")
print(f"Max confidence: {np.max(confidences):.4f}")
print(f"Mean confidence: {np.mean(confidences):.4f}")
print(f"Std confidence: {np.std(confidences):.4f}")

# Show low confidence predictions (might need manual review)
low_conf_threshold = np.percentile(confidences, 10)  # Bottom 10%
low_conf_predictions = submission_df[submission_df['Confidence'] < low_conf_threshold]
print(f"\nLow confidence predictions (< {low_conf_threshold:.4f}):")
print(f"Count: {len(low_conf_predictions)}")
if len(low_conf_predictions) > 0:
    print(low_conf_predictions.head())

print("\nTraining and TTA prediction completed successfully!")
print(f"Best validation accuracy: {best_val_acc:.2f}%")