In [11]:
%load_ext autoreload

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
%autoreload 2

In [12]:
import mynnlib
from mynnlib import *

dataset_dir = "insect-dataset/cicada"

early_regex = r"^.*-(early)$"
unidentified_regex = r"^.*-(spp|genera|genera-spp)$"
early_or_unidentified_regex = r"^.*-(early|spp|genera|genera-spp)$"

# Create Dataset

In [20]:
if os.path.exists(f"{dataset_dir}/data"):
    shutil.rmtree(f"{dataset_dir}/data")
os.makedirs(f"{dataset_dir}/data")

In [21]:
# merge early and imago classes
# merge unnamed classes suffixed "-0"

src_dir = "insect-dataset/src/indiancicadas.org"
for class_dir in os.listdir(src_dir):
    if not os.path.exists(f"{dataset_dir}/data/{re.sub(r"-(early|0)", "", class_dir)}"):
        os.makedirs(f"{dataset_dir}/data/{re.sub(r"-(early|0)", "", class_dir)}")
    if os.listdir(f"{src_dir}/{class_dir}"):
        if re.match(r"^.*-(early|0)$", class_dir):
            for file in os.listdir(f"{src_dir}/{class_dir}"):
                shutil.copy2(f"{src_dir}/{class_dir}/{file}", f"{dataset_dir}/data/{re.sub(r"-(early|0)", "", class_dir)}/{file}")
        else:
            for file in os.listdir(f"{src_dir}/{class_dir}"):
                shutil.copy2(f"{src_dir}/{class_dir}/{file}", f"{dataset_dir}/data/{class_dir}/{file}")

In [22]:
def copy_data_from(sources, add_early=False):
    class_cnt = 0
    img_cnt = 0
    for more_data_dir in sources:
        for class_dir in os.listdir(f"{dataset_dir}/data"):
            if os.path.exists(f"{more_data_dir}/{class_dir}"):
                # print(f"Copying data for {class_dir}...")
                class_cnt += 1
                for file in os.listdir(f"{more_data_dir}/{class_dir}"):
                    shutil.copy2(f"{more_data_dir}/{class_dir}/{file}", f"{dataset_dir}/data/{class_dir}/{file}")
                    img_cnt += 1
            if add_early and os.path.exists(f"{more_data_dir}/{class_dir}-early"):
                # print(f"Copying data for {class_dir}-early...")
                class_cnt += 1
                os.makedirs(f"{dataset_dir}/data/{class_dir}-early/{file}")
                for file in os.listdir(f"{more_data_dir}/{class_dir}-early"):
                    shutil.copy2(f"{more_data_dir}/{class_dir}-early/{file}", f"{dataset_dir}/data/{class_dir}-early/{file}")
                    img_cnt += 1
    print(f"{img_cnt} images added into {class_cnt} classes")

In [23]:
copy_data_from(["insect-dataset/src/cicada.inaturalist.org"], add_early=False)

6998 images added into 154 classes


In [24]:
remove_file_cnt = 0
valid_file_regex = r"^.*\\.(jpg|jpeg|png|ppm|bmp|pgm|tif|tiff|webp)$"
for class_dir in os.listdir(f"{dataset_dir}/data"):
    for file in os.listdir(f"{dataset_dir}/data/{class_dir}"):
        if not re.match(valid_file_regex, file):
            # os.remove(f"{dataset_dir}/data/{class_dir}/{file}")
            remove_file_cnt += 0
print(f"Removed {remove_file_cnt} unsupported files")

Removed 0 unsupported files


In [25]:
# list empty classes
empty_class_cnt = 0
for class_dir in os.listdir(f"{dataset_dir}/data"):
    if not os.listdir(f"{dataset_dir}/data/{class_dir}"):
        print(class_dir)
        shutil.rmtree(f"{dataset_dir}/data/{class_dir}")
        empty_class_cnt += 1
print(f"\nRemoved {empty_class_cnt} empty classes")

