In [1]:
import os
import torch
from torchvision import datasets
from torchvision.transforms import v2
from torch.utils.data import DataLoader, random_split

## EDA

In [2]:
DATA_DIR = "/kaggle/input/eurosat-rgb/EuroSAT_RGB"
BATCH_SIZE = 64
EPOCHS = 10
LR = 1e-4
NUM_CLASSES = 10
IMG_SIZE = 64
NUM_WORKERS = 4

In [4]:
transform = v2.Compose([
    v2.Resize((IMG_SIZE, IMG_SIZE)),
    v2.ToTensor(),
    v2.Normalize(mean=[0.3444, 0.3809, 0.4082], std=[0.1459, 0.1132, 0.1137])
])



In [5]:
ds = datasets.ImageFolder(DATA_DIR, transform=transform)
ds

Dataset ImageFolder
    Number of datapoints: 27000
    Root location: /kaggle/input/eurosat-rgb/EuroSAT_RGB
    StandardTransform
Transform: Compose(
                 Resize(size=[64, 64], interpolation=InterpolationMode.BILINEAR, antialias=True)
                 ToTensor()
                 Normalize(mean=[0.3444, 0.3809, 0.4082], std=[0.1459, 0.1132, 0.1137], inplace=False)
           )

In [6]:
train_size = int(0.8 * len(ds))
valid_size = len(ds) - train_size
train_size, valid_size

(21600, 5400)

In [7]:
train_ds, valid_ds = random_split(ds, [train_size, valid_size])

In [8]:
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True)
valid_dl = DataLoader(valid_ds, batch_size=64)

### Train

In [9]:
import torch.nn as nn
from torchvision.models import resnet18, ResNet18_Weights

In [10]:
torch.cuda.is_available()

True

In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_num_threads(2) 

In [16]:
model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)

In [17]:
model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
model = model.to(device)

In [18]:
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [19]:
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    for imgs, labels in train_dl:
        imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)

        optimizer.zero_grad()
        preds = model(imgs)
        loss = loss_func(preds, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"[Epoch {epoch+1}/{EPOCHS}] Train Loss: {total_loss:.4f}")

[Epoch 1/10] Train Loss: 98.5493
[Epoch 2/10] Train Loss: 27.1437
[Epoch 3/10] Train Loss: 15.7292
[Epoch 4/10] Train Loss: 9.9024
[Epoch 5/10] Train Loss: 8.6789
[Epoch 6/10] Train Loss: 7.2103
[Epoch 7/10] Train Loss: 10.0679
[Epoch 8/10] Train Loss: 9.0689
[Epoch 9/10] Train Loss: 6.4266
[Epoch 10/10] Train Loss: 6.9214


### Validation

In [21]:
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for imgs, labels in valid_dl:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    
    acc = correct / total
    print(f"Validation Accuracy: {acc:.3f}")

Validation Accuracy: 0.967


### Train Results (Iterative)
1. Resnet18, Adam, CrossEntropyLoss, lr = 1e-4, 10 epochs => Accuracy: 0.967

### Save

In [29]:
MODEL_PATH = "resnet18_eurosat.pth"
torch.save(model.state_dict(), MODEL_PATH)

### Predict

In [25]:
from PIL import Image

In [31]:
classes = sorted(os.listdir("/kaggle/input/eurosat-rgb/EuroSAT_RGB"))
classes

['AnnualCrop',
 'Forest',
 'HerbaceousVegetation',
 'Highway',
 'Industrial',
 'Pasture',
 'PermanentCrop',
 'Residential',
 'River',
 'SeaLake']

In [32]:
def predict_image(image_path):
    image = Image.open(image_path).convert("RGB")
    input_tensor = transform(image).unsqueeze(0).to(device)  # [B, C, H, W]

    with torch.no_grad():
        output = model(input_tensor)
        probs = torch.nn.functional.softmax(output[0], dim=0)
    
    predicted_class = classes[probs.argmax().item()]
    confidence = probs.max().item()

    return predicted_class, round(confidence, 3), probs.cpu().numpy()

In [33]:
predict_image('/kaggle/input/eurosat-rgb/EuroSAT_RGB/Forest/Forest_1.jpg')

('Forest',
 1.0,
 array([2.17702137e-07, 9.99968886e-01, 1.63910954e-06, 9.21126048e-06,
        4.90011587e-07, 1.23717355e-05, 1.23592395e-06, 2.09931204e-06,
        3.68005999e-06, 1.61437256e-07], dtype=float32))