In [None]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from PIL import Image
import pandas as pd
from tqdm import tqdm
import re
from sklearn.model_selection import KFold

######################################
# Reproducibility & Device Setup
######################################
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

######################################
# Offline Data Augmentation (Skip if Already Done)
######################################
original_train_dir = "train/train"
augmented_train_dir = "train_augmented"

if not os.path.exists(augmented_train_dir) or not os.listdir(augmented_train_dir):
    os.makedirs(augmented_train_dir, exist_ok=True)
    offline_aug_transforms = transforms.Compose([
        transforms.Resize((288, 288)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.RandomResizedCrop(224, scale=(0.8, 1.0))
    ])
    num_augmented = 10
    print("Starting offline data augmentation...")
    for class_name in os.listdir(original_train_dir):
        class_path = os.path.join(original_train_dir, class_name)
        if not os.path.isdir(class_path):
            continue
        augmented_class_path = os.path.join(augmented_train_dir, class_name)
        os.makedirs(augmented_class_path, exist_ok=True)
        for img_name in os.listdir(class_path):
            if not img_name.lower().endswith((".jpg", ".jpeg", ".png")):
                continue
            img_path = os.path.join(class_path, img_name)
            try:
                image = Image.open(img_path).convert("RGB")
            except Exception as e:
                print(f"Error opening image {img_path}: {e}")
                continue
            image.save(os.path.join(augmented_class_path, img_name))
            for i in range(num_augmented):
                aug_image = offline_aug_transforms(image)
                new_filename = f"{os.path.splitext(img_name)[0]}_aug{i}{os.path.splitext(img_name)[1]}"
                aug_image.save(os.path.join(augmented_class_path, new_filename))
    print("Offline data augmentation complete. Augmented images saved in:", augmented_train_dir)
else:
    print("Using existing augmented dataset in:", augmented_train_dir)

######################################
# Online Transformations for Training & Validation
# (Using only AutoAugment and removing GaussianBlur)
######################################
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.AutoAugment(transforms.AutoAugmentPolicy.IMAGENET),  # Only AutoAugment is used
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value='random')
])

val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

######################################
# Custom Dataset (Avoiding datasets.ImageFolder)
######################################
class CustomImageDataset(torch.utils.data.Dataset):
    def __init__(self, root, transform=None):
        self.transform = transform
        self.samples = []
        class_names = [d for d in os.listdir(root) if os.path.isdir(os.path.join(root, d))]
        try:
            sorted_class_names = sorted(class_names, key=lambda x: int(x))
        except:
            sorted_class_names = sorted(class_names)
        self.class_to_idx = {class_name: i for i, class_name in enumerate(sorted_class_names)}
        for class_name in sorted_class_names:
            folder = os.path.join(root, class_name)
            for fname in os.listdir(folder):
                if fname.lower().endswith(('.jpg', '.jpeg', '.png')):
                    path = os.path.join(folder, fname)
                    self.samples.append((path, self.class_to_idx[class_name]))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, index):
        path, label = self.samples[index]
        image = Image.open(path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

######################################
# Visual Transformer Model with Dropout
# and Partial Freezing (Unfreeze last 4 layers + heads)
######################################
class ViT_Dropout(nn.Module):
    def __init__(self, num_classes=100, dropout_p=0.5, freeze_feature_extractor=False):
        super(ViT_Dropout, self).__init__()
        self.model = models.vit_b_16(pretrained=True)
        num_ftrs = self.model.heads.head.in_features
        self.model.heads = nn.Sequential(
            nn.Dropout(p=dropout_p),
            nn.Linear(num_ftrs, num_classes)
        )
        if freeze_feature_extractor:
            for name, param in self.model.named_parameters():
                # Unfreeze layers 8, 9, 10, 11 and the classification head
                if ("encoder.layers.encoder_layer_8" in name or
                    "encoder.layers.encoder_layer_9" in name or
                    "encoder.layers.encoder_layer_10" in name or
                    "encoder.layers.encoder_layer_11" in name or
                    "heads" in name):
                    param.requires_grad = True
                else:
                    param.requires_grad = False

    def forward(self, x):
        return self.model(x)

######################################
# MixUp Functions with adjusted alpha (0.2)
######################################
def mixup_data(x, y, alpha=0.2):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

######################################
# Helper: Create Optimizer
######################################
def create_optimizer(model, config):
    if config['freeze_feature_extractor']:
        classifier_params = []
        feature_params = []
        for name, param in model.model.named_parameters():
            if param.requires_grad:
                if "heads" in name:
                    classifier_params.append(param)
                else:
                    feature_params.append(param)
        params = [
            {'params': feature_params, 'lr': config['lr_feature']},
            {'params': classifier_params, 'lr': config['lr_head']}
        ]
    else:
        classifier_params = list(model.model.heads.parameters())
        feature_params = [p for name, p in model.model.named_parameters() if "heads" not in name]
        params = [
            {'params': feature_params, 'lr': config['lr_feature']},
            {'params': classifier_params, 'lr': config['lr_head']}
        ]
    
    if config['optimizer_type'] == "SGD":
        optimizer = optim.SGD(params, momentum=config.get('momentum', 0.9), weight_decay=config['weight_decay'])
    else:
        optimizer = optim.AdamW(params, weight_decay=config['weight_decay'])
    return optimizer

######################################
# Training & Validation with Early Stopping
######################################
def train_validate(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=5, patience=2):
    best_val_acc = 0.0
    best_model_state = None
    epochs_no_improve = 0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0.0
        total = 0
        
        for inputs, labels in tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}", leave=False):
            inputs, labels = inputs.to(device), labels.to(device)
            inputs, targets_a, targets_b, lam = mixup_data(inputs, labels, alpha=0.2)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            correct += lam * (preds == targets_a).sum().item() + (1 - lam) * (preds == targets_b).sum().item()
            total += labels.size(0)

        scheduler.step()
        train_acc = correct / total
        
        model.eval()
        correct_val = 0
        total_val = 0
        with torch.no_grad():
            for inputs, labels in tqdm(val_loader, desc=f"Validation Epoch {epoch+1}/{num_epochs}", leave=False):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                correct_val += torch.sum(preds == labels.data).item()
                total_val += labels.size(0)
        val_acc = correct_val / total_val

        print(f"Epoch {epoch+1}/{num_epochs} - Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = model.state_dict()
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"Early stopping triggered at epoch {epoch+1}")
                break

    return best_val_acc, best_model_state

