In [23]:
import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import pandas as pd
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import f1_score
import timm
from tqdm import tqdm

# --- Config ---
IMG_SIZE = 224
BATCH_SIZE = 128
EPOCHS = 25
LEARNING_RATE = 3e-4
NUM_CLASSES = 4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SAVE_PATH = "best_convnext_model.pth"
PATIENCE = 7

# Label mapping
label2idx = {'Alluvial soil': 0, 'Black Soil': 1, 'Clay soil': 2, 'Red soil': 3}

# --- Model ---
model = timm.create_model('convnext_tiny', pretrained=True, num_classes=NUM_CLASSES, drop_rate=0.4)
model = model.to(DEVICE)

# --- Class Weights ---
class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(df.soil_type), y=df.soil_type)
weights = torch.tensor([class_weights[label2idx[k]] for k in label2idx], dtype=torch.float32).to(DEVICE)

# --- Loss, Optimizer, Scheduler ---
criterion = nn.CrossEntropyLoss(weight=weights, label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
scaler = GradScaler()

# --- Optional: MixUp Function ---
def mixup_data(x, y, alpha=1.0):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(DEVICE)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

# --- Training Loop ---
best_val_acc = 0
early_stop_counter = 0

for epoch in range(1, EPOCHS + 1):
    model.train()
    train_loss, train_correct = 0, 0
    train_preds, train_labels = [], []

    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch} [Train]"):
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()

        # Apply MixUp
        images, targets_a, targets_b, lam = mixup_data(images, labels, alpha=0.4)

        with autocast():
            outputs = model(images)
            loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        preds = outputs.argmax(1)
        # For accuracy, use targets_a
        train_loss += loss.item() * images.size(0)
        train_correct += (preds == targets_a).sum().item()
        train_preds.extend(preds.cpu().numpy())
        train_labels.extend(targets_a.cpu().numpy())

    train_acc = train_correct / len(train_loader.dataset)
    train_f1 = f1_score(train_labels, train_preds, average='weighted')

    # --- Validation ---
    model.eval()
    val_loss, val_correct = 0, 0
    val_preds, val_labels = [], []

    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc=f"Epoch {epoch} [Val]"):
            images, labels = images.to(DEVICE), labels.to(DEVICE)

            with autocast():
                outputs = model(images)
                loss = criterion(outputs, labels)

            preds = outputs.argmax(1)
            val_loss += loss.item() * images.size(0)
            val_correct += (preds == labels).sum().item()
            val_preds.extend(preds.cpu().numpy())
            val_labels.extend(labels.cpu().numpy())

    val_acc = val_correct / len(val_loader.dataset)
    val_f1 = f1_score(val_labels, val_preds, average='weighted')

    print(f"Epoch {epoch} | Train Acc: {train_acc:.4f} | Train F1: {train_f1:.4f} | Val Acc: {val_acc:.4f} | Val F1: {val_f1:.4f}")

    scheduler.step()

    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        early_stop_counter = 0
        torch.save(model.state_dict(), SAVE_PATH)
        print(f"✅ Best model saved at epoch {epoch} with Val Acc: {val_acc:.4f}")
    else:
        early_stop_counter += 1
        if early_stop_counter >= PATIENCE:
            print("⛔ Early stopping triggered.")
            break

print(f"✅ Training complete. Best Val Accuracy: {best_val_acc:.4f}")


  scaler = GradScaler()
  with autocast():
Epoch 1 [Train]: 100%|██████████| 59/59 [00:11<00:00,  5.28it/s]
  with autocast():
Epoch 1 [Val]: 100%|██████████| 8/8 [00:01<00:00,  5.90it/s]


Epoch 1 | Train Acc: 0.3760 | Train F1: 0.3628 | Val Acc: 0.8122 | Val F1: 0.8152
✅ Best model saved at epoch 1 with Val Acc: 0.8122


  with autocast():
Epoch 2 [Train]: 100%|██████████| 59/59 [00:11<00:00,  5.35it/s]
  with autocast():
Epoch 2 [Val]: 100%|██████████| 8/8 [00:01<00:00,  5.60it/s]


Epoch 2 | Train Acc: 0.5624 | Train F1: 0.5599 | Val Acc: 0.8571 | Val F1: 0.8595
✅ Best model saved at epoch 2 with Val Acc: 0.8571


  with autocast():
Epoch 3 [Train]: 100%|██████████| 59/59 [00:10<00:00,  5.61it/s]
  with autocast():
Epoch 3 [Val]: 100%|██████████| 8/8 [00:01<00:00,  5.75it/s]


Epoch 3 | Train Acc: 0.6165 | Train F1: 0.6132 | Val Acc: 0.8449 | Val F1: 0.8440


  with autocast():
