In [None]:
# ============================================
# 0. Imports & Environment Setup
# ============================================
import os
import sys
import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.cuda.amp import GradScaler, autocast

from torchvision import transforms, models
import timm

from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import f1_score

from tqdm import tqdm
import copy

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ============================================
# 1. Constants & Paths
# ============================================
# Imagenet normalization stats
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

# Mapping soil-type names → integer labels
label_map = {
    'Alluvial soil': 0,
    'Black Soil'   : 1,
    'Clay soil'    : 2,
    'Red soil'     : 3
}
inv_label_map = {v:k for k,v in label_map.items()}

# Filepaths for train metadata and image directories
TRAIN_CSV = 'soil_classification-2025/train_labels.csv'
TRAIN_DIR = 'soil_classification-2025/train'
TEST_DIR  = 'soil_classification-2025/test'

# ============================================
# 2. Read & Inspect CSV Labels
# ============================================
df = pd.read_csv(TRAIN_CSV)
print("Total images:", len(df))
print(df.head())

# Show class distribution
dist = df['soil_type'].value_counts().rename_axis('soil_type').reset_index(name='count')
print("\nClass distribution:")
print(dist)



  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda
Total images: 1222
           image_id      soil_type
0  img_ed005410.jpg  Alluvial soil
1  img_0c5ecd2a.jpg  Alluvial soil
2  img_ed713bb5.jpg  Alluvial soil
3  img_12c58874.jpg  Alluvial soil
4  img_eff357af.jpg  Alluvial soil

Class distribution:
       soil_type  count
0  Alluvial soil    528
1       Red soil    264
2     Black Soil    231
3      Clay soil    199


In [None]:
# Print versions for debugging
print("Python:", sys.version)
print("Torch version:", torch.__version__, "CUDA available:", torch.cuda.is_available())


Python: 3.9.21 | packaged by conda-forge | (main, Dec  5 2024, 13:41:22) [MSC v.1929 64 bit (AMD64)]
Torch version: 2.6.0+cu118 CUDA available: True


In [None]:
# ============================================
# 3. Data Transforms
# ============================================
# Augmentations + normalization for training
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
])

# Only resizing + normalization for validation
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
])

In [None]:
# ============================================
# 4. Custom Dataset Definition
# ============================================
class SoilDataset(Dataset):
    """
    PyTorch Dataset for loading soil images and labels.
    Expects a DataFrame with columns ['image_id', 'soil_type'].
    """
    def __init__(self, df, img_dir, transform=None):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        # Load image and convert to RGB
        img = Image.open(f"{self.img_dir}/{row['image_id']}").convert('RGB')
        # Apply transforms if provided
        if self.transform:
            img = self.transform(img)
        # Map soil-type string → integer label
        label = label_map[row['soil_type']]
        return img, label


In [None]:
# ============================================
# 5. Train/Validation Split & Sampling
# ============================================
# Stratified split to preserve class ratios
train_df, val_df = train_test_split(
    df,
    test_size=0.2,
    stratify=df['soil_type'],
    random_state=42
)
print(f"Train size: {len(train_df)}, Val size: {len(val_df)}")

# Instantiate datasets
train_dataset = SoilDataset(train_df, TRAIN_DIR, transform=train_transform)
val_dataset   = SoilDataset(val_df,   TRAIN_DIR, transform=val_transform)


Train size: 977, Val size: 245


In [None]:

# Compute balanced class weights for FocalLoss
train_targets = train_df['soil_type'].map(label_map).values
classes       = np.unique(train_targets)
cw            = compute_class_weight('balanced', classes=classes, y=train_targets)
class_weights = torch.tensor(cw, dtype=torch.float).to(device)
print("Class weights:", class_weights)


Class weights: tensor([0.5788, 1.3203, 1.5362, 1.1576], device='cuda:0')


In [None]:
# Set up WeightedRandomSampler to correct for imbalance in DataLoader
class_sample_counts = np.bincount(train_targets)
weights             = 1.0 / class_sample_counts
samples_weight      = weights[train_targets]
samples_weight      = torch.from_numpy(samples_weight).double()

sampler = WeightedRandomSampler(
    weights=samples_weight,
    num_samples=len(samples_weight),
    replacement=True
)


