In [1]:
import os
import numpy as np

In [2]:
def read_file(file_path):
    with open(file_path, "r") as myfile:
        contents = myfile.read().replace('\n', '')
    return contents

In [3]:
def get_file_names_in_directory(dir_path):
    file_names = []
    for _, dirs, _ in os.walk(dir_path):
        for name in dirs:
            file_names.append(name) 
    return file_names

In [4]:
def store_idxs_in_txt(idxs, file_path):
    with open(file_path, 'w') as f:
        for idx in idxs:
            f.write("%s\n" % idx)

In [5]:
def create_train_val_test_splits(example_idxs, train_split, val_split, random_seed=0):
    
    # Shuffle the example names
    np.random.seed(random_seed)
    np.random.shuffle(example_idxs)
    
    # Split the example names into train, val, and test
    train_split_index = int(len(example_idxs) * train_split)
    val_split_index = int(len(example_idxs) * val_split) + train_split_index
    
    train_idxs = example_idxs[:train_split_index]
    val_idxs = example_idxs[train_split_index:val_split_index]
    test_idxs = example_idxs[val_split_index:]
    
    return train_idxs, val_idxs, test_idxs

In [6]:
def main():
    random_seed = 0
    root_file_path = 'root.txt'
    root_path = read_file(root_file_path)
        
    dir_path = f'{root_path}/render_shapenet_data/processed_get3d/camera/02958343'
    split_path = f'{root_path}/3dgan_data_split/shapenet_car.txt'
    split_dir = f'{root_path}/3dgan_data_split/shapenet_car'

    # Get all the directory names in the processed_get3d directory
    file_names = get_file_names_in_directory(dir_path)

    # store the directory names in split_path file
    store_idxs_in_txt(file_names, split_path)

    # create train, val and test splits
    train_examples, val_examples, test_examples = create_train_val_test_splits(file_names, 0.6, 0.2, random_seed=random_seed)

    # store the train, val and test splits in separate text files
    store_idxs_in_txt(train_examples, os.path.join(split_dir, 'train.txt'))
    store_idxs_in_txt(val_examples, os.path.join(split_dir, 'val.txt'))
    store_idxs_in_txt(test_examples, os.path.join(split_dir, 'test.txt'))

    # print the number of examples in each split, and total number of examples
    print(f'Number of examples in train split: {len(train_examples)}')
    print(f'Number of examples in val split: {len(val_examples)}')
    print(f'Number of examples in test split: {len(test_examples)}')
    print(f'Total number of examples: {len(file_names)}')

In [7]:
if __name__ == "__main__":
    main()

Number of examples in train split: 1418
Number of examples in val split: 472
Number of examples in test split: 474
Total number of examples: 2364