Epoch 4 [Train]: 100%|██████████| 59/59 [00:10<00:00,  5.68it/s]
  with autocast():
Epoch 4 [Val]: 100%|██████████| 8/8 [00:01<00:00,  5.81it/s]


Epoch 4 | Train Acc: 0.6235 | Train F1: 0.6172 | Val Acc: 0.8776 | Val F1: 0.8786
✅ Best model saved at epoch 4 with Val Acc: 0.8776


  with autocast():
Epoch 5 [Train]: 100%|██████████| 59/59 [00:10<00:00,  5.73it/s]
  with autocast():
Epoch 5 [Val]: 100%|██████████| 8/8 [00:01<00:00,  6.07it/s]


Epoch 5 | Train Acc: 0.6749 | Train F1: 0.6709 | Val Acc: 0.8286 | Val F1: 0.8312


  with autocast():
Epoch 6 [Train]: 100%|██████████| 59/59 [00:10<00:00,  5.66it/s]
  with autocast():
Epoch 6 [Val]: 100%|██████████| 8/8 [00:01<00:00,  5.83it/s]


Epoch 6 | Train Acc: 0.5501 | Train F1: 0.5464 | Val Acc: 0.9020 | Val F1: 0.9020
✅ Best model saved at epoch 6 with Val Acc: 0.9020


  with autocast():
Epoch 7 [Train]: 100%|██████████| 59/59 [00:10<00:00,  5.52it/s]
  with autocast():
Epoch 7 [Val]: 100%|██████████| 8/8 [00:01<00:00,  5.83it/s]


Epoch 7 | Train Acc: 0.6020 | Train F1: 0.5971 | Val Acc: 0.9061 | Val F1: 0.9051
✅ Best model saved at epoch 7 with Val Acc: 0.9061


  with autocast():
Epoch 8 [Train]: 100%|██████████| 59/59 [00:10<00:00,  5.39it/s]
  with autocast():
Epoch 8 [Val]: 100%|██████████| 8/8 [00:01<00:00,  5.74it/s]


Epoch 8 | Train Acc: 0.5881 | Train F1: 0.5853 | Val Acc: 0.8939 | Val F1: 0.8923


  with autocast():
Epoch 9 [Train]: 100%|██████████| 59/59 [00:10<00:00,  5.51it/s]
  with autocast():
Epoch 9 [Val]: 100%|██████████| 8/8 [00:01<00:00,  5.68it/s]


Epoch 9 | Train Acc: 0.6438 | Train F1: 0.6409 | Val Acc: 0.9184 | Val F1: 0.9177
✅ Best model saved at epoch 9 with Val Acc: 0.9184


  with autocast():
Epoch 10 [Train]: 100%|██████████| 59/59 [00:10<00:00,  5.62it/s]
  with autocast():
Epoch 10 [Val]: 100%|██████████| 8/8 [00:01<00:00,  5.92it/s]


Epoch 10 | Train Acc: 0.6620 | Train F1: 0.6597 | Val Acc: 0.9224 | Val F1: 0.9220
✅ Best model saved at epoch 10 with Val Acc: 0.9224


  with autocast():
Epoch 11 [Train]: 100%|██████████| 59/59 [00:10<00:00,  5.67it/s]
  with autocast():
Epoch 11 [Val]: 100%|██████████| 8/8 [00:01<00:00,  5.95it/s]


Epoch 11 | Train Acc: 0.5763 | Train F1: 0.5749 | Val Acc: 0.8898 | Val F1: 0.8891


  with autocast():
Epoch 12 [Train]: 100%|██████████| 59/59 [00:10<00:00,  5.65it/s]
  with autocast():
Epoch 12 [Val]: 100%|██████████| 8/8 [00:01<00:00,  5.53it/s]


Epoch 12 | Train Acc: 0.5881 | Train F1: 0.5854 | Val Acc: 0.9224 | Val F1: 0.9223


  with autocast():
Epoch 13 [Train]: 100%|██████████| 59/59 [00:10<00:00,  5.60it/s]
  with autocast():
Epoch 13 [Val]: 100%|██████████| 8/8 [00:01<00:00,  6.05it/s]


Epoch 13 | Train Acc: 0.6079 | Train F1: 0.6070 | Val Acc: 0.9347 | Val F1: 0.9345
✅ Best model saved at epoch 13 with Val Acc: 0.9347


  with autocast():
Epoch 14 [Train]: 100%|██████████| 59/59 [00:10<00:00,  5.57it/s]
  with autocast():
Epoch 14 [Val]: 100%|██████████| 8/8 [00:01<00:00,  5.62it/s]


Epoch 14 | Train Acc: 0.6299 | Train F1: 0.6291 | Val Acc: 0.9224 | Val F1: 0.9217


  with autocast():
