forked from Yacalis/celeba-classification
-
Notifications
You must be signed in to change notification settings - Fork 0
/
get_data_dict.py
80 lines (59 loc) · 2.07 KB
/
get_data_dict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Mar 5 14:12:00 2018
@author: Yacalis
"""
import scipy.io as sio
import os
import csv
def get_data_dict(data_dir) -> dict:
imdb_mat = os.path.join(data_dir, "imdb.mat")
print('data file: ', imdb_mat)
# loading file into memory
sio.whosmat(imdb_mat)
f = sio.loadmat(imdb_mat)
# getting the important bit of the file
data = f['imdb'][0][0]
num_entries = len(data[2][0])
# turning the array into a dict of key:filename, value:gender
filename_gender_dict = {}
for i in range(num_entries):
key = str(data[2][0][i][0])
value = data[3][0][i]
filename_gender_dict[key] = value
print('number of records from data file: ', len(filename_gender_dict.keys()))
return filename_gender_dict
def get_new_data_dict(data_dir) -> dict:
data_file = os.path.join(data_dir, "new_testing_data.csv")
print('data file: ', data_file)
# instantiate dict
filename_gender_dict = {}
# load csv data to dict
with open(data_file, mode='r') as file:
reader = csv.reader(file)
for sub_folder, filename, combinedname, gender in reader:
try:
value = int(gender)
filename_gender_dict[combinedname] = value
except:
pass
print('number of records from data file: ', len(filename_gender_dict.keys()))
return filename_gender_dict
def get_celeba_data(data_dir) -> dict:
data_file = os.path.join(data_dir, "list_attr.csv")
print('data file: ', data_file)
# instantiate dict
filename_gender_dict = {}
# load csv data to dict
with open(data_file, mode='r') as file:
reader = csv.reader(file)
for image_id, eyeglasses, male, smiling in reader:
try:
value = [int(eyeglasses), int(male), int(smiling)],
filename_gender_dict[image_id] = value
except:
pass
print('number of records from data file: ',
len(filename_gender_dict.keys()))
return filename_gender_dict