In [3]:
import numpy as np
import math
import cv2
import sys
import scipy.io as sio
import matplotlib.pyplot as plt
import os
#import h5py

In [9]:
def add_margin(img, face_loc):
    crop_h = int(0.4 * (face_loc[3] - face_loc[1]))
    crop_w = int(0.4 * (face_loc[2] - face_loc[0]))
    img_h = img.shape[0]
    img_w = img.shape[1]
    #new_crop_low = crop_h + face_loc[1] - crop_h
    new_crop_high = crop_h + face_loc[3] + crop_h
    #new_crop_left = crop_w + face_loc[0] - crop_w
    new_crop_right = crop_w + face_loc[2] + crop_w
    replicate = cv2.copyMakeBorder(img, crop_h, crop_h, crop_w, crop_w, cv2.BORDER_REPLICATE)
    crop_face_img = replicate[face_loc[1] : new_crop_high, face_loc[0] : new_crop_right]
    return crop_face_img

#产生新图像
def add_augment_img(img):
    translateX = np.random.uniform(-0.1, 0.1)
    translateY = np.random.uniform(-0.1, 0.1)
    scale = np.random.uniform(0.9, 1.1)
    rotate_angle = np.random.uniform(-10, 10)
    
    def translate(image, x, y):
        # 定义平移矩阵
        M = np.float32([[1, 0, x], [0, 1, y]])
        shifted = cv2.warpAffine(image, M, (image.shape[1], image.shape[0]))

        # 返回转换后的图像
        return shifted
    
    # 定义旋转rotate函数
    def rotate(image, angle, center=None, scale=1.0):
        # 获取图像尺寸
        (h, w) = image.shape[:2]

        # 若未指定旋转中心，则将图像中心设为旋转中心
        if center is None:
            center = (w / 2, h / 2)

        # 执行旋转
        M = cv2.getRotationMatrix2D(center, angle, scale)
        rotated = cv2.warpAffine(image, M, (w, h))

        # 返回旋转后的图像
        return rotated
    
    img = translate(img, translateX, translateY)
    cv2.imshow("translate", img)
    
    rotated = rotate(img, rotate_angle, scale = scale)
    
    print(translateX, translateY, scale, rotate_angle)
    cv2.imshow("getRotationMatrix2D", rotated)
    return rotated

    
def load_wiki(wiki_path, num_data = None, target_size = (224,224)):
    mat_path = wiki_path + 'wiki_with_age.mat'
    
    data = sio.loadmat(mat_path)
    wiki_data = data['wiki'][0][0]
   
    if num_data is None:
        num_data = len(wiki_data[6][0])
        
    
    X_data_age = np.zeros([num_data, 224, 224, 3],dtype = "uint8")
    y_data_age = np.zeros([num_data],dtype = "uint8")
    
    X_data_gender = np.zeros([num_data, 224, 224, 3],dtype = "uint8")
    y_data_gender = np.zeros([num_data],dtype = "uint8")
    
    data_count = 0
    #len(wiki_data[6][0])
    num_every_age = int(num_data / 100)
    counter_every_age = np.zeros(100)
    index_every_age = np.zeros(100, num_every_age)
    
    num_every_gender = int(num_data / 2)
    counter_every_gender = np.zeros(2)
    index_every_gender = np.zeros(2, num_every_gender)
    
    for i in range(num_data):

        face_score =wiki_data[6][0][i]
        if face_score != float("-inf"):         #如果face_score == -inf说明不存在脸,这个比例比较大，因此先排除
            full_path = wiki_path + wiki_data[2][0][i][0]
            img = cv2.imread(full_path)
            age = wiki_data[8][0][i]            #有些age在正常值之外要排除
            date_of_birth = wiki_data[0][0][i]  #下面的657438是出生于1800年的Matlab serial date number
            gender = wiki_data[3][0][i]         #有一些gender==None会引起一场，需要排除
            if img is not None and gender == gender and date_of_birth > 657438 and age >= 0 and age <= 100:
                face_loc = wiki_data[5][0][i][0]
                face_loc = face_loc.astype("int32")
                roi_img = add_margin(img, face_loc)    
                face_img = cv2.resize(roi_img, target_size)
                
                index_every_age[age, counter_every_age[age]] = data_count
                counter_every_age[age] += 1
                
                index_every_gender[gender, counter_every_gender[gender]] = data_count
                counter_every_gender[gender] += 1
                
                full_X_data[data_count] = face_img                
               
                data_count += 1
                
    for cur_age in range(100):            
        X_data_age[cur_age : cur_age * num_every_age] = full_X_data[index_every_age[cur_age, 0 : num_every_age]]
        if counter_every_age[cur_age] < num_every_age:
            for cur_augment in range(num_every_age - counter_every_age[cur_age]):
                cur_choice = np.random.choice(counter_every_age[cur_age], replace = True)
                np.append(X_data_age, add_augment_img(full_X_data[cur_choice]))
        y_data_age[cur_age : cur_age * num_every_age] = cur_age
        
    for cur_gender in range(2):            
        X_data_gender[cur_gender : cur_gender * num_every_gender] = \
                                full_X_data[index_every_gender[cur_gender, 0 : num_every_gender]]
        if counter_every_gender[cur_gender] < num_every_gender:
            for cur_augment in range(num_every_gender - counter_every_gender[cur_gender]):
                cur_choice = np.random.choice(counter_every_gender[cur_gender], replace = True)
                np.append(X_data_gender, add_augment_img(full_X_data[cur_choice]))
        y_data_gender[cur_gender : cur_gender * num_every_gender] = cur_gender
    
    
    return X_data_age, y_data_age, X_data_gender, y_data_gender

