In [11]:
import os
from pprint import pprint
import time
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import pandas as pd

print("Importing torch ...")
begin = time.time()
import torch
print("Imported torch in {:.2f} seconds".format(time.time() - begin))

from data.turtle_data_loading import get_file_paths, TurtleDataset

Importing torch ...
Imported torch in 0.00 seconds


In [12]:
default_device = "cuda"
torch.set_default_device(default_device)

In [13]:
from turtle_id_test_config import data_path

from turtle_id_test_config import test_scalar_reg_path

In [14]:
num_samples = 100 # 0 for all
size = 256
sigmas = [0.05, 0.1, 0.15, 0.2, 0.25, 0.3]

In [15]:
file_paths:dict = get_file_paths(data_path, "test", num_samples, sigmas, size)

Loading original image paths in images_crop_resize_256_greyscale 

100%|██████████| 100/100 [00:00<00:00, 1152281.32it/s]


Loading noisy image paths sigma=0.05 in images_crop_resize_256_greyscale_noisy_0_05 

100%|██████████| 100/100 [00:00<00:00, 1565038.81it/s]


Loading noisy image paths sigma=0.1 in images_crop_resize_256_greyscale_noisy_0_1 

100%|██████████| 100/100 [00:00<00:00, 1001027.21it/s]


Loading noisy image paths sigma=0.15 in images_crop_resize_256_greyscale_noisy_0_15 

100%|██████████| 100/100 [00:00<00:00, 1476867.61it/s]


Loading noisy image paths sigma=0.2 in images_crop_resize_256_greyscale_noisy_0_2 

100%|██████████| 100/100 [00:00<00:00, 1366222.80it/s]


Loading noisy image paths sigma=0.25 in images_crop_resize_256_greyscale_noisy_0_25 

100%|██████████| 100/100 [00:00<00:00, 1451316.26it/s]


Loading noisy image paths sigma=0.3 in images_crop_resize_256_greyscale_noisy_0_3 

100%|██████████| 100/100 [00:00<00:00, 1318963.52it/s]


In [16]:
# file_paths is a dict with keys the sigmas and values the list of file paths
# Check that every list has the same length
for key in file_paths.keys():
    assert len(file_paths[key]) == num_samples

In [17]:
test_dataset = TurtleDataset(data_path, file_paths, default_device)

Loading original images 

100%|██████████| 100/100 [00:00<00:00, 137.74it/s]


Loading noisy images sigma=0.05 

100%|██████████| 100/100 [00:00<00:00, 115.79it/s]


Loading noisy images sigma=0.1 

100%|██████████| 100/100 [00:00<00:00, 124.72it/s]


Loading noisy images sigma=0.15 

100%|██████████| 100/100 [00:00<00:00, 131.31it/s]


Loading noisy images sigma=0.2 

100%|██████████| 100/100 [00:00<00:00, 130.83it/s]


Loading noisy images sigma=0.25 

100%|██████████| 100/100 [00:00<00:00, 133.79it/s]


Loading noisy images sigma=0.3 

100%|██████████| 100/100 [00:00<00:00, 126.85it/s]


In [18]:
assert len(test_dataset) == num_samples * len(sigmas), f"len(test_dataset)={len(test_dataset)} != {num_samples} * {len(sigmas)}"

In [19]:
# Fix results.csv files so that the lambda column is not lost
def fix_results_csv():
    for sigma in sigmas:
        print(f"sigma={sigma}")
        for sample_id in tqdm(range(num_samples)):
            file = file_paths[sigma][sample_id]
            extension = file.split(".")[-1]
            file = file.replace(f".{extension}", "")
            
            scalar_reg_path = f"{test_scalar_reg_path}/{file}"
            results_csv = f"{scalar_reg_path}/results.csv"
            df = pd.read_csv(results_csv)
            
            # Assert there are 81 rows
            assert len(df) == 81, f"len(df)={len(df)} != 81"
            
            # # Add sigma column to results.csv with equally spaced values from 0 to 0.4
            # df["sigma"] = np.linspace(0, 0.4, 81)
            
            # Rename "sigma" column to "lambda" if exists
            if "sigma" in df.columns:
                df.rename(columns={"sigma": "lambda"}, inplace=True)

            # print(f"{results_csv} fixed")
            
            # return df
            
            df.to_csv(results_csv, index=False)
            

In [20]:
fix_results_csv()

sigma=0.05


100%|██████████| 100/100 [00:00<00:00, 162.78it/s]


sigma=0.1


100%|██████████| 100/100 [00:00<00:00, 181.32it/s]


sigma=0.15


100%|██████████| 100/100 [00:00<00:00, 179.55it/s]


sigma=0.2


100%|██████████| 100/100 [00:00<00:00, 176.52it/s]


sigma=0.25


100%|██████████| 100/100 [00:00<00:00, 165.24it/s]


sigma=0.3


100%|██████████| 100/100 [00:00<00:00, 178.75it/s]