######################################
# -- Using a Single Hyperparameter Set and Updated Freezing --
######################################
best_config = {
    'optimizer_type': 'AdamW',
    'lr_head': 1e-3,       # Updated learning rate
    'lr_feature': 1e-3,    # Updated learning rate for feature extractor
    'momentum': 0.0,       # Not used for AdamW
    'weight_decay': 1e-4,
    'freeze_feature_extractor': True
}

######################################
# Final Training on Full Augmented Dataset
######################################
print("\nUsing hyperparameters:", best_config)
final_train_dataset = CustomImageDataset(root=augmented_train_dir, transform=train_transforms)
final_model = ViT_Dropout(
    num_classes=100, 
    dropout_p=0.5,
    freeze_feature_extractor=best_config['freeze_feature_extractor']
).to(device)

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = create_optimizer(final_model, best_config)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
num_epochs_final = 15  # Increased epochs
patience_final = 5     # Increased patience

print("\nStarting final training on full augmented dataset...")
total_size = len(final_train_dataset)
val_size = int(0.1 * total_size)
train_size = total_size - val_size
train_dataset_final, val_dataset_final = torch.utils.data.random_split(
    final_train_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(SEED)
)
train_dataset_final.dataset.transform = train_transforms
val_dataset_final.dataset.transform = val_transforms

final_train_loader = torch.utils.data.DataLoader(train_dataset_final, batch_size=64, shuffle=True, num_workers=4)
final_val_loader = torch.utils.data.DataLoader(val_dataset_final, batch_size=64, shuffle=False, num_workers=4)

best_val_acc_final = 0.0
best_final_model_state = None
epochs_no_improve = 0

for epoch in range(num_epochs_final):
    final_model.train()
    running_loss = 0.0
    correct = 0.0
    total = 0
    for inputs, labels in tqdm(final_train_loader, desc=f"Final Training Epoch {epoch+1}/{num_epochs_final}"):
        inputs, labels = inputs.to(device), labels.to(device)
        inputs, targets_a, targets_b, lam = mixup_data(inputs, labels, alpha=0.2)
        optimizer.zero_grad()
        outputs = final_model(inputs)
        loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
        _, preds = torch.max(outputs, 1)
        correct += lam * (preds == targets_a).sum().item() + (1 - lam) * (preds == targets_b).sum().item()
        total += labels.size(0)

    scheduler.step()
    train_acc = correct / total

    final_model.eval()
    correct_val = 0
    total_val = 0
    with torch.no_grad():
        for inputs, labels in tqdm(final_val_loader, desc=f"Final Validation Epoch {epoch+1}/{num_epochs_final}"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = final_model(inputs)
            _, preds = torch.max(outputs, 1)
            correct_val += torch.sum(preds == labels.data).item()
            total_val += labels.size(0)
    val_acc = correct_val / total_val
    print(f"Final Epoch {epoch+1} - Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")

    if val_acc > best_val_acc_final:
        best_val_acc_final = val_acc
        best_final_model_state = final_model.state_dict()
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience_final:
            print(f"Final training early stopping at epoch {epoch+1}")
            break

final_model.load_state_dict(best_final_model_state)
torch.save(final_model.state_dict(), "transfer_learning_model_final.pth")
print("Final model saved as: transfer_learning_model_final.pth")

######################################
# Test-Time Inference and Submission Generation
######################################
test_dir = "test/test"

class TestDataset(torch.utils.data.Dataset):
    def __init__(self, test_dir, transform):
        self.test_dir = test_dir
        self.transform = transform
        self.image_files = sorted([f for f in os.listdir(test_dir) if f.lower().endswith('.jpg')])
    def __len__(self):
        return len(self.image_files)
    def __getitem__(self, idx):
        image_file = self.image_files[idx]
        img_path = os.path.join(self.test_dir, image_file)
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, image_file

test_dataset = TestDataset(test_dir, val_transforms)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

final_model.eval()
predictions = {}
with torch.no_grad():
    for inputs, image_files in tqdm(test_loader, desc="Predicting on Test Set"):
        inputs = inputs.to(device)
        outputs = final_model(inputs)
        _, preds = torch.max(outputs, 1)
        for file, pred in zip(image_files, preds.cpu().numpy()):
            predictions[file] = int(pred)

submission = pd.DataFrame({
    "ID": list(predictions.keys()),
    "Label": list(predictions.values())
})
submission.to_csv("submission.csv", index=False)
submission = pd.read_csv("submission.csv")
submission["numeric_id"] = submission["ID"].apply(lambda x: int(re.findall(r"\d+", x)[0]))
submission_sorted = submission.sort_values(by="numeric_id")
submission_sorted.drop(columns=["numeric_id"], inplace=True)
submission_sorted.to_csv("submission_sorted.csv", index=False)
print("Submission file created: submission_sorted.csv")