In [None]:
# ============================================
# 6. Loss Function: Focal Loss
# ============================================
class FocalLoss(nn.Module):
    """
    Implements focal loss for multi-class classification.
    down-weights well-classified examples (gamma > 0).
    """
    def __init__(self, gamma=2.0, weight=None):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.weight = weight

    def forward(self, input, target):
        # Compute log-softmax & probabilities
        log_softmax = F.log_softmax(input, dim=1)  # [B, C]
        probs = torch.exp(log_softmax)             # [B, C]

        # Select the log-prob and prob for the true class
        logpt = log_softmax.gather(1, target.unsqueeze(1)).squeeze(1)  # [B]
        pt    = probs.gather(1, target.unsqueeze(1)).squeeze(1)       # [B]

        # Focal loss formula
        loss = -((1 - pt) ** self.gamma) * logpt

        # Apply per-class weighting if provided
        if self.weight is not None:
            loss = loss * self.weight[target]

        return loss.mean()

criterion = FocalLoss(gamma=2.0, weight=class_weights)

In [None]:
# ============================================
# 7. Model Definition: Swin Transformer
# ============================================
class SwinClassifier(nn.Module):
    """
    Swin Transformer backbone + simple linear classifier head.
    """
    def __init__(self, model_name='swin_tiny_patch4_window7_224', num_classes=4, pretrained=True):
        super(SwinClassifier, self).__init__()
        # Load pre-trained backbone with no head
        self.backbone = timm.create_model(
            model_name,
            pretrained=pretrained,
            num_classes=0
        )
        # Add a new linear layer for our 4 soil classes
        self.classifier = nn.Linear(self.backbone.num_features, num_classes)

    def forward(self, x):
        features = self.backbone(x)                # [B, backbone_features]
        logits   = self.classifier(features)       # [B, num_classes]
        return logits


In [None]:
# Instantiate model and move to device
model = SwinClassifier(num_classes=4).to(device)
print("Model:", model)

# ============================================
# 8. Optimizer & Scheduler
# ============================================
# Different LRs for backbone vs. classifier head
optimizer = optim.AdamW([
    {'params': model.backbone.parameters(), 'lr': 1e-5},
    {'params': model.classifier.parameters(), 'lr': 1e-4}
], weight_decay=0.01)

# Reduce LR on plateau of validation loss
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=3,
    verbose=True
)


In [None]:


batch_size = 32

train_loader = DataLoader(
    train_dataset,       # from your Step 2
    batch_size=batch_size,
    sampler=sampler,
    num_workers=0        # use 0 first to debug; bump to 2–4 when stable
)

val_loader = DataLoader(
    val_dataset,         # from your Step 2
    batch_size=batch_size,
    shuffle=False,
    num_workers=0
)



In [None]:
# ============================================
# 9. Sanity-Check Data Loaders
# ============================================
# Grab one batch from each loader to confirm shapes
train_batch = next(iter(train_loader))
val_batch   = next(iter(val_loader))
print("Train Loader:", train_batch[0].shape, train_batch[1].shape)
print("Val Loader:  ", val_batch[0].shape,   val_batch[1].shape)
print("Criterion:", criterion)
print("Optimizer:", optimizer)


