In [None]:
"""
This notebook contains code for all 5 experiments:
1. Baseline CNN
2. Fine-tuned ResNet50
3. Fine-tuned EfficientNet-B0
4. ResNet50 + Data Augmentation
5. EfficientNet-B0 + Data Augmentation

Dataset: https://doi.org/10.34740/KAGGLE/DSV/12745533

Expected folder structure:
├── Training/
│   ├── glioma_tumor/
│   ├── meningioma_tumor/
│   ├── pituitary_tumor/
│   └── no_tumor/
└── Testing/
    ├── glioma_tumor/
    ├── meningioma_tumor/
    ├── pituitary_tumor/
    └── no_tumor/

Note on Reproducibility:
Random seeds are set for maximum reproducibility. However, due to non-deterministic
GPU operations and stochastic training processes, minor variations in results
(typically within ±2% accuracy) may occur between runs.
"""


In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from tqdm import tqdm
from collections import Counter
from torch.utils.data import random_split, Subset
from sklearn.model_selection import StratifiedShuffleSplit
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import random
import copy
import time
from torch.utils.data import WeightedRandomSampler

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Paths - UPDATE THESE if needed
train_dir = os.path.join(os.getcwd(), 'Training')
test_dir = os.path.join(os.getcwd(), 'Testing')



In [None]:

# ============================================================================
# 1. Baseline Model
# ============================================================================

class BrainTumorCNN(nn.Module):
    def __init__(self, num_classes):
        super(BrainTumorCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(128 * 28 * 28, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

print('Baseline CNN model defined')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size   = 32
num_epochs   = 10
learning_rate= 1e-3
num_classes  = 4  # glioma, meningioma, no_tumor, pituitary

best_val_loss = float('inf')
patience_counter = 0
patience = 5

# Transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])

# Datasets
base_train_dataset = datasets.ImageFolder(root=train_dir, transform=transform)
test_dataset       = datasets.ImageFolder(root=test_dir,  transform=transform)

# Labels for stratified split
y = [label for _, label in base_train_dataset.samples]

# Stratified split
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
train_idx, val_idx = next(sss.split(np.arange(len(y)), y))

# Build subsets
train_dataset = Subset(base_train_dataset, train_idx)
val_dataset   = Subset(base_train_dataset, val_idx)

# Sanity checks
assert max(train_idx) < len(base_train_dataset) and max(val_idx) < len(base_train_dataset)
assert set(train_idx).isdisjoint(val_idx)

# DataLoaders
g = torch.Generator().manual_seed(42)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                          num_workers=2, pin_memory=True, generator=g)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False,
                          num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False,
                          num_workers=2, pin_memory=True)

print(f"Train: {len(train_dataset)} | Val: {len(val_dataset)} | Test: {len(test_dataset)}")

# Get class distribution
train_classes = [base_train_dataset.samples[i][1] for i in train_idx]
class_counts = Counter(train_classes)

print("Training dataset class distribution:")
for class_idx, count in class_counts.items():
    class_name = base_train_dataset.classes[class_idx]
    print(f"{class_name}: {count}")

model = BrainTumorCNN(num_classes=num_classes).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
train_losses = []
val_losses = []

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for i, (images, labels) in enumerate(tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} (Training)')):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        if (i + 1) % 20 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

    epoch_train_loss = running_loss / len(train_loader.dataset)
    train_accuracy = 100. * correct / total
    train_losses.append(epoch_train_loss)
    print(f'Epoch {epoch+1}/{num_epochs} | Train Loss: {epoch_train_loss:.4f} | Train Accuracy: {train_accuracy:.2f}%')

    # Validation loop
    model.eval()
    running_val_loss = 0.0
    correct_val = 0
    total_val = 0
    with torch.no_grad():
        for images_val, labels_val in tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} (Validation)'):
            images_val, labels_val = images_val.to(device), labels_val.to(device)
            outputs_val = model(images_val)
            loss_val = criterion(outputs_val, labels_val)

            running_val_loss += loss_val.item() * images_val.size(0)
            _, predicted_val = outputs_val.max(1)
            total_val += labels_val.size(0)
            correct_val += predicted_val.eq(labels_val).sum().item()

    epoch_val_loss = running_val_loss / len(val_loader.dataset)
    val_accuracy = 100. * correct_val / total_val
    val_losses.append(epoch_val_loss)
    print(f'Epoch {epoch+1}/{num_epochs} | Validation Loss: {epoch_val_loss:.4f} | Validation Accuracy: {val_accuracy:.2f}%')

    # Early Stopping Logic
    if epoch_val_loss < best_val_loss:
        best_val_loss = epoch_val_loss
        patience_counter = 0
    else:
        patience_counter += 1
        print(f"Patience: {patience_counter}/{patience}")

    if patience_counter >= patience:
        print("Early stopping triggered!")
        break

