In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

In [3]:
import mynnlib
from mynnlib import *

dataset_dir = "insect-dataset/odonata"

early_regex = r"^.*-(early)$"
unidentified_regex = r"^.*-(spp|genera|genera-spp)$"
early_or_unidentified_regex = r"^.*-(early|spp|genera|genera-spp)$"

# Create Dataset

In [4]:
if os.path.exists(f"{dataset_dir}/data"):
    shutil.rmtree(f"{dataset_dir}/data")
os.makedirs(f"{dataset_dir}/data")

In [5]:
# merge early and imago classes
# merge unnamed classes suffixed "-0"

src_dir = "insect-dataset/src/indianodonata.org"
for class_dir in os.listdir(src_dir):
    if not os.path.exists(f"{dataset_dir}/data/{re.sub(r"-(early|0)", "", class_dir)}"):
        os.makedirs(f"{dataset_dir}/data/{re.sub(r"-(early|0)", "", class_dir)}")
    if os.listdir(f"{src_dir}/{class_dir}"):
        if re.match(r"^.*-(early|0)$", class_dir):
            for file in os.listdir(f"{src_dir}/{class_dir}"):
                shutil.copy2(f"{src_dir}/{class_dir}/{file}", f"{dataset_dir}/data/{re.sub(r"-(early|0)", "", class_dir)}/{file}")
        else:
            for file in os.listdir(f"{src_dir}/{class_dir}"):
                shutil.copy2(f"{src_dir}/{class_dir}/{file}", f"{dataset_dir}/data/{class_dir}/{file}")

In [6]:
def copy_data_from(sources, add_early=False):
    class_cnt = 0
    img_cnt = 0
    for more_data_dir in sources:
        for class_dir in os.listdir(f"{dataset_dir}/data"):
            if os.path.exists(f"{more_data_dir}/{class_dir}"):
                # print(f"Copying data for {class_dir}...")
                class_cnt += 1
                for file in os.listdir(f"{more_data_dir}/{class_dir}"):
                    shutil.copy2(f"{more_data_dir}/{class_dir}/{file}", f"{dataset_dir}/data/{class_dir}/{file}")
                    img_cnt += 1
            if add_early and os.path.exists(f"{more_data_dir}/{class_dir}-early"):
                # print(f"Copying data for {class_dir}-early...")
                class_cnt += 1
                os.makedirs(f"{dataset_dir}/data/{class_dir}-early/{file}")
                for file in os.listdir(f"{more_data_dir}/{class_dir}-early"):
                    shutil.copy2(f"{more_data_dir}/{class_dir}-early/{file}", f"{dataset_dir}/data/{class_dir}-early/{file}")
                    img_cnt += 1
    print(f"{img_cnt} images added into {class_cnt} classes")

In [7]:
copy_data_from(["insect-dataset/src/odonata.inaturalist.org"], add_early=False)

32463 images added into 354 classes


In [8]:
remove_file_cnt = 0
valid_file_regex = r"^.*\\.(jpg|jpeg|png|ppm|bmp|pgm|tif|tiff|webp)$"
for class_dir in os.listdir(f"{dataset_dir}/data"):
    for file in os.listdir(f"{dataset_dir}/data/{class_dir}"):
        if not re.match(valid_file_regex, file):
            # os.remove(f"{dataset_dir}/data/{class_dir}/{file}")
            remove_file_cnt += 0
print(f"Removed {remove_file_cnt} unsupported files")

Removed 0 unsupported files


In [9]:
# list empty classes
empty_class_cnt = 0
for class_dir in os.listdir(f"{dataset_dir}/data"):
    if not os.listdir(f"{dataset_dir}/data/{class_dir}"):
        print(class_dir)
        shutil.rmtree(f"{dataset_dir}/data/{class_dir}")
        empty_class_cnt += 1
print(f"\nRemoved {empty_class_cnt} empty classes")

aciagrion-azureum
acrogomphus-mohani
aeshna-donaldi
agriocnemis-dabreui
agrionoptera-dorothea
anisogomphus-orites
aristocypha-immaculata
asiagomphus-personatus
bayadera-kali
bayadera-longicauda
burmargiolestes-laidlawi
caconeura-gomphoides
caconeura-obscura
calicnemia-mukherjeei
calicnemia-pyrrhosoma
cephalaeschna-klapperichi
cephalaeschna-masoni
chlorogomphus-schmidti
chloropetalia-olympicus
coeliccia-dorothea
coeliccia-prakritiae
coeliccia-rossi
coeliccia-sarbottama
coeliccia-vacca
davidius-kumaonensis
davidius-malloryi
davidius-zallorensis
drepanosticta-annandalei
elattoneura-nihari
enallagma-immsi
epallage-fatime
gynacantha-albistyla
gynacantha-andamanae
gynacantha-apicalis
gynacantha-biharica
gynacantha-odoneli
gynacantha-pallampurica
gynacantha-rammohani
gynacantha-rotundata
himalagrion-exclamatione
ictinogomphus-atrox
idionyx-galeata
idionyx-imbricata
idionyx-intricata
idionyx-minima
idionyx-nadganiensis
idionyx-nilgiriensis
idionyx-periyashola
idionyx-rhinoceroides
ischnura-nul

# Create val dataset

In [10]:
if os.path.exists(f"{dataset_dir}/val"):
    shutil.rmtree(f"{dataset_dir}/val")
os.makedirs(f"{dataset_dir}/val")