abricta-brunnea
abricta-pusilla
abroma-apicalis
balinta-pulchella
balinta-sanguiniventris
becquartina-goera
calcagninus-divaricatus
calcagninus-nilgirensis
callogaeana-annamensis
chremistica-germana
chremistica-viridis
cicada-complex
cicada-conspurcata
cicada-olivierana
cicadatra-anoea
cicadatra-gingat
cicadatra-intermedia
cicadatra-karachiensis
cicadatra-minuta
cicadatra-raja
cicadatra-walkeri
cicadetta-inglisi
cicadetta-intermedia
cicadetta-minuta
cryptotympana-auropilosa
dundubia-ensifera
dundubia-myitkyinensis
emathia-aegrota
emathia-dup-aegrota
eopycna-himalayana
eopycna-minor
eopycna-montana
eopycna-spp
eopycna-verna
gaeana-consors
huechys-haematica
hyalessa-melanoptera
hyalessa-virescens
khimbya-cuneata
kumanga-sandaracata
linguacicada-continuata
meimuna-pallida
meimuna-velitaris
melampsalta-mogannia
mogannia-aurea
mogannia-spurcata
mogannia-venutissima
neotanna-thalia
orientopsaltria-beaudouini
panka-simulata
paratanna-parata
platylomia-juno
platylomia-lemoultii
platylomia-mali

# Create val dataset

In [26]:
if os.path.exists(f"{dataset_dir}/val"):
    shutil.rmtree(f"{dataset_dir}/val")
os.makedirs(f"{dataset_dir}/val")

In [27]:
move_src = "data"
move_dst = "val"
val_data_ratio = 0.01
val_data_cnt = 0
for class_dir in os.listdir(f"{dataset_dir}/{move_src}"):
    for file in os.listdir(f"{dataset_dir}/{move_src}/{class_dir}"):
        if random.random() < val_data_ratio:
            if not os.path.exists(f"{dataset_dir}/{move_dst}/{class_dir}"):
                os.makedirs(f"{dataset_dir}/{move_dst}/{class_dir}")
            shutil.move(f"{dataset_dir}/{move_src}/{class_dir}/{file}", f"{dataset_dir}/{move_dst}/{class_dir}/")
            val_data_cnt += 1
print(f"{val_data_cnt} images moved from {move_src} to {move_dst}")

84 images moved from data to val


# Count

In [28]:
classes = { class_dir: len([ img for img in os.listdir(f"{dataset_dir}/data/{class_dir}") ]) for class_dir in os.listdir(f"{dataset_dir}/data") }
early_classes = { class_name: count for class_name, count in classes.items() if re.match(early_regex, class_name) }
unidentified_classes = { class_name: count for class_name, count in classes.items() if re.match(unidentified_regex, class_name) }
print(f"Total Class count : {len(classes):6} ( Unidentified: {len(unidentified_classes):6} / Early-stage: {len(early_classes):6} / Identified-adult: {len(classes) - len(unidentified_classes) - len(early_classes):6} )")
print(f"Total  Data count : {sum(classes.values()):6} ( Unidentified: {sum(unidentified_classes.values()):6} / Early-stage: {sum(early_classes.values()):6} / Identified-adult: {sum(classes.values()) - sum(unidentified_classes.values()) - sum(early_classes.values()):6} )")

Total Class count :    217 ( Unidentified:      0 / Early-stage:      0 / Identified-adult:    217 )
Total  Data count :   7933 ( Unidentified:      0 / Early-stage:      0 / Identified-adult:   7933 )


In [29]:
img2_class = []
img5_class = []
for class_dir in os.listdir(f"{dataset_dir}/data"):
    if not re.match(early_or_unidentified_regex, class_dir):
        img_cnt = sum([1 for file in os.listdir(f"{dataset_dir}/data/{class_dir}")])
        img2_class += [class_dir] if img_cnt <= 2 else []
        img5_class += [class_dir] if img_cnt <= 5 else []
print(f"{len(img2_class):6} classes with <=2 images")
print(f"{len(img5_class):6} classes with <=5 images")

    55 classes with <=2 images
    78 classes with <=5 images


In [30]:
generas = set()
for class_name in classes:
    generas.add(class_name.split('-')[0])
print(f"Genera count: {len(generas)}")

Genera count: 65


# Train

### Model A (resnet-50)