In [7]:
def get_wiki_data(X_data, y_data, num_training=49000, num_validation=1000, num_test=1000, subtract_mean = True):
   
    X_train = X_data[:-(num_validation + num_test)]
    y_train = y_data[:-(num_validation + num_test)]
    X_val = X_data[X_train.shape[0] : -num_test]
    y_val = y_data[X_train.shape[0] : -num_test]
    X_test = X_data[-num_test:]
    y_test = y_data[-num_test:]
    # Normalize the data: subtract the mean image
    if subtract_mean:
        mean_image = np.mean(X_data, axis=0).astype("uint8")
        X_train -= mean_image
        X_val -= mean_image
        X_test -= mean_image
    # Transpose so that channels come first
    X_train = X_train.transpose(0, 3, 1, 2).copy()
    X_val = X_val.transpose(0, 3, 1, 2).copy()
    X_test = X_test.transpose(0, 3, 1, 2).copy()

    # Package data into a dictionary
    return {
      'X_train': X_train, 'y_train': y_train,
      'X_val': X_val, 'y_val': y_val,
      'X_test': X_test, 'y_test': y_test,
    }
    

In [10]:
#不带h5py的
wiki_path = ""
if sys.platform == "linux" :
    wiki_path = "/devdata/wiki/"
else:
    wiki_path = "G:\\MachineLearning\\wiki\\wiki\\"
    #wiki_path = "D:\\Z\\wiki\\"

mat_path = wiki_path + 'wiki_with_age.mat'

X_data, y_data  = load_wiki(wiki_path, 8)
    
wiki_cropface_dataset = get_wiki_data(X_data, y_data, num_training=49000, num_validation=10, num_test=0)

cv2.waitKey()
cv2.destroyAllWindows()

0.09507326165652905 0.053617105734280524 0.9176089286351296 -4.9122970267128
-0.025307055127768943 0.07990387675439001 0.9685518348045866 4.220507087966306
-0.06175113416187144 -0.006301875829342357 0.9893250103070833 4.072922702103096
0.0040480108870521725 -0.02681021195504596 1.0037402448139587 -0.4593939171329424
-0.04062166173265811 -0.0008792262335024503 0.9631703099811444 3.8945957781463942
0.09209542830796588 0.08714034991786054 1.069964880925668 0.4513897763161019
-0.025393573052195056 -0.05434495519057945 1.0458827820972143 -3.6315535684953515


In [11]:

wiki_path = ""
if sys.platform == "linux" :
    wiki_path = "/devdata/wiki/"
else:
    wiki_path = "G:\\MachineLearning\\wiki\\wiki\\"

mat_path = wiki_path + 'wiki_with_age.mat'

# Create a new file
f = None
X_data = None
y_data = None
#wiki_crop_dataset = f.create_dataset('wiki_cropface_data', dtype = "int32")
if os.path.exists('/devdata/wiki_cropface_data.h5'):
    f = h5py.File('/devdata/wiki_cropface_data.h5', "r")
    wiki_cropface_group = f["wiki_cropface_group"]
    X_data = np.array(wiki_cropface_group["X_data"][:])
    y_data = np.array(wiki_cropface_group["y_data"][:])
else:
    f = h5py.File('/devdata/wiki_cropface_data.h5',"w")
    X_data, y_data  = load_wiki(wiki_path, 1)
    wiki_cropface_group = f.create_group("wiki_cropface_group")
    wiki_cropface_group.create_dataset('X_data', dtype = "uint8", data = X_data)
    wiki_cropface_group.create_dataset('y_data', dtype = "uint8", data = y_data)
    
print(X_data.shape)
wiki_cropface_dataset = get_wiki_data(X_data, y_data, num_training=49000, num_validation=10, num_test=0)
f.close()
"""
test_img = X_data[0,:,:,:]
cv2.imshow("test_img", test_img)
cv2.waitKey()
cv2.destroyAllWindows()"""


NameError: name 'h5py' is not defined