In [10]:
import os
import torch
from torchvision import datasets
from torchvision.transforms import v2
from torch.utils.data import DataLoader, random_split

## Data Processing

In [30]:
DATA_DIR = "../data/EuroSAT_RGB"
BATCH_SIZE = 64
EPOCHS = 10
LR = 1e-4
NUM_CLASSES = 10
IMG_SIZE = 64
NUM_WORKERS = 4

In [39]:
transform = v2.Compose([
    v2.Resize((IMG_SIZE, IMG_SIZE)),
    v2.ToTensor(),
    v2.Normalize(mean=[0.3444, 0.3809, 0.4082], std=[0.1459, 0.1132, 0.1137])
])



In [12]:
ds = datasets.ImageFolder(DATA_DIR, transform=transform)

In [13]:
ds

Dataset ImageFolder
    Number of datapoints: 27000
    Root location: ../data/EuroSAT_RGB
    StandardTransform
Transform: Compose(
                 Resize(size=[64, 64], interpolation=InterpolationMode.BILINEAR, antialias=True)
                 ToTensor()
                 Normalize(mean=[0.3444, 0.3809, 0.4082], std=[0.1459, 0.1132, 0.1137], inplace=False)
           )

In [16]:
train_size = int(0.8 * len(ds))
valid_size = len(ds) - train_size
train_size, valid_size

(21600, 5400)

In [17]:
train_ds, valid_ds = random_split(ds, [train_size, valid_size])

In [18]:
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True)
valid_dl = DataLoader(valid_ds, batch_size=64)

### Train Resnet

In [20]:
import torch.nn as nn
from torchvision.models import resnet18, ResNet18_Weights

In [41]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_num_threads(os.cpu_count()) 

In [42]:
model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)

In [43]:
model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
model = model.to(device)

In [44]:
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [None]:
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    for imgs, labels in train_dl:
        imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)

        optimizer.zero_grad()
        preds = model(imgs)
        loss = loss_func(preds, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"[Epoch {epoch+1}/{EPOCHS}] Train Loss: {total_loss:.4f}")

### Validation

In [None]:
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    
    acc = correct / total
    print(f"Validation Accuracy: {acc:.3f}")

### Save

In [None]:
MODEL_PATH = "model/resnet18_eurosat.pth"
torch.save(model.state_dict(), MODEL_PATH)

### Predict

In [None]:
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.3444, 0.3809, 0.4082], std=[0.1459, 0.1132, 0.1137])
])

def predict_image(image_path):
    image = Image.open(image_path).convert("RGB")
    input_tensor = transform(image).unsqueeze(0).to(device)  # [B, C, H, W]

    with torch.no_grad():
        output = model(input_tensor)
        probs = torch.nn.functional.softmax(output[0], dim=0)
    
    predicted_class = classes[probs.argmax().item()]
    confidence = probs.max().item()

    return predicted_class, round(confidence, 3), probs.cpu().numpy()