In [15]:
training_params = [
    { "idx": 1, "robustness": 0.2, "break_at_val_acc_diff": 0.05},
    { "idx": 2, "robustness": 0.5, "break_at_val_acc_diff": 0.02},
    { "idx": 3, "robustness": 1.0, "break_at_val_acc_diff": 0.01},
    { "idx": 4, "robustness": 2.0, "break_at_val_acc_diff": -0.000001},
    { "idx": 5, "robustness": 2.0, "break_at_val_acc_diff": -0.000001},
    { "idx": 6, "robustness": 2.0, "break_at_val_acc_diff": -0.000001}
]
for param in training_params:
    print(f"Phase {param["idx"]}:")
    if param["idx"] == 1:
        model_data = init_model_for_training(f'{dataset_dir}/data', f'{dataset_dir}/val', 
                                             batch_size=32, arch="resnet50", image_size=224, robustness=param["robustness"],
                                             lr=1e-4, weight_decay=1e-4, silent=True)
    else:
        model_data = prepare_for_retraining(model_data, f'{dataset_dir}/data', f'{dataset_dir}/val', 
                                            batch_size=32, image_size=224, robustness=param["robustness"], silent=True)
    train(model_data, 5, f"{dataset_dir}/checkpoint.cicada.ta.ep{param["idx"]:02}###.pth", 
          break_at_val_acc_diff=param["break_at_val_acc_diff"])

Phase 1:
Epoch    1 /    5  | Train Loss: 4.4782 Acc: 0.1412  | Val Loss: 3.8467 Acc: 0.1875  | Elapsed time: 0:00:39.764082
Epoch    2 /    5  | Train Loss: 3.1831 Acc: 0.3162  | Val Loss: 3.0110 Acc: 0.3438  | Elapsed time: 0:01:06.269717
Epoch    3 /    5  | Train Loss: 2.4004 Acc: 0.4675  | Val Loss: 2.3262 Acc: 0.5312  | Elapsed time: 0:01:32.488642
Epoch    4 /    5  | Train Loss: 1.7762 Acc: 0.6245  | Val Loss: 1.8954 Acc: 0.6250  | Elapsed time: 0:01:58.846555
Epoch    5 /    5  | Train Loss: 1.3094 Acc: 0.7420  | Val Loss: 1.5509 Acc: 0.6562  | Elapsed time: 0:02:25.624755
Phase 2:
Epoch    1 /    5  | Train Loss: 2.0995 Acc: 0.5409  | Val Loss: 1.5360 Acc: 0.6875  | Elapsed time: 0:00:25.839088
Epoch    2 /    5  | Train Loss: 1.7627 Acc: 0.6172  | Val Loss: 1.4179 Acc: 0.7188  | Elapsed time: 0:00:52.120409
Epoch    3 /    5  | Train Loss: 1.5564 Acc: 0.6612  | Val Loss: 1.4157 Acc: 0.6875  | Elapsed time: 0:01:18.515850
Phase 3:
Epoch    1 /    5  | Train Loss: 1.6346 Acc: 

### Model B (resnet-101)

In [31]:
training_params = [
    { "idx": 1, "robustness": 0.2, "break_at_val_acc_diff": 0.05},
    { "idx": 2, "robustness": 0.5, "break_at_val_acc_diff": 0.02},
    { "idx": 3, "robustness": 1.0, "break_at_val_acc_diff": 0.01},
    { "idx": 4, "robustness": 2.0, "break_at_val_acc_diff": -0.000001},
    { "idx": 5, "robustness": 2.0, "break_at_val_acc_diff": -0.000001},
    { "idx": 6, "robustness": 2.0, "break_at_val_acc_diff": -0.000001}
]
for param in training_params:
    print(f"Phase {param["idx"]}:")
    if param["idx"] == 1:
        model_data = init_model_for_training(f'{dataset_dir}/data', f'{dataset_dir}/val', 
                                             batch_size=32, arch="resnet101", image_size=224, robustness=param["robustness"],
                                             lr=1e-4, weight_decay=1e-4, silent=True)
    else:
        model_data = prepare_for_retraining(model_data, f'{dataset_dir}/data', f'{dataset_dir}/val', 
                                            batch_size=32, image_size=224, robustness=param["robustness"], silent=True)
    train(model_data, 5, f"{dataset_dir}/checkpoint.cicada.tb.ep{param["idx"]:02}###.pth", 
          break_at_val_acc_diff=param["break_at_val_acc_diff"])

