In [2]:
dataset_dir = "insect-dataset/moth"

In [21]:
import shutil
import os
import time
import datetime
import random
import numpy as np
from pathlib import Path
from PIL import Image
import pprint
import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import models
import torch.nn as nn
import torch.nn.functional as F

In [6]:
def split_data_for_train_and_val(data_dir, test_dir, val_dir, train_dir, test_data_weight, val_data_weight, min_file_cnt_for_val):
    train_data_cnt = 0
    val_data_cnt = 0
    test_data_cnt = 0
    class_cnt = 0
    
    for class_dir in Path(data_dir).iterdir():
        if class_dir.is_dir() and os.listdir(class_dir):
            class_cnt = class_cnt + 1
            file_count = sum(1 for file in class_dir.iterdir() if file.is_file())
            for file in Path(class_dir).iterdir():
                if file.is_file():
                    random_float = random.random()
                    class_dir_name = class_dir.name
                    if file_count >= min_file_cnt_for_val and random_float < test_data_weight:
                        target_dir = test_dir
                        test_data_cnt = test_data_cnt + 1
                    elif file_count >= min_file_cnt_for_val and random_float < test_data_weight + val_data_weight:
                        target_dir = val_dir
                        val_data_cnt = val_data_cnt + 1
                    else:
                        target_dir = train_dir
                        train_data_cnt = train_data_cnt + 1
                    target_dir_path = f"{target_dir}/{class_dir_name}"
                    if not os.path.exists(target_dir_path):
                        os.makedirs(target_dir_path)
                    shutil.copy(file, target_dir_path)

    print(f"Class count: {class_cnt}")
    print(f"Training data count: {train_data_cnt}")
    print(f"Validation data count: {val_data_cnt}")
    print(f"Test data count: {test_data_cnt}")

In [7]:
def init_model_for_training(train_dir, val_dir, batch_size):
    transform = {
        'train': transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]),
        'val': transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]),
    }
    training_datasets = {
        'train': datasets.ImageFolder(root=train_dir, transform=transform['train']),
        'val': datasets.ImageFolder(root=val_dir, transform=transform['val']),
    }
    dataloaders = {
        'train': DataLoader(training_datasets['train'], batch_size=batch_size, shuffle=True),
        'val': DataLoader(training_datasets['val'], batch_size=batch_size, shuffle=False),
    }
    class_names = training_datasets['train'].classes
    
    model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
    num_classes = len(class_names)
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, num_classes)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

    return {
        'model': model,
        'device': device,
        'transform': transform,
        'datasets': training_datasets,
        'dataloaders': dataloaders,
        'class_names': class_names,
        'num_classes': num_classes,
        'num_features': num_features,
        'criterion': criterion,
        'optimizer': optimizer,
        'scheduler': scheduler
    }

