In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

In [3]:
import mynnlib
from mynnlib import *

dataset_dir = "insect-dataset/odonata"

early_regex = r"^.*-(early)$"
unidentified_regex = r"^.*-(spp|genera|genera-spp)$"
early_or_unidentified_regex = r"^.*-(early|spp|genera|genera-spp)$"

# Create Dataset

In [46]:
if os.path.exists(f"{dataset_dir}/data"):
    shutil.rmtree(f"{dataset_dir}/data")
os.makedirs(f"{dataset_dir}/data")

In [47]:
# merge early and imago classes
src_dir = "insect-dataset/src/indianodonata.org"
for class_dir in os.listdir(src_dir):
    if not os.path.exists(f"{dataset_dir}/data/{re.sub(r"-early", "", class_dir)}"):
        os.makedirs(f"{dataset_dir}/data/{re.sub(r"-early", "", class_dir)}")
    if os.listdir(f"{src_dir}/{class_dir}"):
        if class_dir.endswith("-early"):
            for file in os.listdir(f"{src_dir}/{class_dir}"):
                shutil.copy2(f"{src_dir}/{class_dir}/{file}", f"{dataset_dir}/data/{re.sub(r"-early", "", class_dir)}/{file}")
        else:
            for file in os.listdir(f"{src_dir}/{class_dir}"):
                shutil.copy2(f"{src_dir}/{class_dir}/{file}", f"{dataset_dir}/data/{class_dir}/{file}")

In [48]:
def copy_data_from(sources, add_early=False):
    class_cnt = 0
    img_cnt = 0
    for more_data_dir in sources:
        for class_dir in os.listdir(f"{dataset_dir}/data"):
            if os.path.exists(f"{more_data_dir}/{class_dir}"):
                # print(f"Copying data for {class_dir}...")
                class_cnt += 1
                for file in os.listdir(f"{more_data_dir}/{class_dir}"):
                    shutil.copy2(f"{more_data_dir}/{class_dir}/{file}", f"{dataset_dir}/data/{class_dir}/{file}")
                    img_cnt += 1
            if add_early and os.path.exists(f"{more_data_dir}/{class_dir}-early"):
                # print(f"Copying data for {class_dir}-early...")
                class_cnt += 1
                os.makedirs(f"{dataset_dir}/data/{class_dir}-early/{file}")
                for file in os.listdir(f"{more_data_dir}/{class_dir}-early"):
                    shutil.copy2(f"{more_data_dir}/{class_dir}-early/{file}", f"{dataset_dir}/data/{class_dir}-early/{file}")
                    img_cnt += 1
    print(f"{img_cnt} images added into {class_cnt} classes")

In [49]:
copy_data_from(["insect-dataset/src/inaturalist.org"], add_early=False)

31089 images added into 332 classes


In [50]:
removed_cnt = 0
for class_dir in os.listdir(f"{dataset_dir}/data"):
    if not os.listdir(f"{dataset_dir}/data/{class_dir}"):
        shutil.rmtree(f"{dataset_dir}/data/{class_dir}")
        removed_cnt += 1
print(f"Removed {removed_cnt} empty classes")

Removed 132 empty classes


In [51]:
remove_file_cnt = 0
valid_file_regex = r"^.*\\.(jpg|jpeg|png|ppm|bmp|pgm|tif|tiff|webp)$"
for class_dir in os.listdir(f"{dataset_dir}/data"):
    for file in os.listdir(f"{dataset_dir}/data/{class_dir}"):
        if not re.match(valid_file_regex, file):
            # os.remove(f"{dataset_dir}/data/{class_dir}/{file}")
            remove_file_cnt += 0
print(f"Removed {remove_file_cnt} unsupported files")

Removed 0 unsupported files


# Create val dataset

In [52]:
if os.path.exists(f"{dataset_dir}/val"):
    shutil.rmtree(f"{dataset_dir}/val")
os.makedirs(f"{dataset_dir}/val")

In [53]:
move_src = "data"
move_dst = "val"
val_data_ratio = 0.03
val_data_cnt = 0
for class_dir in os.listdir(f"{dataset_dir}/{move_src}"):
    for file in os.listdir(f"{dataset_dir}/{move_src}/{class_dir}"):
        if random.random() < val_data_ratio:
            if not os.path.exists(f"{dataset_dir}/{move_dst}/{class_dir}"):
                os.makedirs(f"{dataset_dir}/{move_dst}/{class_dir}")
            shutil.move(f"{dataset_dir}/{move_src}/{class_dir}/{file}", f"{dataset_dir}/{move_dst}/{class_dir}/")
            val_data_cnt += 1
print(f"{val_data_cnt} images moved from {move_src} to {move_dst}")

1295 images moved from data to val


# Count

