## Create split.txt files for ARL Protocol 3 dataset

In [8]:
import numpy as np
import os
import shutil
import os.path as osp
import random
import glob

In [6]:
root = '../../dataset/odin_data/'

src = 'VIS_align'
dest = 'splits'
# splits = [0,1,2,3,4]
splits = [5,6,7]

num_train = 71
num_val = 25
num_test = 25

src_path = os.path.join(root, src)
dest_path = os.path.join(root, dest)
split_dir = dest_path

In [3]:
def get_id(x):
    return x.split('_')[2][:-1]

def get_full_id(x):
    return x.split('_')[2]

In [4]:
files = os.listdir(src_path)
ids = list(set(([get_id(x) for x in files])))

print('Total len:', len(files))
print('Unique persons:', len(ids))

Total len: 5419
Unique persons: 121


In [5]:
break
np.random.seed(0)
os.makedirs(dest_path, exist_ok=True)

# Saving only ids in split files
for split in splits:
    train_set = np.random.choice(ids, size=num_train, replace=False)
    rem_set = set(ids) - set(train_set)
    val_set = np.random.choice(list(rem_set), size=num_val, replace=False)
    test_set = list(rem_set - set(val_set))

    assert(len(set(test_set).intersection(set(train_set))) == 0)
    assert(len(set(val_set).intersection(set(train_set))) == 0)

    train_ids = [x + '\n' for x in train_set]
    val_ids = [x + '\n' for x in val_set]
    test_ids = [x + '\n' for x in test_set]

    print('Train len:', len(train_ids))
    print('Val len:', len(val_ids))
    print('Test len:', len(test_ids))
    
    train_path = os.path.join(dest_path, 'train_{}.txt'.format(split))
    val_path = os.path.join(dest_path, 'val_{}.txt'.format(split))
    test_path = os.path.join(dest_path, 'test_{}.txt'.format(split))
    
    with open(train_path, 'w') as f:
        f.writelines(train_ids)
    print('Train ids saved to', train_path)
    
    with open(val_path, 'w') as f:
        f.writelines(val_ids)
    print('Val ids saved to', val_path)
    
    with open(test_path, 'w') as f:
        f.writelines(test_ids)
    print('Test ids saved to', test_path)    

Train len: 71
Val len: 25
Test len: 25
Train ids saved to ../../dataset/odin_data/splits/train_5.txt
Val ids saved to ../../dataset/odin_data/splits/val_5.txt
Test ids saved to ../../dataset/odin_data/splits/test_5.txt
Train len: 71
Val len: 25
Test len: 25
Train ids saved to ../../dataset/odin_data/splits/train_6.txt
Val ids saved to ../../dataset/odin_data/splits/val_6.txt
Test ids saved to ../../dataset/odin_data/splits/test_6.txt
Train len: 71
Val len: 25
Test len: 25
Train ids saved to ../../dataset/odin_data/splits/train_7.txt
Val ids saved to ../../dataset/odin_data/splits/val_7.txt
Test ids saved to ../../dataset/odin_data/splits/test_7.txt


## Create Gallery and probe data

In [6]:
def variation(x):
    x = x.split('_')[-3]
    return x

# unique variations for gallery. glasses and no glasses
# variations = ['0_b_', '1_b_']

In [7]:
for split in splits:
    print('Split', split)
    test_path = os.path.join(split_dir, 'test_{}.txt'.format(split))
    gallery_file = osp.join(split_dir, 'gallery_{}.txt'.format(split))
    probe_file = osp.join(split_dir, 'probe_{}.txt'.format(split))
    
    test_ids = open(test_path, 'r').read().split('\n')[:-1]
    test_files = [x for x in files if get_id(x) in test_ids]
    
    # Treat glasses and no glasses as separate ids
    test_dic = {}
    for x in test_files:
        if get_full_id(x) not in test_dic:
            test_dic[get_full_id(x)] = []
        test_dic[get_full_id(x)].append(x)
    
    print('Number of test images:', len(test_files))
    print('Number of test ids:', len(test_dic.keys()))
    
    g_files = []
    for id_ in test_dic.keys():
        # randomly sample a baseline image for each identity for forming gallery set
        cnt = 0
        
        var_files = [x for x in test_dic[id_] if variation(x) == 'b']
        if len(var_files):
            g_files += random.sample(var_files, 1)
            cnt += 1

        if cnt == 0:
            # Add a random expression image if no neutral image
            var_files = [x for x in test_dic[id_] if variation(x) == 'e']
            if len(var_files):
                g_files += random.sample(var_files, 1)
                cnt += 1
        
        if cnt == 0:
            # Add a random pose image
            var_files = [x for x in test_dic[id_] if variation(x) == 'p']
            g_files += random.sample(var_files, 1)
            cnt += 1
        
            
    p_files = list(set(test_files) - set(g_files))
    
    print('Num gallery:', len(g_files))
    print('Num probe:', len(p_files))
    
    random.shuffle(g_files)
    random.shuffle(p_files)
    g_files = [x + '\n' for x in g_files]
    p_files = [x + '\n' for x in p_files]
    
    with open(gallery_file, 'w') as f:
        f.writelines(g_files)
        
    with open(probe_file, 'w') as f:
        f.writelines(p_files)

Split 5
Number of test images: 1108
Number of test ids: 31
Num gallery: 31
Num probe: 1077
Split 6
Number of test images: 1172
Number of test ids: 31
Num gallery: 31
Num probe: 1141
Split 7
Number of test images: 1196
Number of test ids: 29
Num gallery: 29
Num probe: 1167


# Analyze splits

In [13]:
A_mode = "Polar"
def remove_pose(A_paths):
    A_paths = [x for x in A_paths if '_p_' not in x]
    print('Removed pose images')
    return A_paths

def make_files(phase, split):
    dir_A = src_path

    split_file = os.path.join(split_dir, '{}_{}.txt'.format(phase, split))
    ids = open(split_file, 'r').read().split('\n')[:-1]

    A_paths = glob.glob(os.path.join(dir_A, '*.[pj][np]g'), recursive=True)
    A_ext = os.path.splitext(A_paths[0])[-1]
    
    A_ids = [os.path.splitext(os.path.basename(x))[0] for x in A_paths]
    
    # retain only train/test files
    A_ids = set(x.replace(A_mode, 'XX') for x in A_ids if get_id(x) in ids)
    
    A_paths = [x.replace('XX', A_mode) + A_ext for x in A_ids]
    
    spects_cnt = 0
    for p in A_paths:
        if get_full_id(p)[-1] != '0':
            spects_cnt += 1
            
    print("Phase: {}, Glasses files: {}, Total files: {}".format(phase, spects_cnt, len(A_paths)))
    
for split in splits:
    print('Split', split)
    make_files('train', split)
    make_files('val', split)
    make_files('test', split)

Split 5
Phase: train, Glasses files: 516, Total files: 3300
Phase: val, Glasses files: 78, Total files: 1011
Phase: test, Glasses files: 244, Total files: 1108
Split 6
Phase: train, Glasses files: 440, Total files: 3165
Phase: val, Glasses files: 182, Total files: 1082
Phase: test, Glasses files: 216, Total files: 1172
Split 7
Phase: train, Glasses files: 606, Total files: 3425
Phase: val, Glasses files: 111, Total files: 798
Phase: test, Glasses files: 121, Total files: 1196