Epoch 15 [Train]: 100%|██████████| 59/59 [00:10<00:00,  5.56it/s]
  with autocast():
Epoch 15 [Val]: 100%|██████████| 8/8 [00:01<00:00,  6.06it/s]


Epoch 15 | Train Acc: 0.5404 | Train F1: 0.5389 | Val Acc: 0.9510 | Val F1: 0.9514
✅ Best model saved at epoch 15 with Val Acc: 0.9510


  with autocast():
Epoch 16 [Train]: 100%|██████████| 59/59 [00:10<00:00,  5.55it/s]
  with autocast():
Epoch 16 [Val]: 100%|██████████| 8/8 [00:01<00:00,  5.66it/s]


Epoch 16 | Train Acc: 0.6251 | Train F1: 0.6240 | Val Acc: 0.9306 | Val F1: 0.9306


  with autocast():
Epoch 17 [Train]: 100%|██████████| 59/59 [00:10<00:00,  5.57it/s]
  with autocast():
Epoch 17 [Val]: 100%|██████████| 8/8 [00:01<00:00,  5.62it/s]


Epoch 17 | Train Acc: 0.6352 | Train F1: 0.6350 | Val Acc: 0.9673 | Val F1: 0.9674
✅ Best model saved at epoch 17 with Val Acc: 0.9673


  with autocast():
Epoch 18 [Train]: 100%|██████████| 59/59 [00:10<00:00,  5.62it/s]
  with autocast():
Epoch 18 [Val]: 100%|██████████| 8/8 [00:01<00:00,  5.85it/s]


Epoch 18 | Train Acc: 0.6197 | Train F1: 0.6176 | Val Acc: 0.9551 | Val F1: 0.9552


  with autocast():
Epoch 19 [Train]: 100%|██████████| 59/59 [00:10<00:00,  5.66it/s]
  with autocast():
Epoch 19 [Val]: 100%|██████████| 8/8 [00:01<00:00,  6.09it/s]


Epoch 19 | Train Acc: 0.5972 | Train F1: 0.5957 | Val Acc: 0.9224 | Val F1: 0.9215


  with autocast():
Epoch 20 [Train]: 100%|██████████| 59/59 [00:10<00:00,  5.60it/s]
  with autocast():
Epoch 20 [Val]: 100%|██████████| 8/8 [00:01<00:00,  5.93it/s]


Epoch 20 | Train Acc: 0.6476 | Train F1: 0.6470 | Val Acc: 0.9469 | Val F1: 0.9467


  with autocast():
Epoch 21 [Train]: 100%|██████████| 59/59 [00:10<00:00,  5.58it/s]
  with autocast():
Epoch 21 [Val]: 100%|██████████| 8/8 [00:01<00:00,  5.67it/s]


Epoch 21 | Train Acc: 0.6610 | Train F1: 0.6597 | Val Acc: 0.9755 | Val F1: 0.9756
✅ Best model saved at epoch 21 with Val Acc: 0.9755


  with autocast():
Epoch 22 [Train]: 100%|██████████| 59/59 [00:10<00:00,  5.54it/s]
  with autocast():
Epoch 22 [Val]: 100%|██████████| 8/8 [00:01<00:00,  5.45it/s]


Epoch 22 | Train Acc: 0.6449 | Train F1: 0.6433 | Val Acc: 0.9755 | Val F1: 0.9756


  with autocast():
Epoch 23 [Train]: 100%|██████████| 59/59 [00:10<00:00,  5.59it/s]
  with autocast():
Epoch 23 [Val]: 100%|██████████| 8/8 [00:01<00:00,  5.88it/s]


Epoch 23 | Train Acc: 0.6299 | Train F1: 0.6284 | Val Acc: 0.9755 | Val F1: 0.9756


  with autocast():
Epoch 24 [Train]: 100%|██████████| 59/59 [00:10<00:00,  5.59it/s]
  with autocast():
Epoch 24 [Val]: 100%|██████████| 8/8 [00:01<00:00,  5.78it/s]


Epoch 24 | Train Acc: 0.6545 | Train F1: 0.6532 | Val Acc: 0.9347 | Val F1: 0.9343


  with autocast():
Epoch 25 [Train]: 100%|██████████| 59/59 [00:10<00:00,  5.61it/s]
  with autocast():
Epoch 25 [Val]: 100%|██████████| 8/8 [00:01<00:00,  5.90it/s]


Epoch 25 | Train Acc: 0.6508 | Train F1: 0.6505 | Val Acc: 0.9796 | Val F1: 0.9797
✅ Best model saved at epoch 25 with Val Acc: 0.9796
✅ Training complete. Best Val Accuracy: 0.9796
