In [None]:
import os
import sys

import torchvision

if '__file__' in globals():
    script_dir = os.path.dirname(__file__)
else:
    script_dir = os.getcwd()
sys.path.append(os.path.join(script_dir, './'))

from torch.utils.data import DataLoader

from utils.sokoto_dataset import SOKOTODataset
from utils.polyu_dataset import PolyUDatasetContactless

from utils.pca_module import pca_data_construction, pca_generation, pca_collate_fn

import csv

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

from tqdm import tqdm

seed = 42

In [None]:
###########################################################################################
##################################    PolyU    ############################################
###########################################################################################

# train_transform = torchvision.transforms.Compose([
#     torchvision.transforms.RandomRotation(20),
#     torchvision.transforms.ToTensor(),
#     torchvision.transforms.Normalize((0.5), (1.0)),
# ])

# polyu_location = "./../images/Cross_Fingerprint_Images_Database/processed_contactless_2d_fingerprint_images"
# trainset = PolyUDatasetContactless(root_dir=polyu_location, type='train', transform=train_transform)
# validationset = PolyUDatasetContactless(root_dir=polyu_location, type='val', transform=train_transform)


# dataset_name = "polyu"
# num_ids = 400

# ###########################################################################################
# ##################################    Sokoto    ###########################################
# ###########################################################################################

train_transform = torchvision.transforms.Compose([
    torchvision.transforms.RandomRotation(20),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize((96, 96)),
    torchvision.transforms.Normalize((0.5), (1.0)),
])

scoofing_db = "./../images/SCOOF_DB/SOCOFing/Real"
labels_file="./metadata/sokoto_meta.txt"
sokoto = SOKOTODataset(root_dir=scoofing_db, labels_file=labels_file, transform=train_transform)
trainset, validationset = train_test_split(
    sokoto, test_size=0.4, random_state=seed, shuffle=True
)

num_ids = 600
dataset_name = "sokoto"


###########################################################################################
##################################    END    ##############################################
###########################################################################################

In [None]:
# pca_elements = 784
pca_elements = 0.95

new_size = (28, 28)

model_type = "lenet5"

x, y = pca_data_construction(trainset)
x_pca, pca_model, total_pca, scaler = pca_generation(x, StandardScaler(), pca_elements)

def collate_fn(batch):
    return pca_collate_fn(
        batch,
        pca_model,
        scaler,
        new_size,
        return_components=False
    )

KeyboardInterrupt: 

In [None]:
validation_loader = DataLoader(
    validationset,
    batch_size=1,
    shuffle=False,
    num_workers=2,
    persistent_workers=True,
    pin_memory=True,
    collate_fn=collate_fn
)

In [None]:
SAVE_PATH = "../exported_images"

os.makedirs(SAVE_PATH, exist_ok=True)

os.makedirs(os.path.join(SAVE_PATH, dataset_name, model_type, "register"), exist_ok=True)
os.makedirs(os.path.join(SAVE_PATH, dataset_name, model_type, "authenticate"), exist_ok=True)

In [None]:
def save_image_to_csv(path, img):
    # if image is 1d, just save to single row of values
    # if image is 2d, save to multiple rows
    with open(path, 'w', newline='') as f:
        writer = csv.writer(f)
        if img.ndim == 1:
            writer.writerow(img)
        elif img.ndim == 2:
            for row in img:
                writer.writerow(row)


In [None]:
def save_dataset_to_csv(dataloader, type = "register"):
    for sample_num, (img, label) in tqdm(enumerate(dataloader)):
        # Process each image in the validation set
        img = img.numpy()
        label = label.numpy()

        if img.ndim == 2:
            img = img[0]
        else:
            img = img[0][0]

        # Save the image and label to a CSV file
        save_image_to_csv(
            os.path.join(SAVE_PATH, dataset_name, model_type, type, f"{sample_num}_{label[0]}.csv"), 
            img
        )

In [None]:
save_dataset_to_csv(validation_loader, type="register")
save_dataset_to_csv(validation_loader, type="authenticate")

52it [00:00, 76.29it/s]

160it [00:02, 65.85it/s]
160it [00:02, 62.01it/s]