# Plot training and validation loss curves
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Baseline CNN - Training and Validation Loss')
plt.legend()
plt.grid(True)
plt.show()

# Evaluation on Test Set
model.eval()
all_labels = []
all_predictions = []
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = outputs.max(1)
        all_labels.extend(labels.cpu().numpy())
        all_predictions.extend(predicted.cpu().numpy())

# Classification Report
print("Classification Report (Baseline CNN):")
print(classification_report(all_labels, all_predictions, target_names=test_dataset.classes))

# Confusion Matrix
cm = confusion_matrix(all_labels, all_predictions)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=test_dataset.classes, yticklabels=test_dataset.classes)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Baseline CNN - Confusion Matrix')
plt.show()


In [None]:
# ============================================================================
# 2. Fine-tune Pretrained Model (ResNet50)
# ============================================================================

model_name    = "resnet50"
epochs        = 40
batch_size    = 32
max_lr        = 3e-4
weight_decay  = 1e-4
num_workers   = 2
seed          = 42
device        = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(seed); np.random.seed(seed)

# Early Stopping Parameters
patience = 10
best_val_loss = float('inf')
patience_counter = 0

# Define transforms
if model_name == "resnet50":
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    crop_size = 224
    resize_size = 256
elif model_name == "efficientnet_b0":
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    crop_size = 224
    resize_size = 256
else:
    raise ValueError("Unsupported model_name")

train_tf = transforms.Compose([
    transforms.RandomResizedCrop(crop_size, scale=(0.85, 1.0)),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])

val_tf = transforms.Compose([
    transforms.Resize(resize_size),
    transforms.CenterCrop(crop_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])

test_tf = val_tf

# Datasets
base_train_dataset = datasets.ImageFolder(root=train_dir, transform=train_tf)
test_dataset       = datasets.ImageFolder(root=test_dir,  transform=test_tf)

num_classes = len(base_train_dataset.classes)
print(f"Number of classes: {num_classes}")
print(f"Classes: {base_train_dataset.classes}")

# Stratified split
y = [label for _, label in base_train_dataset.samples]
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
train_idx, val_idx = next(sss.split(np.arange(len(y)), y))

train_dataset = Subset(base_train_dataset, train_idx)
base_val_dataset = datasets.ImageFolder(root=train_dir, transform=val_tf)
val_dataset   = Subset(base_val_dataset, val_idx)

assert set(train_idx).isdisjoint(val_idx)

# Class weights + sampler
train_targets = np.array([base_train_dataset.samples[i][1] for i in train_idx])
class_counts = torch.bincount(torch.tensor(train_targets), minlength=num_classes).float().clamp(min=1.0)

cls_weights = (class_counts.sum() / (num_classes * class_counts)).to(device)

subset_weights = torch.tensor([1.0 / class_counts[base_train_dataset.samples[i][1]] for i in train_idx],
                              dtype=torch.double)
sampler = WeightedRandomSampler(weights=subset_weights,
                                num_samples=len(subset_weights),
                                replacement=True)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler,
                          num_workers=num_workers, pin_memory=True)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False,
                          num_workers=num_workers, pin_memory=True)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False,
                          num_workers=num_workers, pin_memory=True)

# Model
if model_name == "resnet50":
    from torchvision.models import ResNet50_Weights
    weights = ResNet50_Weights.IMAGENET1K_V2
    model = models.resnet50(weights=weights)
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)

elif model_name == "efficientnet_b0":
    from torchvision.models import EfficientNet_B0_Weights
    weights = EfficientNet_B0_Weights.IMAGENET1K_V1
    model = models.efficientnet_b0(weights=weights)
    in_features = model.classifier[-1].in_features
    model.classifier[-1] = nn.Linear(in_features, num_classes)

else:
    raise ValueError("model_name must be 'resnet50' or 'efficientnet_b0'")

model = model.to(device)

# Loss / Optimizer / Scheduler
criterion = nn.CrossEntropyLoss(weight=cls_weights)
optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr, weight_decay=weight_decay)

