In [None]:
import os
import torch
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from PIL import Image
from transformers import ViTForImageClassification, ViTFeatureExtractor
from tqdm import tqdm
from sklearn.metrics import accuracy_score

# Set device to CUDA if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ==========================
# Dataset Configuration
# ==========================
# Kaggle dataset path
dataset_path = "/kaggle/input/nthuddd2/train_data"  # Update this path if needed
print(f"Dataset path: {dataset_path}")

# Check if dataset exists
if not os.path.exists(dataset_path):
    raise FileNotFoundError(f"Dataset not found at {dataset_path}. Please ensure the dataset is uploaded to Kaggle.")

# ==========================
# Custom Dataset Class
# ==========================
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.classes = sorted(os.listdir(root_dir))  # ['drowsy', 'notdrowsy']
        
        # Load image paths and labels
        for label, class_name in enumerate(self.classes):
            class_dir = os.path.join(root_dir, class_name)
            if not os.path.isdir(class_dir):
                raise FileNotFoundError(f"Class directory {class_dir} not found.")
            for file_name in os.listdir(class_dir):
                self.image_paths.append(os.path.join(class_dir, file_name))
                self.labels.append(label)
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        try:
            image = Image.open(image_path).convert("RGB")
            if self.transform:
                image = self.transform(image)
            return {"pixel_values": image, "labels": torch.tensor(label, dtype=torch.long)}
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            return None

# ==========================
# Define Transformations
# ==========================
# Load feature extractor for ViT
model_name = "google/vit-base-patch16-224-in21k"
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)

# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(),
    transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
])

# ==========================
# Load Dataset and Split into Train/Validation
# ==========================
dataset = CustomDataset(dataset_path, transform=transform)
train_size = int(0.8 * len(dataset))  # 80% for training
val_size = len(dataset) - train_size  # 20% for validation
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=2)

# ==========================
# Load Model
# ==========================
model = ViTForImageClassification.from_pretrained(model_name, num_labels=2)  # 2 classes: drowsy, notdrowsy
model.to(device)

# ==========================
# Define Training Setup
# ==========================
optimizer = AdamW(model.parameters(), lr=0.0002)
criterion = CrossEntropyLoss()

# Training Loop
num_epochs = 10
best_loss = float("inf")
early_stopping_patience = 5
early_stopping_counter = 0

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    # Training loop with tqdm progress bar
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} (Training)"):
        if batch is None:
            continue  # Skip problematic batches

        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        optimizer.zero_grad()
        outputs = model(pixel_values=pixel_values, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = torch.max(outputs.logits, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

    # Calculate training metrics
    avg_loss = total_loss / len(train_loader)
    train_accuracy = 100 * correct / total
    print(f"Epoch {epoch+1}: Training Loss: {avg_loss:.4f}, Training Accuracy: {train_accuracy:.2f}%")

    # Validation loop
    model.eval()
    val_correct = 0
    val_total = 0
    val_loss = 0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} (Validation)"):
            if batch is None:
                continue  # Skip problematic batches

            pixel_values = batch["pixel_values"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(pixel_values=pixel_values, labels=labels)
            loss = outputs.loss
            val_loss += loss.item()
            _, predicted = torch.max(outputs.logits, 1)
            val_correct += (predicted == labels).sum().item()
            val_total += labels.size(0)

    # Calculate validation metrics
    val_accuracy = 100 * val_correct / val_total
    avg_val_loss = val_loss / len(val_loader)
    print(f"Epoch {epoch+1}: Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")

    # Save Best Model
    if avg_val_loss < best_loss:
        best_loss = avg_val_loss
        early_stopping_counter = 0
        torch.save(model.state_dict(), "/kaggle/working/best_model.pth")
        print("Saved best model.")
    else:
        early_stopping_counter += 1
        if early_stopping_counter >= early_stopping_patience:
            print(f"Early stopping triggered at epoch {epoch+1}.")
            break

# ==========================
# Save Model to Kaggle Output Folder
# ==========================
model.save_pretrained("/kaggle/working/trained_vit_model")
feature_extractor.save_pretrained("/kaggle/working/trained_vit_model")
print("Model and feature extractor saved to /kaggle/working/trained_vit_model")