In [1]:
import os
import sys
import random
import numpy as np
import pandas as pd

from tqdm import tqdm
from itertools import combinations

In [2]:
dataset_path = "/gpfs/data/gpfs0/k.fedyanin/space/IJB/aligned_data_for_fusion/big"
RESULT_META_DIR = "/gpfs/data/gpfs0/k.fedyanin/space/IJB/aligned_data_for_fusion/metadata_refuse_verification/val_test"

val_portion = 0.5
NUM_PAIRS = 1000
POSITIVE_PORTION = 0.5

In [3]:
def get_identities_split(dataset_path=dataset_path, val_portion=val_portion):
    identities = np.array(os.listdir(dataset_path), dtype=object)
    
    np.random.shuffle(identities)
    split_idx = int(val_portion * len(identities))
    
    return identities[:split_idx], identities[split_idx:]

def randomized_round(number):
    return int(number) + int(random.random() < (number % 1))

def sample_pairs_from_directory(prefix, directory, n_pairs):
    pairs = list(combinations(os.listdir(os.path.join(prefix, directory)), r=2))
    n_sampled_pairs = n_pairs if n_pairs < len(pairs) else len(pairs)
    pairs = [tuple(map(lambda x: os.path.join(directory, x), pair)) for pair in pairs]
    return random.sample(pairs, n_sampled_pairs)


def generate_positive_pairs(directories, n_pairs, identities_dir=dataset_path):

    remaining_directories = len(directories)
    remaining_pairs = n_pairs

    positive_pairs = []
    probes = []

    random.shuffle(directories)
    
    mean_pairs_from_directory = n_pairs / len(directories)
    sampled_pairs = 0
    
    for idx, directory in tqdm(enumerate(directories)):
        lack = idx * mean_pairs_from_directory - sampled_pairs
        lack = max(0, lack)
        
        needed_pairs_amount = randomized_round(lack + mean_pairs_from_directory)
        probes.append(needed_pairs_amount)
        new_pairs = sample_pairs_from_directory(identities_dir, directory, needed_pairs_amount)
        positive_pairs.extend(new_pairs)

        sampled_pairs += len(new_pairs)
        
    return positive_pairs


def generate_negative_pairs(identities, n_pairs, identities_dir=dataset_path):
    negative_pairs = set()
    
    identities = list(identities)
    i = 0
    while len(negative_pairs) < n_pairs:
        left, right = random.sample(identities, 2)
        potential_lefts = os.listdir(os.path.join(identities_dir, left))
        potential_rights = os.listdir(os.path.join(identities_dir, right))
        
        if len(potential_lefts) and len(potential_rights):
            left = os.path.join(left, random.choice(potential_lefts))
            right = os.path.join(right, random.choice(potential_rights))
            negative_pairs.add((left, right,))
            
        if i % 1000 == 0:
            sys.stdout.write("Sampled pairs : {}/{}...\t\r".format(len(negative_pairs), n_pairs))
    return list(negative_pairs)

def save_to_file(positive_pairs, negative_pairs, file_path):
    with open(file_path, "w") as f:
        for pair in positive_pairs:
            f.write(",".join((pair[0], pair[1], "1")) + "\n")
        for pair in negative_pairs:
            f.write(",".join((pair[0], pair[1], "0")) + "\n")
    

def save_meta_file():
    val_idxes, test_idxes = get_identities_split(dataset_path=dataset_path, val_portion=val_portion)
    
    val_path = os.path.join(RESULT_META_DIR, "val_pairs_" + str(NUM_PAIRS) + "_prob_" + str(POSITIVE_PORTION) + ".csv")
    pos_pairs = generate_positive_pairs(val_idxes, n_pairs=int(NUM_PAIRS * POSITIVE_PORTION))
    neg_pairs = generate_negative_pairs(val_idxes, n_pairs=int(NUM_PAIRS * (1 - POSITIVE_PORTION)))
    save_to_file(pos_pairs, neg_pairs, file_path=val_path)
    
    test_path = os.path.join(RESULT_META_DIR, "test_pairs_" + str(NUM_PAIRS) + "_prob_" + str(POSITIVE_PORTION) + ".csv")
    pos_pairs = generate_positive_pairs(test_idxes, n_pairs=int(NUM_PAIRS * POSITIVE_PORTION))
    neg_pairs = generate_negative_pairs(test_idxes, n_pairs=int(NUM_PAIRS * (1 - POSITIVE_PORTION)))
    save_to_file(pos_pairs, neg_pairs, file_path=test_path)
    

In [4]:
save_meta_file()

1765it [00:12, 144.30it/s]
57it [00:00, 560.02it/s]

Sampled pairs : 500/500...	

1766it [00:05, 347.64it/s]

Sampled pairs : 500/500...	