In [11]:
move_src = "data"
move_dst = "val"
val_data_ratio = 0.03
val_data_cnt = 0
for class_dir in os.listdir(f"{dataset_dir}/{move_src}"):
    for file in os.listdir(f"{dataset_dir}/{move_src}/{class_dir}"):
        if random.random() < val_data_ratio:
            if not os.path.exists(f"{dataset_dir}/{move_dst}/{class_dir}"):
                os.makedirs(f"{dataset_dir}/{move_dst}/{class_dir}")
            shutil.move(f"{dataset_dir}/{move_src}/{class_dir}/{file}", f"{dataset_dir}/{move_dst}/{class_dir}/")
            val_data_cnt += 1
print(f"{val_data_cnt} images moved from {move_src} to {move_dst}")

1332 images moved from data to val


# Count

In [12]:
classes = { class_dir: len([ img for img in os.listdir(f"{dataset_dir}/data/{class_dir}") ]) for class_dir in os.listdir(f"{dataset_dir}/data") }
early_classes = { class_name: count for class_name, count in classes.items() if re.match(early_regex, class_name) }
unidentified_classes = { class_name: count for class_name, count in classes.items() if re.match(unidentified_regex, class_name) }
print(f"Total Class count : {len(classes):6} ( Unidentified: {len(unidentified_classes):6} / Early-stage: {len(early_classes):6} / Identified-adult: {len(classes) - len(unidentified_classes) - len(early_classes):6} )")
print(f"Total  Data count : {sum(classes.values()):6} ( Unidentified: {sum(unidentified_classes.values()):6} / Early-stage: {sum(early_classes.values()):6} / Identified-adult: {sum(classes.values()) - sum(unidentified_classes.values()) - sum(early_classes.values()):6} )")

Total Class count :    427 ( Unidentified:     27 / Early-stage:      0 / Identified-adult:    400 )
Total  Data count :  44256 ( Unidentified:    302 / Early-stage:      0 / Identified-adult:  43954 )


In [13]:
img2_class = []
img5_class = []
for class_dir in os.listdir(f"{dataset_dir}/data"):
    if not re.match(early_or_unidentified_regex, class_dir):
        img_cnt = sum([1 for file in os.listdir(f"{dataset_dir}/data/{class_dir}")])
        img2_class += [class_dir] if img_cnt <= 2 else []
        img5_class += [class_dir] if img_cnt <= 5 else []
print(f"{len(img2_class):6} classes with <=2 images")
print(f"{len(img5_class):6} classes with <=5 images")

    25 classes with <=2 images
    56 classes with <=5 images


In [14]:
generas = set()
for class_name in classes:
    generas.add(class_name.split('-')[0])
print(f"Genera count: {len(generas)}")

Genera count: 145


# Train

### Model A (resnet-152) ***

In [15]:
training_params = [
    { "idx": 1, "robustness": 0.2, "break_at_val_acc_diff": 0.05},
    { "idx": 2, "robustness": 0.5, "break_at_val_acc_diff": 0.02},
    { "idx": 3, "robustness": 1.0, "break_at_val_acc_diff": 0.01},
    { "idx": 4, "robustness": 2.0, "break_at_val_acc_diff": -0.000001},
    { "idx": 5, "robustness": 2.0, "break_at_val_acc_diff": -0.000001},
    { "idx": 6, "robustness": 2.0, "break_at_val_acc_diff": -0.000001}
]
for param in training_params:
    print(f"Phase {param["idx"]}:")
    if param["idx"] == 1:
        model_data = init_model_for_training(f'{dataset_dir}/data', f'{dataset_dir}/val', 
                                             batch_size=32, arch="resnet152", image_size=224, robustness=param["robustness"],
                                             lr=1e-4, weight_decay=1e-4, silent=True)
    else:
        model_data = prepare_for_retraining(model_data, f'{dataset_dir}/data', f'{dataset_dir}/val', 
                                            batch_size=32, image_size=224, robustness=param["robustness"], silent=True)
    train(model_data, 5, f"{dataset_dir}/checkpoint.odonata.ta.ep{param["idx"]:02}###.pth", 
          break_at_val_acc_diff=param["break_at_val_acc_diff"])

Phase 1:
Epoch    1 /    5  | Train Loss: 2.5883 Acc: 0.4226  | Val Loss: 1.2730 Acc: 0.6659  | Elapsed time: 0:14:55.580652
Epoch    2 /    5  | Train Loss: 0.9480 Acc: 0.7312  | Val Loss: 0.9276 Acc: 0.7320  | Elapsed time: 0:26:11.469817
Epoch    3 /    5  | Train Loss: 0.5371 Acc: 0.8314  | Val Loss: 0.8259 Acc: 0.7800  | Elapsed time: 0:38:14.889094
Phase 2:
Epoch    1 /    5  | Train Loss: 1.5621 Acc: 0.5991  | Val Loss: 0.9821 Acc: 0.7200  | Elapsed time: 0:12:02.660414
Epoch    2 /    5  | Train Loss: 1.2620 Acc: 0.6670  | Val Loss: 0.8866 Acc: 0.7523  | Elapsed time: 0:23:54.011407
Epoch    3 /    5  | Train Loss: 1.1369 Acc: 0.6969  | Val Loss: 0.8554 Acc: 0.7590  | Elapsed time: 0:39:33.472592
Phase 3:
Epoch    1 /    5  | Train Loss: 1.1157 Acc: 0.7028  | Val Loss: 0.8933 Acc: 0.7530  | Elapsed time: 0:12:17.588497
Epoch    2 /    5  | Train Loss: 0.9032 Acc: 0.7602  | Val Loss: 0.7417 Acc: 0.7875  | Elapsed time: 0:24:34.782030
Epoch    3 /    5  | Train Loss: 0.8050 Acc: 

In [58]:
model_data = torch.load(f"{dataset_dir}/checkpoint.odonata.ta.ep060000.pth", weights_only=False)