steps_per_epoch = len(train_loader)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=max_lr,
    epochs=epochs,
    steps_per_epoch=steps_per_epoch,
    pct_start=0.2,
    div_factor=10.0,
    final_div_factor=10.0
)

scaler = torch.cuda.amp.GradScaler(enabled=device.type == "cuda")

# Eval helper
@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    total, correct = 0, 0
    all_preds, all_labels = [], []
    running_loss = 0.0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        with torch.cuda.amp.autocast(enabled=device.type == "cuda"):
            logits = model(x)
            loss = criterion(logits, y)
        running_loss += loss.item() * x.size(0)
        preds = logits.argmax(dim=1)
        all_preds.append(preds.cpu())
        all_labels.append(y.cpu())
        correct += (preds == y).sum().item()
        total += y.size(0)
    all_preds = torch.cat(all_preds).numpy()
    all_labels = torch.cat(all_labels).numpy()
    avg_loss = running_loss / total
    return (correct / total), all_labels, all_preds, avg_loss

# Train loop
best_w = copy.deepcopy(model.state_dict())
best_val_loss = float('inf')
patience_counter = 0
t0 = time.time()
train_losses = []
val_losses = []

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=device.type == "cuda"):
            logits = model(x)
            loss = criterion(logits, y)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        running_loss += loss.item()

    train_loss = running_loss / max(1, len(train_loader))
    train_losses.append(train_loss)

    val_acc, _, _, val_loss = evaluate(model, val_loader)
    val_losses.append(val_loss)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_w = copy.deepcopy(model.state_dict())
        patience_counter = 0
    else:
        patience_counter += 1

    print(f"Epoch {epoch+1:03d}/{epochs} | train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | val_acc={val_acc:.4f} "
          f"| lr={scheduler.get_last_lr()[0]:.2e}")

    if patience_counter >= patience:
        print(f"Early stopping triggered after {patience} epochs with no improvement.")
        break

model.load_state_dict(best_w)
print(f"\nTraining done in {(time.time()-t0)/60:.1f} min.")

# Plot
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Fine-tuned ResNet50 - Training and Validation Loss')
plt.legend()
plt.grid(True)
plt.show()

# Final test report
test_acc, y_true, y_pred, _ = evaluate(model, test_loader)
print(f"\nFinal Test Acc: {test_acc:.4f}")
print("\nClassification report (TEST):")
print(classification_report(y_true, y_pred, target_names=[c for c in base_train_dataset.classes], digits=3))

cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=test_dataset.classes, yticklabels=test_dataset.classes)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Fine-tuned ResNet50 - Confusion Matrix')
plt.show()

print("\nConfusion matrix (TEST):")
print(cm)


In [None]:

# ============================================================================
# 3. Data Augmentation + ResNet50
# ============================================================================

model_name    = "resnet50"
epochs        = 40
batch_size    = 32
max_lr        = 3e-4
weight_decay  = 1e-4
num_workers   = 2
seed          = 42
device        = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(seed); np.random.seed(seed)

patience = 10
best_val_loss = float('inf')
patience_counter = 0

# Define transforms with augmentation
if model_name == "resnet50":
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    crop_size = 224
    resize_size = 256
elif model_name == "efficientnet_b0":
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    crop_size = 224
    resize_size = 256
else:
    raise ValueError("Unsupported model_name")

train_tf = transforms.Compose([
    transforms.RandomResizedCrop(crop_size, scale=(0.85, 1.0)),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
    transforms.GaussianBlur(kernel_size=3),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=5),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
    transforms.Lambda(lambda x: x + 0.01 * torch.randn_like(x)),
])