Phase 1:
Epoch    1 /    5  | Train Loss: 3.4608 Acc: 0.2665  | Val Loss: 2.1670 Acc: 0.4762  | Elapsed time: 0:02:05.581155
Epoch    2 /    5  | Train Loss: 1.6958 Acc: 0.5942  | Val Loss: 1.5121 Acc: 0.6429  | Elapsed time: 0:03:36.464370
Epoch    3 /    5  | Train Loss: 0.9535 Acc: 0.7648  | Val Loss: 1.2624 Acc: 0.6667  | Elapsed time: 0:05:09.525064
Phase 2:
Epoch    1 /    5  | Train Loss: 1.5493 Acc: 0.6033  | Val Loss: 1.2991 Acc: 0.6310  | Elapsed time: 0:01:38.251262
Epoch    2 /    5  | Train Loss: 1.2566 Acc: 0.6749  | Val Loss: 1.2669 Acc: 0.6905  | Elapsed time: 0:03:16.885391
Epoch    3 /    5  | Train Loss: 1.1231 Acc: 0.7082  | Val Loss: 1.1394 Acc: 0.7143  | Elapsed time: 0:04:55.921881
Epoch    4 /    5  | Train Loss: 0.9947 Acc: 0.7457  | Val Loss: 1.1135 Acc: 0.7024  | Elapsed time: 0:06:37.681167
Phase 3:
Epoch    1 /    5  | Train Loss: 0.8878 Acc: 0.7790  | Val Loss: 1.0842 Acc: 0.7143  | Elapsed time: 0:01:41.590388
Epoch    2 /    5  | Train Loss: 0.8209 Acc: 

In [39]:
model_data = torch.load(f"{dataset_dir}/checkpoint.cicada.tb.ep060002.pth", weights_only=False)

### Model C (resnet-34)

In [17]:
training_params = [
    { "idx": 1, "robustness": 0.2, "break_at_val_acc_diff": 0.05},
    { "idx": 2, "robustness": 0.5, "break_at_val_acc_diff": 0.02},
    { "idx": 3, "robustness": 1.0, "break_at_val_acc_diff": 0.01},
    { "idx": 4, "robustness": 2.0, "break_at_val_acc_diff": -0.000001},
    { "idx": 5, "robustness": 2.0, "break_at_val_acc_diff": -0.000001},
    { "idx": 6, "robustness": 2.0, "break_at_val_acc_diff": -0.000001}
]
for param in training_params:
    print(f"Phase {param["idx"]}:")
    if param["idx"] == 1:
        model_data = init_model_for_training(f'{dataset_dir}/data', f'{dataset_dir}/val', 
                                             batch_size=32, arch="resnet34", image_size=224, robustness=param["robustness"],
                                             lr=1e-4, weight_decay=1e-4, silent=True)
    else:
        model_data = prepare_for_retraining(model_data, f'{dataset_dir}/data', f'{dataset_dir}/val', 
                                            batch_size=32, image_size=224, robustness=param["robustness"], silent=True)
    train(model_data, 5, f"{dataset_dir}/checkpoint.cicada.tc.ep{param["idx"]:02}###.pth", 
          break_at_val_acc_diff=param["break_at_val_acc_diff"])

Phase 1:
Epoch    1 /    5  | Train Loss: 4.0401 Acc: 0.2157  | Val Loss: 3.0061 Acc: 0.3750  | Elapsed time: 0:00:22.853247
Epoch    2 /    5  | Train Loss: 2.4528 Acc: 0.5065  | Val Loss: 2.2035 Acc: 0.4688  | Elapsed time: 0:00:46.589591
Epoch    3 /    5  | Train Loss: 1.7044 Acc: 0.6770  | Val Loss: 1.6959 Acc: 0.6562  | Elapsed time: 0:01:09.745984
Epoch    4 /    5  | Train Loss: 1.1965 Acc: 0.7899  | Val Loss: 1.4126 Acc: 0.6562  | Elapsed time: 0:01:32.913761
Phase 2:
Epoch    1 /    5  | Train Loss: 2.3284 Acc: 0.4896  | Val Loss: 1.4854 Acc: 0.6250  | Elapsed time: 0:00:24.587938
Epoch    2 /    5  | Train Loss: 1.9940 Acc: 0.5641  | Val Loss: 1.5214 Acc: 0.6250  | Elapsed time: 0:00:49.241880
Phase 3:
Epoch    1 /    5  | Train Loss: 1.8658 Acc: 0.5985  | Val Loss: 1.3553 Acc: 0.7188  | Elapsed time: 0:00:24.544694
Epoch    2 /    5  | Train Loss: 1.6329 Acc: 0.6635  | Val Loss: 1.3041 Acc: 0.7500  | Elapsed time: 0:00:49.237325
Epoch    3 /    5  | Train Loss: 1.6149 Acc: 

### Model D (resnet-152)