In [54]:
classes = { class_dir: len([ img for img in os.listdir(f"{dataset_dir}/data/{class_dir}") ]) for class_dir in os.listdir(f"{dataset_dir}/data") }
early_classes = { class_name: count for class_name, count in classes.items() if re.match(early_regex, class_name) }
unidentified_classes = { class_name: count for class_name, count in classes.items() if re.match(unidentified_regex, class_name) }
print(f"Total Class count : {len(classes):6} ( Unidentified: {len(unidentified_classes):6} / Early-stage: {len(early_classes):6} / Identified-adult: {len(classes) - len(unidentified_classes) - len(early_classes):6} )")
print(f"Total  Data count : {sum(classes.values()):6} ( Unidentified: {sum(unidentified_classes.values()):6} / Early-stage: {sum(early_classes.values()):6} / Identified-adult: {sum(classes.values()) - sum(unidentified_classes.values()) - sum(early_classes.values()):6} )")

Total Class count :    405 ( Unidentified:     27 / Early-stage:      0 / Identified-adult:    378 )
Total  Data count :  42919 ( Unidentified:    296 / Early-stage:      0 / Identified-adult:  42623 )


In [55]:
img2_class = []
img5_class = []
for class_dir in os.listdir(f"{dataset_dir}/data"):
    if not re.match(early_or_unidentified_regex, class_dir):
        img_cnt = sum([1 for file in os.listdir(f"{dataset_dir}/data/{class_dir}")])
        img2_class += [class_dir] if img_cnt <= 2 else []
        img5_class += [class_dir] if img_cnt <= 5 else []
print(f"{len(img2_class):6} classes with <=2 images")
print(f"{len(img5_class):6} classes with <=5 images")

    17 classes with <=2 images
    48 classes with <=5 images


In [56]:
generas = set()
for class_name in classes:
    generas.add(class_name.split('-')[0])
print(f"Genera count: {len(generas)}")

Genera count: 143


# Train

### Model A (resnet-152)

In [57]:
training_params = [
    { "idx": 1, "robustness": 0.2, "break_at_val_acc_diff": 0.05},
    { "idx": 2, "robustness": 0.5, "break_at_val_acc_diff": 0.02},
    { "idx": 3, "robustness": 1.0, "break_at_val_acc_diff": 0.01},
    { "idx": 4, "robustness": 2.0, "break_at_val_acc_diff": -0.000001},
    { "idx": 5, "robustness": 2.0, "break_at_val_acc_diff": -0.000001},
    { "idx": 6, "robustness": 2.0, "break_at_val_acc_diff": -0.000001}
]
for param in training_params:
    print(f"Phase {param["idx"]}:")
    if param["idx"] == 1:
        model_data = init_model_for_training(f'{dataset_dir}/data', f'{dataset_dir}/val', 
                                             batch_size=32, arch="resnet152", image_size=224, robustness=param["robustness"],
                                             lr=1e-4, weight_decay=1e-4, silent=True)
    else:
        model_data = prepare_for_retraining(model_data, f'{dataset_dir}/data', f'{dataset_dir}/val', 
                                            batch_size=32, image_size=224, robustness=param["robustness"], silent=True)
    train(model_data, 5, f"{dataset_dir}/checkpoint.odonata.ta.ep{param["idx"]:02}###.pth", 
          break_at_val_acc_diff=param["break_at_val_acc_diff"])

Phase 1:
Epoch    1 /    5  | Train Loss: 2.5392 Acc: 0.4311  | Val Loss: 1.2621 Acc: 0.6672  | Elapsed time: 0:22:38.730906
Epoch    2 /    5  | Train Loss: 0.9245 Acc: 0.7414  | Val Loss: 0.8889 Acc: 0.7367  | Elapsed time: 0:41:19.308152
Epoch    3 /    5  | Train Loss: 0.5251 Acc: 0.8412  | Val Loss: 0.8412 Acc: 0.7591  | Elapsed time: 0:52:42.493943
Phase 2:
Epoch    1 /    5  | Train Loss: 1.5526 Acc: 0.6038  | Val Loss: 1.0386 Acc: 0.7174  | Elapsed time: 0:11:43.037567
Epoch    2 /    5  | Train Loss: 1.2540 Acc: 0.6713  | Val Loss: 1.0289 Acc: 0.7042  | Elapsed time: 0:23:20.584905
Phase 3:
Epoch    1 /    5  | Train Loss: 1.1648 Acc: 0.6935  | Val Loss: 0.8529 Acc: 0.7529  | Elapsed time: 0:11:37.235206
Epoch    2 /    5  | Train Loss: 1.0919 Acc: 0.7115  | Val Loss: 0.8612 Acc: 0.7598  | Elapsed time: 0:23:09.610912
Phase 4:
Epoch    1 /    5  | Train Loss: 0.9742 Acc: 0.7442  | Val Loss: 0.7642 Acc: 0.7838  | Elapsed time: 0:11:41.106220
Epoch    2 /    5  | Train Loss: 0.8

In [58]:
model_data = torch.load(f"{dataset_dir}/checkpoint.odonata.ta.ep060000.pth", weights_only=False)