In [8]:
def train(model_data, num_epochs, model_path):
    start_time = time.time()
    for epoch in range(num_epochs):
        print(f"Epoch {(epoch+1):4} / {num_epochs:4}", end=' ')
        for phase in ['train', 'val']:
            if phase == 'train':
                model_data['model'].train()
            else:
                model_data['model'].eval()
            running_loss = 0.0
            running_corrects = 0
            for inputs, labels in model_data['dataloaders'][phase]:
                inputs, labels = inputs.to(model_data['device']), labels.to(model_data['device'])
                model_data['optimizer'].zero_grad()
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model_data['model'](inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = model_data['criterion'](outputs, labels)
                    if phase == 'train':
                        loss.backward()
                        model_data['optimizer'].step()
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
    
            epoch_loss = running_loss / len(model_data['datasets'][phase])
            epoch_acc = running_corrects.double() / len(model_data['datasets'][phase])
            print(f" | {phase.capitalize()} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}", end=' ')
            if phase == 'train':
                model_data['scheduler'].step()
        print(f" | Elapsed time: {datetime.timedelta(seconds=(time.time() - start_time))}")
        torch.save(model_data['model'].state_dict(), model_path)

In [15]:
def predict(image_path, model_data):
    image = Image.open(image_path).convert("RGB")
    image = model_data['transform']['val'](image).unsqueeze(0).to(model_data['device'])
    with torch.no_grad():
        outputs = model_data['model'](image)
        _, preds = torch.max(outputs, 1)
    try:
        return model_data['class_names'][preds[0]]
    except (Exception):
        return None

def predict_top_k(image_path, model_data, k):
    image = Image.open(image_path).convert("RGB")
    image = model_data['transform']['val'](image).unsqueeze(0).to(model_data['device'])
    with torch.no_grad():
        outputs = model_data['model'](image)
        probabilities = F.softmax(outputs, dim=1)
        top_probs, top_indices = torch.topk(probabilities, k)
    try:
        return {model_data['class_names'][top_indices[0][i]]: top_probs[0][i].item() for i in range(0, k)}
    except (Exception):
        return None

In [27]:
def validate_prediction_in_dir(test_dir, model_data):
    total = 0
    success = 0
    failures = {}
    for species_dir in Path(test_dir).iterdir():
        if species_dir.is_dir():
            for file in Path(f"{species_dir}").iterdir():
                if file.is_file():
                    species = file.parts[-2]
                    prediction = predict(file, model_data)
                    is_success = (species==prediction)
                    if not is_success:
                        failures[species] = prediction
                    total = total + 1
                    if is_success:
                        success = success + 1
    return {
        'total': total, 
        'success': success,
        'failures': failures
    }

def test(checkpoint, test_dir, print_failures=True):
    model_data['model'].load_state_dict(torch.load(checkpoint, weights_only=False))
    model_data['model'].eval()
    start_time = time.time()
    prediction = validate_prediction_in_dir(test_dir, model_data)
    print(f"Accuracy: {prediction['success']} / {prediction['total']} -> {100*prediction['success']/prediction['total']:.2f}%")
    print(f"Elapsed time: {datetime.timedelta(seconds=(time.time() - start_time))}")
    if print_failures:
        print("-"*10)
        print("Failures:")
        pprint.pprint(prediction['failures'])

# Extract small dataset and train

In [13]:
def extract_proto_dataset(data_dir, proto_data_dir, limit):
    file_cnt = 0
    for class_dir in Path(data_dir).iterdir():
        if class_dir.is_dir() and os.listdir(class_dir):
            file_count = sum(1 for file in class_dir.iterdir() if file.is_file())
            class_dir_name = class_dir.name
            for file in Path(class_dir).iterdir():
                if file.is_file():
                    target_dir_path = f"{proto_data_dir}/{class_dir_name}"
                    if not os.path.exists(target_dir_path):
                        os.makedirs(target_dir_path)
                    shutil.copy(file, target_dir_path)
                    file_cnt = file_cnt + 1
                    if(file_cnt >= limit):
                        return

In [16]:
extract_proto_dataset(f"{dataset_dir}/data", f"{dataset_dir}/proto/data", 3000)

In [17]:
if os.path.exists(f"{dataset_dir}/proto/test"):
    shutil.rmtree(f"{dataset_dir}/proto/test")
if os.path.exists(f"{dataset_dir}/proto/val"):
    shutil.rmtree(f"{dataset_dir}/proto/val")
if os.path.exists(f"{dataset_dir}/proto/train"):
    shutil.rmtree(f"{dataset_dir}/proto/train")
    
split_data_for_train_and_val(f"{dataset_dir}/proto/data", f"{dataset_dir}/proto/test", f"{dataset_dir}/proto/val", f"{dataset_dir}/proto/train", 0.1, 0.2, 4)

Class count: 199
Training data count: 2108
Validation data count: 603
Test data count: 289


In [18]:
model_data = init_model_for_training(f'{dataset_dir}/proto/train', f'{dataset_dir}/proto/val', 32)

In [19]:
train(model_data, 5, f"{dataset_dir}/proto/checkpoint_latest.pth")

Epoch    1 /    5  | Train Loss: 4.0851 Acc: 0.2106  | Val Loss: 8.0438 Acc: 0.0033  | Elapsed time: 0:00:30.503317
Epoch    2 /    5  | Train Loss: 2.8633 Acc: 0.3629  | Val Loss: 8.0567 Acc: 0.0033  | Elapsed time: 0:00:48.362525
Epoch    3 /    5  | Train Loss: 2.0931 Acc: 0.4881  | Val Loss: 9.6558 Acc: 0.0017  | Elapsed time: 0:01:06.320642
Epoch    4 /    5  | Train Loss: 1.5604 Acc: 0.5916  | Val Loss: 11.3176 Acc: 0.0066  | Elapsed time: 0:01:24.400304
Epoch    5 /    5  | Train Loss: 1.2717 Acc: 0.6551  | Val Loss: 11.8641 Acc: 0.0017  | Elapsed time: 0:01:42.596990


In [29]:
test(f"{dataset_dir}/proto/checkpoint_latest.pth", f"{dataset_dir}/proto/test", False)

Accuracy: 140 / 289 -> 48.44%
Elapsed time: 0:00:02.401079


In [30]:
train(model_data, 95, f"{dataset_dir}/proto/checkpoint_latest.pth")

Epoch    1 /   95  | Train Loss: 0.9895 Acc: 0.7234  | Val Loss: 11.5017 Acc: 0.0017  | Elapsed time: 0:00:17.201074
Epoch    2 /   95  | Train Loss: 0.7013 Acc: 0.8065  | Val Loss: 12.6858 Acc: 0.0033  | Elapsed time: 0:00:34.812139
Epoch    3 /   95  | Train Loss: 0.3409 Acc: 0.9146  | Val Loss: 12.4507 Acc: 0.0017  | Elapsed time: 0:00:53.054634
Epoch    4 /   95  | Train Loss: 0.1925 Acc: 0.9725  | Val Loss: 12.4425 Acc: 0.0017  | Elapsed time: 0:01:11.669036
Epoch    5 /   95  | Train Loss: 0.1476 Acc: 0.9791  | Val Loss: 12.4741 Acc: 0.0017  | Elapsed time: 0:01:30.443407
Epoch    6 /   95  | Train Loss: 0.1316 Acc: 0.9848  | Val Loss: 12.5994 Acc: 0.0017  | Elapsed time: 0:01:48.753914
Epoch    7 /   95  | Train Loss: 0.1140 Acc: 0.9886  | Val Loss: 12.5290 Acc: 0.0017  | Elapsed time: 0:02:06.756473
Epoch    8 /   95  | Train Loss: 0.1009 Acc: 0.9924  | Val Loss: 12.6097 Acc: 0.0017  | Elapsed time: 0:02:24.973099
Epoch    9 /   95  | Train Loss: 0.0866 Acc: 0.9943  | Val Loss:

In [31]:
shutil.copy(f"{dataset_dir}/proto/checkpoint_latest.pth", f"{dataset_dir}/proto/checkpoint_{int(time.time())}.pth")

'insect-dataset/moth/proto/checkpoint_1737815088.pth'

In [32]:
test(f"{dataset_dir}/proto/checkpoint_latest.pth", f"{dataset_dir}/proto/test", False)

Accuracy: 226 / 289 -> 78.20%
Elapsed time: 0:00:02.713052


In [33]:
train(model_data, 100, f"{dataset_dir}/proto/checkpoint_latest.pth")

Epoch    1 /  100  | Train Loss: 0.0633 Acc: 0.9976  | Val Loss: 12.7164 Acc: 0.0017  | Elapsed time: 0:00:16.811785
Epoch    2 /  100  | Train Loss: 0.0673 Acc: 0.9957  | Val Loss: 12.6889 Acc: 0.0017  | Elapsed time: 0:00:34.740604
Epoch    3 /  100  | Train Loss: 0.0659 Acc: 0.9976  | Val Loss: 12.6630 Acc: 0.0017  | Elapsed time: 0:00:53.092367
Epoch    4 /  100  | Train Loss: 0.0640 Acc: 0.9972  | Val Loss: 12.7580 Acc: 0.0017  | Elapsed time: 0:01:10.829420
Epoch    5 /  100  | Train Loss: 0.0653 Acc: 0.9976  | Val Loss: 12.7131 Acc: 0.0017  | Elapsed time: 0:01:28.232424
Epoch    6 /  100  | Train Loss: 0.0612 Acc: 0.9986  | Val Loss: 12.7633 Acc: 0.0017  | Elapsed time: 0:01:45.719025
Epoch    7 /  100  | Train Loss: 0.0642 Acc: 0.9972  | Val Loss: 12.6468 Acc: 0.0017  | Elapsed time: 0:02:02.850533
Epoch    8 /  100  | Train Loss: 0.0651 Acc: 0.9976  | Val Loss: 12.7679 Acc: 0.0017  | Elapsed time: 0:02:20.947548
Epoch    9 /  100  | Train Loss: 0.0641 Acc: 0.9981  | Val Loss:

In [34]:
shutil.copy(f"{dataset_dir}/proto/checkpoint_latest.pth", f"{dataset_dir}/proto/checkpoint_{int(time.time())}.pth")

'insect-dataset/moth/proto/checkpoint_1737817230.pth'

In [35]:
test(f"{dataset_dir}/proto/checkpoint_latest.pth", f"{dataset_dir}/proto/test", False)

Accuracy: 224 / 289 -> 77.51%
Elapsed time: 0:00:02.545990
