In [1]:
import os
import glob
import random
import json
from PIL import Image
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as T
import torchvision.models as models

# -----------------------------
# Set seed for reproducibility
# -----------------------------
random.seed(42)
torch.manual_seed(42)

# -----------------------------
# Path and File Collection
# -----------------------------
image_dir = "/kaggle/input/treesattfg/"
all_files = glob.glob(f"{image_dir}/**/*.tif", recursive=True)

print("Number of images found:", len(all_files))
if len(all_files) == 0:
    raise ValueError("No .tif files found. Check the folder structure!")

# -----------------------------
# Load class list from JSON
# -----------------------------
def load_species_list(json_path):
    with open(json_path, 'r') as f:
        return json.load(f)

species_list = load_species_list("/kaggle/input/treesattfg/species_list.json")
species_list = sorted(species_list)

# -----------------------------
# Extract species from filename
# -----------------------------
def extract_species(filename):
    basename = os.path.basename(filename)
    parts = basename.split('_')
    species_parts = []
    for part in parts:
        if part.isdigit():
            break
        species_parts.append(part)
    species_name = "_".join(species_parts)
    if species_name.endswith("_spec."):
        species_name = species_name.split("_")[0]
    return species_name

# -----------------------------
# Dataset Definition
# -----------------------------
class TreeSpeciesDataset(Dataset):
    def __init__(self, files, transform=None):
        self.image_files = files
        self.transform = transform
        self.species_to_idx = {s: i for i, s in enumerate(species_list)}
        self.idx_to_species = {v: k for k, v in self.species_to_idx.items()}
        self.labels = [extract_species(f) for f in self.image_files]
        self.targets = [self.species_to_idx[label] for label in self.labels]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image = Image.open(self.image_files[idx]).convert("RGB")
        label = self.targets[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

# -----------------------------
# Split Data
# -----------------------------
train_files, testval_files = train_test_split(all_files, test_size=0.3, random_state=42)
val_files, test_files = train_test_split(testval_files, test_size=0.5, random_state=42)

print(f"Train: {len(train_files)}, Val: {len(val_files)}, Test: {len(test_files)}")

# -----------------------------
# Transform
# -----------------------------
transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])
])

# -----------------------------
# Create Datasets
# -----------------------------
train_dataset = TreeSpeciesDataset(train_files, transform=transform)
val_dataset = TreeSpeciesDataset(val_files, transform=transform)
test_dataset = TreeSpeciesDataset(test_files, transform=transform)

# -----------------------------
# Create DataLoaders
# -----------------------------
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# -----------------------------
# Save sample label mappings
# -----------------------------
output_path = "/kaggle/working/species_sample.txt"
with open(output_path, "w") as f:
    for i in range(0,len(train_dataset)):
        file = train_dataset.image_files[i]
        label_idx = train_dataset.targets[i]
        label_name = train_dataset.labels[i]
        line = f"{os.path.basename(file)} --> {label_name} (Index: {label_idx})\n"
        f.write(line)

# -----------------------------
# Save test file names
# -----------------------------
test_output_path = "/kaggle/working/test_files.txt"
with open(test_output_path, "w") as f:
    for file in test_dataset.image_files:
        f.write(os.path.basename(file) + "\n")

Number of images found: 50381
Train: 35266, Val: 7557, Test: 7558


In [2]:
# -----------------------------
# Model Setup
# -----------------------------

from torchvision.models import resnet34, ResNet34_Weights

print("🔧 Setting up model...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"✅ Using device: {device}")

num_classes = len(train_dataset.species_to_idx)
print(f"🔢 Number of classes: {num_classes}")


model = resnet34(weights=ResNet34_Weights.DEFAULT)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)
print("✅ Model created and moved to device.")

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
print("✅ Loss function, optimizer, and scheduler initialized.")

# -----------------------------
# Train and Evaluation Functions
# -----------------------------
def train(model, loader):
    print("🚂 Training...")
    model.train()
    total_loss = 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print("✅ Training step completed.")
    return total_loss / len(loader)

def evaluate(model, loader):
    print("🔍 Evaluating...")
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    print("✅ Evaluation step completed.")
    return correct / total

# -----------------------------
# Training Loop
# -----------------------------
print("🏁 Starting training loop...")
EPOCHS = 20
for epoch in range(EPOCHS):
    print(f"\n📅 Epoch {epoch + 1}/{EPOCHS}")
    train_loss = train(model, train_loader)
    val_acc = evaluate(model, val_loader)
    scheduler.step()
    print(f"📊 Epoch {epoch+1} Results -> Train Loss: {train_loss:.4f} | Val Acc: {val_acc:.4f}")

# -----------------------------
# Save Trained Model
# -----------------------------
print("\n💾 Saving model...")
torch.save(model.state_dict(), "resnet18_tree_species.pth")
print("✅ Model saved as 'resnet18_tree_species.pth'")

# -----------------------------
# Final Test Accuracy
# -----------------------------
print("\n🧪 Evaluating final test accuracy...")
test_acc = evaluate(model, test_loader)
print(f"🎯 Final Test Accuracy: {test_acc:.4f}")


🔧 Setting up model...
✅ Using device: cuda
🔢 Number of classes: 20


Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth
100%|██████████| 83.3M/83.3M [00:00<00:00, 172MB/s]


✅ Model created and moved to device.
✅ Loss function, optimizer, and scheduler initialized.
🏁 Starting training loop...

📅 Epoch 1/20
🚂 Training...
✅ Training step completed.
🔍 Evaluating...
✅ Evaluation step completed.
📊 Epoch 1 Results -> Train Loss: 1.3159 | Val Acc: 0.6632

📅 Epoch 2/20
🚂 Training...
✅ Training step completed.
🔍 Evaluating...
✅ Evaluation step completed.
📊 Epoch 2 Results -> Train Loss: 0.9193 | Val Acc: 0.6896

📅 Epoch 3/20
🚂 Training...
✅ Training step completed.
🔍 Evaluating...
✅ Evaluation step completed.
📊 Epoch 3 Results -> Train Loss: 0.6605 | Val Acc: 0.6904

📅 Epoch 4/20
🚂 Training...
✅ Training step completed.
🔍 Evaluating...
✅ Evaluation step completed.
📊 Epoch 4 Results -> Train Loss: 0.4195 | Val Acc: 0.6941

📅 Epoch 5/20
🚂 Training...
✅ Training step completed.
🔍 Evaluating...
✅ Evaluation step completed.
📊 Epoch 5 Results -> Train Loss: 0.2493 | Val Acc: 0.6945

📅 Epoch 6/20
🚂 Training...
✅ Training step completed.
🔍 Evaluating...
✅ Evaluation step c