In [32]:
training_params = [
    { "idx": 1, "robustness": 0.2, "break_at_val_acc_diff": 0.05},
    { "idx": 2, "robustness": 0.5, "break_at_val_acc_diff": 0.02},
    { "idx": 3, "robustness": 1.0, "break_at_val_acc_diff": 0.01},
    { "idx": 4, "robustness": 2.0, "break_at_val_acc_diff": -0.000001},
    { "idx": 5, "robustness": 2.0, "break_at_val_acc_diff": -0.000001},
    { "idx": 6, "robustness": 2.0, "break_at_val_acc_diff": -0.000001}
]
for param in training_params:
    print(f"Phase {param["idx"]}:")
    if param["idx"] == 1:
        model_data = init_model_for_training(f'{dataset_dir}/data', f'{dataset_dir}/val', 
                                             batch_size=32, arch="resnet152", image_size=224, robustness=param["robustness"],
                                             lr=1e-4, weight_decay=1e-4, silent=True)
    else:
        model_data = prepare_for_retraining(model_data, f'{dataset_dir}/data', f'{dataset_dir}/val', 
                                            batch_size=32, image_size=224, robustness=param["robustness"], silent=True)
    train(model_data, 5, f"{dataset_dir}/checkpoint.cicada.td.ep{param["idx"]:02}###.pth", 
          break_at_val_acc_diff=param["break_at_val_acc_diff"])

Phase 1:
Epoch    1 /    5  | Train Loss: 3.2403 Acc: 0.3033  | Val Loss: 2.0221 Acc: 0.5000  | Elapsed time: 0:02:07.260538
Epoch    2 /    5  | Train Loss: 1.4557 Acc: 0.6489  | Val Loss: 1.3021 Acc: 0.6429  | Elapsed time: 0:04:09.869474
Epoch    3 /    5  | Train Loss: 0.7899 Acc: 0.7979  | Val Loss: 1.0924 Acc: 0.6905  | Elapsed time: 0:06:12.231535
Phase 2:
Epoch    1 /    5  | Train Loss: 1.4493 Acc: 0.6331  | Val Loss: 1.2052 Acc: 0.6786  | Elapsed time: 0:02:07.982461
Epoch    2 /    5  | Train Loss: 1.1428 Acc: 0.7030  | Val Loss: 1.1173 Acc: 0.7024  | Elapsed time: 0:04:15.917992
Epoch    3 /    5  | Train Loss: 1.0085 Acc: 0.7362  | Val Loss: 1.0478 Acc: 0.6905  | Elapsed time: 0:06:23.896194
Phase 3:
Epoch    1 /    5  | Train Loss: 0.9487 Acc: 0.7539  | Val Loss: 1.2315 Acc: 0.6310  | Elapsed time: 0:02:07.946518
Epoch    2 /    5  | Train Loss: 0.7555 Acc: 0.8089  | Val Loss: 1.1461 Acc: 0.6905  | Elapsed time: 0:04:16.602304
Epoch    3 /    5  | Train Loss: 0.7001 Acc: 

In [5]:
model_data = torch.load(f"{dataset_dir}/checkpoint.cicada.td.ep060001.pth", weights_only=False)

In [7]:
test_top_k(model_data, f"{dataset_dir}/test", 3, print_preds=True, print_top1_accuracy=True, print_no_match=False)
test_top_k(model_data, f"{dataset_dir}/test", 5, print_preds=False, print_top1_accuracy=False)
test_top_k(model_data, f"{dataset_dir}/test", 10, print_preds=False, print_top1_accuracy=False)

