In [1]:
import os
import glob
import random
import json
from PIL import Image
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as T
import torchvision.models as models

# -----------------------------
# Set seed for reproducibility
# -----------------------------
random.seed(42)
torch.manual_seed(42)

# -----------------------------
# Path and File Collection
# -----------------------------
image_dir = "/kaggle/input/treesattfg/"
all_files = glob.glob(f"{image_dir}/**/*.tif", recursive=True)

print("Number of images found:", len(all_files))
if len(all_files) == 0:
    raise ValueError("No .tif files found. Check the folder structure!")

# -----------------------------
# Load class list from JSON
# -----------------------------
def load_species_list(json_path):
    with open(json_path, 'r') as f:
        return json.load(f)

species_list = load_species_list("/kaggle/input/treesattfg/species_list.json")
species_list = sorted(species_list)

# -----------------------------
# Extract species from filename
# -----------------------------
def extract_species(filename):
    basename = os.path.basename(filename)
    parts = basename.split('_')
    species_parts = []
    for part in parts:
        if part.isdigit():
            break
        species_parts.append(part)
    species_name = "_".join(species_parts)
    if species_name.endswith("_spec."):
        species_name = species_name.split("_")[0]
    return species_name

# -----------------------------
# Dataset Definition
# -----------------------------
class TreeSpeciesDataset(Dataset):
    def __init__(self, files, transform=None):
        self.image_files = files
        self.transform = transform
        self.species_to_idx = {s: i for i, s in enumerate(species_list)}
        self.idx_to_species = {v: k for k, v in self.species_to_idx.items()}
        self.labels = [extract_species(f) for f in self.image_files]
        self.targets = [self.species_to_idx[label] for label in self.labels]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image = Image.open(self.image_files[idx]).convert("RGB")
        label = self.targets[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

# -----------------------------
# Split Data
# -----------------------------
train_files, testval_files = train_test_split(all_files, test_size=0.3, random_state=42)
val_files, test_files = train_test_split(testval_files, test_size=0.5, random_state=42)

print(f"Train: {len(train_files)}, Val: {len(val_files)}, Test: {len(test_files)}")

# -----------------------------
# Transform
# -----------------------------
transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])
])

# -----------------------------
# Create Datasets
# -----------------------------
train_dataset = TreeSpeciesDataset(train_files, transform=transform)
val_dataset = TreeSpeciesDataset(val_files, transform=transform)
test_dataset = TreeSpeciesDataset(test_files, transform=transform)

# -----------------------------
# Create DataLoaders
# -----------------------------
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# -----------------------------
# Save sample label mappings
# -----------------------------
output_path = "species_sample.txt"
with open(output_path, "w") as f:
    for i in range(min(10, len(train_dataset))):
        file = train_dataset.image_files[i]
        label_idx = train_dataset.targets[i]
        label_name = train_dataset.labels[i]
        line = f"{os.path.basename(file)} --> {label_name} (Index: {label_idx})\n"
        f.write(line)

Number of images found: 50381
Train: 35266, Val: 7557, Test: 7558


In [2]:
# -----------------------------
# Model Setup
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = len(train_dataset.species_to_idx)

model = models.resnet18(weights=None)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

# -----------------------------
# Train and Evaluation Functions
# -----------------------------
def train(model, loader):
    model.train()
    total_loss = 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return correct / total

# -----------------------------
# Training Loop
# -----------------------------
EPOCHS = 25
for epoch in range(EPOCHS):
    train_loss = train(model, train_loader)
    val_acc = evaluate(model, val_loader)
    scheduler.step()
    print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Val Acc: {val_acc:.4f}")

# -----------------------------
# Save Trained Model
# -----------------------------
torch.save(model.state_dict(), "resnet18_tree_species.pth")
print("✅ Model saved as 'resnet18_tree_species.pth'")

# -----------------------------
# Final Test Accuracy
# -----------------------------
test_acc = evaluate(model, test_loader)
print(f"🎯 Final Test Accuracy: {test_acc:.4f}")

Epoch 1/25 | Train Loss: 1.9596 | Val Acc: 0.4069
Epoch 2/25 | Train Loss: 1.6585 | Val Acc: 0.4080
Epoch 3/25 | Train Loss: 1.4615 | Val Acc: 0.5031
Epoch 4/25 | Train Loss: 1.3091 | Val Acc: 0.5721
Epoch 5/25 | Train Loss: 1.1433 | Val Acc: 0.5145
Epoch 6/25 | Train Loss: 0.7558 | Val Acc: 0.5279
Epoch 7/25 | Train Loss: 0.4640 | Val Acc: 0.5683
Epoch 8/25 | Train Loss: 0.2496 | Val Acc: 0.5502
Epoch 9/25 | Train Loss: 0.1288 | Val Acc: 0.5534
Epoch 10/25 | Train Loss: 0.0851 | Val Acc: 0.5505
Epoch 11/25 | Train Loss: 0.0383 | Val Acc: 0.5718
Epoch 12/25 | Train Loss: 0.0251 | Val Acc: 0.5676
Epoch 13/25 | Train Loss: 0.0254 | Val Acc: 0.5613
Epoch 14/25 | Train Loss: 0.0222 | Val Acc: 0.5551
Epoch 15/25 | Train Loss: 0.0234 | Val Acc: 0.5732
Epoch 16/25 | Train Loss: 0.0114 | Val Acc: 0.5755
Epoch 17/25 | Train Loss: 0.0092 | Val Acc: 0.5739
Epoch 18/25 | Train Loss: 0.0088 | Val Acc: 0.5664
Epoch 19/25 | Train Loss: 0.0083 | Val Acc: 0.5701
Epoch 20/25 | Train Loss: 0.0054 | Val A