Model: SwinClassifier(
  (backbone): SwinTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
      (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
    )
    (layers): Sequential(
      (0): SwinTransformerStage(
        (downsample): Identity()
        (blocks): Sequential(
          (0): SwinTransformerBlock(
            (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
            (attn): WindowAttention(
              (qkv): Linear(in_features=96, out_features=288, bias=True)
              (attn_drop): Dropout(p=0.0, inplace=False)
              (proj): Linear(in_features=96, out_features=96, bias=True)
              (proj_drop): Dropout(p=0.0, inplace=False)
              (softmax): Softmax(dim=-1)
            )
            (drop_path1): Identity()
            (norm2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
            (mlp): Mlp(
              (fc1): Linear(in_features=96, out_features=38

In [None]:
# ============================================
# 10. Training Loop
# ============================================
num_epochs = 25
best_acc   = 0.0
scaler     = GradScaler()  # for mixed-precision

# Tracking metrics
train_losses, val_losses = [], []
train_accuracies, val_accuracies = [], []

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print("-" * 30)

    for phase in ['train', 'val']:
        if phase == 'train':
            model.train()   # enable dropout, etc.
            data_loader = train_loader
        else:
            model.eval()    # disable dropout, etc.
            data_loader = val_loader

        running_loss = 0.0
        running_corrects = 0
        total_samples = 0

        all_preds, all_labels = [], []

        # Iterate batches with progress bar
        with tqdm(data_loader, unit="batch") as tepoch:
            for inputs, labels in tepoch:
                inputs, labels = inputs.to(device), labels.to(device)
                optimizer.zero_grad()

                # Forward + backward (only if training)
                with torch.set_grad_enabled(phase == 'train'):
                    with autocast():  # mixed-precision
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)

                    # Get predictions
                    _, preds = torch.max(outputs, 1)

                    if phase == 'train':
                        # Scale + backward + step for AMP
                        scaler.scale(loss).backward()
                        scaler.step(optimizer)
                        scaler.update()

                # Accumulate loss & accuracy
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data).item()
                total_samples += labels.size(0)

                # For validation F1 calculation
                if phase == 'val':
                    all_preds.extend(preds.cpu().numpy())
                    all_labels.extend(labels.cpu().numpy())

        # Compute epoch metrics
        epoch_loss = running_loss / total_samples
        epoch_acc  = running_corrects / total_samples

        if phase == 'train':
            train_losses.append(epoch_loss)
            train_accuracies.append(epoch_acc)
            print(f"🛠️ Train - Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.4f}")

        else:
            # Validation: compute F1 & adjust LR
            epoch_f1 = f1_score(all_labels, all_preds, average='weighted')
            val_losses.append(epoch_loss)
            val_accuracies.append(epoch_acc)
            lr_scheduler.step(epoch_loss)

            print(f"🧪 Validation - Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.4f} | F1 Score: {epoch_f1:.4f}")

            # Save best model weights by accuracy
            if epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(best_model_wts, 'best_soil_model.pth')
                print("✅ Best model updated")

  scaler = GradScaler()



Epoch 1/25
------------------------------


  with autocast():  # AMP
100%|██████████| 31/31 [00:09<00:00,  3.16batch/s]


🛠️ Train - Loss: 0.4546 | Acc: 0.6940


100%|██████████| 8/8 [00:01<00:00,  4.92batch/s]


🧪 Validation - Loss: 0.2163 | Acc: 0.8204 | F1 Score: 0.8161
✅ Best model updated

Epoch 2/25
------------------------------


  with autocast():  # AMP
100%|██████████| 31/31 [00:09<00:00,  3.25batch/s]


🛠️ Train - Loss: 0.1515 | Acc: 0.9048


100%|██████████| 8/8 [00:01<00:00,  5.00batch/s]


🧪 Validation - Loss: 0.1129 | Acc: 0.8776 | F1 Score: 0.8776
✅ Best model updated

Epoch 3/25
------------------------------


  with autocast():  # AMP
100%|██████████| 31/31 [00:09<00:00,  3.40batch/s]


🛠️ Train - Loss: 0.0929 | Acc: 0.9099


100%|██████████| 8/8 [00:01<00:00,  4.91batch/s]


🧪 Validation - Loss: 0.0793 | Acc: 0.9347 | F1 Score: 0.9355
✅ Best model updated

Epoch 4/25
------------------------------


  with autocast():  # AMP
100%|██████████| 31/31 [00:09<00:00,  3.16batch/s]


🛠️ Train - Loss: 0.0544 | Acc: 0.9509


100%|██████████| 8/8 [00:01<00:00,  5.00batch/s]


🧪 Validation - Loss: 0.0592 | Acc: 0.9429 | F1 Score: 0.9437
✅ Best model updated

Epoch 5/25
------------------------------


  with autocast():  # AMP
100%|██████████| 31/31 [00:09<00:00,  3.44batch/s]


🛠️ Train - Loss: 0.0371 | Acc: 0.9570


100%|██████████| 8/8 [00:01<00:00,  4.98batch/s]


🧪 Validation - Loss: 0.0583 | Acc: 0.9469 | F1 Score: 0.9473
✅ Best model updated

Epoch 6/25
------------------------------


  with autocast():  # AMP
100%|██████████| 31/31 [00:09<00:00,  3.44batch/s]


🛠️ Train - Loss: 0.0264 | Acc: 0.9734


100%|██████████| 8/8 [00:01<00:00,  5.04batch/s]


🧪 Validation - Loss: 0.0418 | Acc: 0.9551 | F1 Score: 0.9556
✅ Best model updated

Epoch 7/25
------------------------------


  with autocast():  # AMP
100%|██████████| 31/31 [00:08<00:00,  3.46batch/s]


🛠️ Train - Loss: 0.0217 | Acc: 0.9693


100%|██████████| 8/8 [00:01<00:00,  4.98batch/s]


🧪 Validation - Loss: 0.0443 | Acc: 0.9510 | F1 Score: 0.9517

Epoch 8/25
------------------------------


  with autocast():  # AMP
100%|██████████| 31/31 [00:08<00:00,  3.46batch/s]


🛠️ Train - Loss: 0.0136 | Acc: 0.9867


100%|██████████| 8/8 [00:01<00:00,  5.00batch/s]


🧪 Validation - Loss: 0.0341 | Acc: 0.9551 | F1 Score: 0.9556

Epoch 9/25
------------------------------


  with autocast():  # AMP
100%|██████████| 31/31 [00:08<00:00,  3.51batch/s]


🛠️ Train - Loss: 0.0215 | Acc: 0.9785


100%|██████████| 8/8 [00:01<00:00,  5.04batch/s]


🧪 Validation - Loss: 0.0248 | Acc: 0.9714 | F1 Score: 0.9716
✅ Best model updated

Epoch 10/25
------------------------------


  with autocast():  # AMP
100%|██████████| 31/31 [00:08<00:00,  3.54batch/s]


🛠️ Train - Loss: 0.0225 | Acc: 0.9785


100%|██████████| 8/8 [00:01<00:00,  5.04batch/s]


🧪 Validation - Loss: 0.0528 | Acc: 0.9551 | F1 Score: 0.9557

Epoch 11/25
------------------------------


  with autocast():  # AMP
100%|██████████| 31/31 [00:09<00:00,  3.40batch/s]


🛠️ Train - Loss: 0.0160 | Acc: 0.9836


100%|██████████| 8/8 [00:01<00:00,  4.94batch/s]


🧪 Validation - Loss: 0.0247 | Acc: 0.9673 | F1 Score: 0.9675

Epoch 12/25
------------------------------


  with autocast():  # AMP
100%|██████████| 31/31 [00:08<00:00,  3.47batch/s]


🛠️ Train - Loss: 0.0098 | Acc: 0.9898


100%|██████████| 8/8 [00:01<00:00,  5.02batch/s]


🧪 Validation - Loss: 0.0292 | Acc: 0.9673 | F1 Score: 0.9675

Epoch 13/25
------------------------------


  with autocast():  # AMP
100%|██████████| 31/31 [00:08<00:00,  3.49batch/s]


🛠️ Train - Loss: 0.0063 | Acc: 0.9959


100%|██████████| 8/8 [00:01<00:00,  5.02batch/s]


🧪 Validation - Loss: 0.0223 | Acc: 0.9796 | F1 Score: 0.9796
✅ Best model updated

Epoch 14/25
------------------------------


  with autocast():  # AMP
100%|██████████| 31/31 [00:09<00:00,  3.42batch/s]


🛠️ Train - Loss: 0.0094 | Acc: 0.9898


100%|██████████| 8/8 [00:01<00:00,  5.00batch/s]


🧪 Validation - Loss: 0.0373 | Acc: 0.9796 | F1 Score: 0.9797

Epoch 15/25
------------------------------


  with autocast():  # AMP
100%|██████████| 31/31 [00:08<00:00,  3.46batch/s]


🛠️ Train - Loss: 0.0078 | Acc: 0.9939


100%|██████████| 8/8 [00:01<00:00,  5.02batch/s]


🧪 Validation - Loss: 0.0218 | Acc: 0.9837 | F1 Score: 0.9837
✅ Best model updated

Epoch 16/25
------------------------------


  with autocast():  # AMP
100%|██████████| 31/31 [00:09<00:00,  3.19batch/s]


🛠️ Train - Loss: 0.0089 | Acc: 0.9898


100%|██████████| 8/8 [00:01<00:00,  4.66batch/s]


🧪 Validation - Loss: 0.0199 | Acc: 0.9755 | F1 Score: 0.9756

Epoch 17/25
------------------------------


  with autocast():  # AMP
100%|██████████| 31/31 [00:09<00:00,  3.42batch/s]


🛠️ Train - Loss: 0.0057 | Acc: 0.9908


100%|██████████| 8/8 [00:01<00:00,  4.76batch/s]


🧪 Validation - Loss: 0.0173 | Acc: 0.9837 | F1 Score: 0.9837

Epoch 18/25
------------------------------


  with autocast():  # AMP
100%|██████████| 31/31 [00:09<00:00,  3.43batch/s]


🛠️ Train - Loss: 0.0045 | Acc: 0.9969


100%|██████████| 8/8 [00:01<00:00,  4.98batch/s]


🧪 Validation - Loss: 0.0198 | Acc: 0.9837 | F1 Score: 0.9838

Epoch 19/25
------------------------------


  with autocast():  # AMP
100%|██████████| 31/31 [00:09<00:00,  3.36batch/s]


🛠️ Train - Loss: 0.0032 | Acc: 0.9990


100%|██████████| 8/8 [00:01<00:00,  4.93batch/s]


🧪 Validation - Loss: 0.0189 | Acc: 0.9755 | F1 Score: 0.9757

Epoch 20/25
------------------------------


  with autocast():  # AMP
100%|██████████| 31/31 [00:09<00:00,  3.31batch/s]


🛠️ Train - Loss: 0.0046 | Acc: 0.9939


100%|██████████| 8/8 [00:01<00:00,  4.26batch/s]


🧪 Validation - Loss: 0.0202 | Acc: 0.9878 | F1 Score: 0.9878
✅ Best model updated

Epoch 21/25
------------------------------


  with autocast():  # AMP
100%|██████████| 31/31 [00:09<00:00,  3.15batch/s]


🛠️ Train - Loss: 0.0046 | Acc: 0.9969


100%|██████████| 8/8 [00:01<00:00,  4.86batch/s]


🧪 Validation - Loss: 0.0154 | Acc: 0.9878 | F1 Score: 0.9878

Epoch 22/25
------------------------------


  with autocast():  # AMP
100%|██████████| 31/31 [00:09<00:00,  3.10batch/s]


🛠️ Train - Loss: 0.0028 | Acc: 0.9980


100%|██████████| 8/8 [00:01<00:00,  4.86batch/s]


🧪 Validation - Loss: 0.0225 | Acc: 0.9796 | F1 Score: 0.9797

Epoch 23/25
------------------------------


  with autocast():  # AMP
100%|██████████| 31/31 [00:08<00:00,  3.45batch/s]


🛠️ Train - Loss: 0.0046 | Acc: 0.9949


100%|██████████| 8/8 [00:01<00:00,  4.90batch/s]


🧪 Validation - Loss: 0.0203 | Acc: 0.9796 | F1 Score: 0.9797

Epoch 24/25
------------------------------


  with autocast():  # AMP
100%|██████████| 31/31 [00:08<00:00,  3.44batch/s]


🛠️ Train - Loss: 0.0032 | Acc: 0.9949


100%|██████████| 8/8 [00:01<00:00,  4.94batch/s]


🧪 Validation - Loss: 0.0186 | Acc: 0.9878 | F1 Score: 0.9878

Epoch 25/25
------------------------------


  with autocast():  # AMP
100%|██████████| 31/31 [00:08<00:00,  3.49batch/s]


🛠️ Train - Loss: 0.0025 | Acc: 0.9969


100%|██████████| 8/8 [00:01<00:00,  4.98batch/s]

🧪 Validation - Loss: 0.0201 | Acc: 0.9878 | F1 Score: 0.9878





In [41]:
import os
import pandas as pd
import torch
import torch.nn as nn
import timm
import torchvision.transforms as transforms
from PIL import Image
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SwinClassifier(num_classes=4).to(device)

# 3. Load your trained weights
checkpoint_path = 'best_soil_model.pth'
state_dict = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(state_dict)
model.eval()

# -------------------------------
# 4. Prepare your class labels
# -------------------------------
class_names = ['Alluvial soil', 'Black Soil', 'Clay soil', 'Red soil']

# -------------------------------
# 5. Define the image transforms
# -------------------------------
test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# -------------------------------
# 6. Load your test IDs CSV
# -------------------------------
test_csv_path = os.path.join('soil_classification-2025', 'test_ids.csv')
test_df = pd.read_csv(test_csv_path)  
# test_df['image_id'] order is kept as in the file

# -------------------------------
# 7. Run inference
# -------------------------------
predictions = []
for img_name in tqdm(test_df['image_id'], desc="Predicting"):
    img_path = os.path.join('soil_classification-2025', 'test', img_name)
    try:
        img = Image.open(img_path).convert('RGB')
        inp = test_transforms(img).unsqueeze(0).to(device)
        with torch.no_grad():
            logits = model(inp)
            pred_idx = logits.argmax(dim=1).item()
        predictions.append(class_names[pred_idx])
    except Exception as e:
        print(f"Error processing {img_name}: {e}")
        predictions.append("Error")

# -------------------------------
# 8. Attach predictions & save
# -------------------------------
test_df['soil_type'] = predictions
output_path = 'prediction.csv'
test_df.to_csv(output_path, index=False)

print(f"✅ Saved predictions to {output_path} (same order as input CSV)")


Predicting: 100%|██████████| 341/341 [00:03<00:00, 90.94it/s] 

✅ Saved predictions to prediction.csv (same order as input CSV)