eopycna-repanda-2 : [32meopycna-repanda[0m(0.885)  platypleura-capitata(0.038)  hamza-ciliaris(0.024)  
eopycna-repanda   : [32meopycna-repanda[0m(0.654)  platypleura-takasagona(0.329)  platypleura-assamensis(0.006)  
----------
Top   1 accuracy: 2/2 -> 100.00%, genus matched: 2/2 -> 100.00%
Top   3 accuracy: 2/2 -> 100.00%, genus matched: 2/2 -> 100.00%
Top   5 accuracy: 2/2 -> 100.00%, genus matched: 2/2 -> 100.00%
Top  10 accuracy: 2/2 -> 100.00%, genus matched: 2/2 -> 100.00%


### Model E (resnet-152 + image transform pipeline fixed) ***

In [4]:
training_params = [
    { "idx": 1, "robustness": 0.2, "break_at_val_acc_diff": 0.05},
    { "idx": 2, "robustness": 0.5, "break_at_val_acc_diff": 0.02},
    { "idx": 3, "robustness": 1.0, "break_at_val_acc_diff": 0.01},
    { "idx": 4, "robustness": 2.0, "break_at_val_acc_diff": -0.000001},
    { "idx": 5, "robustness": 2.0, "break_at_val_acc_diff": -0.000001},
    { "idx": 6, "robustness": 2.0, "break_at_val_acc_diff": -0.000001}
]
start_time = time.time()
print("Started at:", datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
for param in training_params:
    print(f"Phase {param["idx"]}:")
    if param["idx"] == 1:
        model_data = init_model_for_training(f'{dataset_dir}/data', f'{dataset_dir}/val', 
                                             batch_size=32, arch="resnet152", image_size=224, robustness=param["robustness"],
                                             lr=1e-4, weight_decay=1e-4, silent=True)
    else:
        model_data = prepare_for_retraining(model_data, f'{dataset_dir}/data', f'{dataset_dir}/val', 
                                            batch_size=32, image_size=224, robustness=param["robustness"], silent=True)
    train(model_data, 5, f"{dataset_dir}/checkpoint.cicada.te.ep{param["idx"]:02}###.pth", 
          break_at_val_acc_diff=param["break_at_val_acc_diff"])
    print(f"Total elapsed time: {datetime.timedelta(seconds=(time.time() - start_time))}")

Started at: 2025-03-23 08:32:57
Phase 1:
Epoch    0 /    4  | Train Loss: 3.1861 Acc: 0.3228  | Val Loss: 1.9318 Acc: 0.5714  | Elapsed time: 0:02:35.264527
Epoch    1 /    4  | Train Loss: 1.4204 Acc: 0.6499  | Val Loss: 1.4491 Acc: 0.6310  | Elapsed time: 0:04:30.619668
Epoch    2 /    4  | Train Loss: 0.7573 Acc: 0.8062  | Val Loss: 1.2325 Acc: 0.6190  | Elapsed time: 0:06:27.802004
Total elapsed time: 0:06:31.450745
Phase 2:
Epoch    0 /    4  | Train Loss: 1.4660 Acc: 0.6313  | Val Loss: 1.1841 Acc: 0.6905  | Elapsed time: 0:02:02.647503
Epoch    1 /    4  | Train Loss: 1.2273 Acc: 0.6841  | Val Loss: 1.0426 Acc: 0.6667  | Elapsed time: 0:04:05.797385
Total elapsed time: 0:10:38.044754
Phase 3:
Epoch    0 /    4  | Train Loss: 1.1494 Acc: 0.7017  | Val Loss: 0.9463 Acc: 0.6905  | Elapsed time: 0:02:02.980263
Epoch    1 /    4  | Train Loss: 1.0343 Acc: 0.7352  | Val Loss: 0.9951 Acc: 0.7381  | Elapsed time: 0:04:07.226541
Epoch    2 /    4  | Train Loss: 0.8684 Acc: 0.7770  | Val 

In [3]:
model_data = torch.load(f"{dataset_dir}/checkpoint.cicada.te.ep060000.pth", weights_only=False)

In [6]:
test_top_k(model_data, f"{dataset_dir}/test", 3, print_preds=True, print_top1_accuracy=True, print_no_match=False)
test_top_k(model_data, f"{dataset_dir}/test", 5, print_preds=False, print_top1_accuracy=False)
test_top_k(model_data, f"{dataset_dir}/test", 10, print_preds=False, print_top1_accuracy=False)

eopycna-repanda-2 : [32meopycna-repanda[0m(0.997)  platypleura-takasagona(0.001)  platypleura-assamensis(0.000)  
eopycna-repanda   : [32meopycna-repanda[0m(0.822)  platypleura-takasagona(0.157)  platypleura-assamensis(0.013)  
----------
Top   1 accuracy: 2/2 -> 100.00%, genus matched: 2/2 -> 100.00%
Top   3 accuracy: 2/2 -> 100.00%, genus matched: 2/2 -> 100.00%
Top   5 accuracy: 2/2 -> 100.00%, genus matched: 2/2 -> 100.00%
Top  10 accuracy: 2/2 -> 100.00%, genus matched: 2/2 -> 100.00%


In [20]:
pred = validate_prediction_in_dir_top_k(f"{dataset_dir}/val", model_data, 3)
print (f"Top 3 accuracy: {pred['success']}/{pred['total']} -> {100*pred['success']/pred['total']:.2f}%")

Top 3 accuracy: 76/84 -> 90.48%