val_tf = transforms.Compose([
    transforms.Resize(resize_size),
    transforms.CenterCrop(crop_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])

test_tf = val_tf

# Datasets
base_train_dataset = datasets.ImageFolder(root=train_dir, transform=train_tf)
test_dataset       = datasets.ImageFolder(root=test_dir,  transform=test_tf)

num_classes = len(base_train_dataset.classes)
print(f"Number of classes: {num_classes}")
print(f"Classes: {base_train_dataset.classes}")

# Stratified split
y = [label for _, label in base_train_dataset.samples]
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
train_idx, val_idx = next(sss.split(np.arange(len(y)), y))

train_dataset = Subset(base_train_dataset, train_idx)
base_val_dataset = datasets.ImageFolder(root=train_dir, transform=val_tf)
val_dataset   = Subset(base_val_dataset, val_idx)

assert set(train_idx).isdisjoint(val_idx)

# Class weights + sampler
train_targets = np.array([base_train_dataset.samples[i][1] for i in train_idx])
class_counts = torch.bincount(torch.tensor(train_targets), minlength=num_classes).float().clamp(min=1.0)

cls_weights = (class_counts.sum() / (num_classes * class_counts)).to(device)

subset_weights = torch.tensor([1.0 / class_counts[base_train_dataset.samples[i][1]] for i in train_idx],
                              dtype=torch.double)
sampler = WeightedRandomSampler(weights=subset_weights,
                                num_samples=len(subset_weights),
                                replacement=True)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler,
                          num_workers=num_workers, pin_memory=True)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False,
                          num_workers=num_workers, pin_memory=True)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False,
                          num_workers=num_workers, pin_memory=True)

# Model
if model_name == "resnet50":
    from torchvision.models import ResNet50_Weights
    weights = ResNet50_Weights.IMAGENET1K_V2
    model = models.resnet50(weights=weights)
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)

elif model_name == "efficientnet_b0":
    from torchvision.models import EfficientNet_B0_Weights
    weights = EfficientNet_B0_Weights.IMAGENET1K_V1
    model = models.efficientnet_b0(weights=weights)
    in_features = model.classifier[-1].in_features
    model.classifier[-1] = nn.Linear(in_features, num_classes)

else:
    raise ValueError("model_name must be 'resnet50' or 'efficientnet_b0'")

model = model.to(device)

# Loss / Optimizer / Scheduler
criterion = nn.CrossEntropyLoss(weight=cls_weights)
optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr, weight_decay=weight_decay)

steps_per_epoch = len(train_loader)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=max_lr,
    epochs=epochs,
    steps_per_epoch=steps_per_epoch,
    pct_start=0.2,
    div_factor=10.0,
    final_div_factor=10.0
)

scaler = torch.cuda.amp.GradScaler(enabled=device.type == "cuda")

# Train loop
best_w = copy.deepcopy(model.state_dict())
best_val_loss = float('inf')
patience_counter = 0
t0 = time.time()
train_losses = []
val_losses = []

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=device.type == "cuda"):
            logits = model(x)
            loss = criterion(logits, y)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        running_loss += loss.item()

    train_loss = running_loss / max(1, len(train_loader))
    train_losses.append(train_loss)

    val_acc, _, _, val_loss = evaluate(model, val_loader)
    val_losses.append(val_loss)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_w = copy.deepcopy(model.state_dict())
        patience_counter = 0
    else:
        patience_counter += 1

    print(f"Epoch {epoch+1:03d}/{epochs} | train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | val_acc={val_acc:.4f} "
          f"| lr={scheduler.get_last_lr()[0]:.2e}")

    if patience_counter >= patience:
        print(f"Early stopping triggered after {patience} epochs with no improvement.")
        break

model.load_state_dict(best_w)
print(f"\nTraining done in {(time.time()-t0)/60:.1f} min.")

# Plot
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('ResNet50 + Data Augmentation - Training and Validation Loss')
plt.legend()
plt.grid(True)
plt.show()

# Final test report
test_acc, y_true, y_pred, _ = evaluate(model, test_loader)
print(f"\nFinal Test Acc: {test_acc:.4f}")
print("\nClassification report (TEST):")
print(classification_report(y_true, y_pred, target_names=[c for c in base_train_dataset.classes], digits=3))

cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=test_dataset.classes, yticklabels=test_dataset.classes)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('ResNet50 + Data Augmentation - Confusion Matrix')
plt.show()

print("\nConfusion matrix (TEST):")
print(cm)


In [None]:
# ============================================================================
# 4. Data Augmentation + EfficientNet-B0
# ============================================================================

model_name    = "efficientnet_b0"
epochs        = 40
batch_size    = 32
max_lr        = 3e-4
weight_decay  = 1e-4
num_workers   = 2
seed          = 42
device        = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(seed); np.random.seed(seed)

patience = 10
best_val_loss = float('inf')
patience_counter = 0

# Define transforms
if model_name == "resnet50":
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    crop_size = 224
    resize_size = 256
elif model_name == "efficientnet_b0":
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    crop_size = 224
    resize_size = 256
else:
    raise ValueError("Unsupported model_name")

train_tf = transforms.Compose([
    transforms.RandomResizedCrop(crop_size, scale=(0.85, 1.0)),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness