In [21]:
import numpy as np
import random
from PIL import Image

class DataProcessor():
    # 初始化
    def __init__(self, use_poster=False):
        self.use_poster = use_poster
        
        # 声明数据文件路径
        user_info_path = '../datasets/ml-1m/users.dat'
        movie_info_path = '../datasets/ml-1m/movies.dat'
        if use_poster:
            rating_info_path = '../datasets/ml-1m/new_rating.txt'
        else:
            rating_info_path = '../datasets/ml-1m/ratings.dat'
        self.post_path = '../datasets/ml-1m/posters/'
        
        # 记录用户数据的最大ID
        self.max_user_id = 0
        self.max_user_age = 0
        self.max_user_job = 0
        
        # 获取用户数据
        self.user_info = self.get_user_info(user_info_path)
        
        # 获取电影数据
        self.movie_info, self.movie_titles, self.movie_cats = self.get_movie_info(movie_info_path)
        
        # 记录电影的最大ID
        self.max_movie_id = np.max(list(map(int, self.movie_info.keys())))
        self.max_movie_title = np.max([self.movie_titles[k] for k in self.movie_titles])
        self.max_movie_cat = np.max([self.movie_cats[k] for k in self.movie_cats])
        
        # 获取评分数据
        self.rating_info = self.get_rating_info(rating_info_path)
        
        # 构建数据集
        self.dataset = self.get_dataset(user_info=self.user_info, 
                                       movie_info=self.movie_info, rating_info=self.rating_info)
        
        # 划分数据集
        self.train_dataset = self.dataset[:int(len(self.dataset) * 0.9)]
        self.test_dataset = self.dataset[int(len(self.dataset) * 0.9):]
        
        # 打印测试
        print('用户数据量：{}，电影数据量：{}'.format(len(self.user_info), len(self.movie_info)))
        print('构建的数据集总量：{}，其中训练集：{}，测试集：{}'.format(len(self.dataset), 
                                                   len(self.train_dataset), len(self.test_dataset)))
        
    # 获取用户数据
    def get_user_info(self, path):
        def gender2num(gender):
            return 1 if gender == 'F' else 0

        with open(path, 'r') as f:
            data = f.readlines()

        user_info = {}
        
        for item in data:
            item = item.strip().split('::')
            user_id = item[0]
            user_info[user_id] = {
                'user_id': int(user_id),
                'gender': gender2num(item[1]),
                'age': int(item[2]),
                'job': int(item[3])
            }
            self.max_user_id = max(self.max_user_id, int(user_id))
            self.max_user_age = max(self.max_user_age, int(item[2]))
            self.max_user_job = max(self.max_user_job, int(item[3]))

        return user_info
    
    # 获取电影数据
    def get_movie_info(self, path):
        with open(path, 'r', encoding='ISO-8859-1') as f:
            data = f.readlines()

        # 建立3个字典，分别存放电影的所有、名称、类别信息
        movie_info, movie_titles, movie_cats = {}, {}, {}

        # 对电影名称、类别中不同的单词计数
        t_count, c_count = 1, 1

        # 按行读取数据并处理
        for item in data:
            item = item.strip().split('::')
            v_id = item[0]
            v_title = item[1][:-7]  # 去除title里的上映年份
            v_year = item[1][-5:-1] # 获取上映年份
            v_cat = item[2].split('|')

            # 统计电影名称包含的单词，并给每个单词一个序号，存放在movie_titles中
            titles = v_title.split()
            for t in titles:
                if t not in movie_titles:
                    movie_titles[t] = t_count
                    t_count += 1

            # 统计电影类别包含的单词，并给每个单词一个序号，存放在movie_cat中
            for c in v_cat:
                if c not in movie_cats:
                    movie_cats[c] = c_count
                    c_count += 1

            # 补0使电影名称对应的列表长度为15（最长的电影名称长度为15）
            title = [movie_titles[k] for k in titles]
            while len(title) < 15:
                title.append(0)

            # 补0使电影类别对应的列表长度为6（最多类别为6）
            cat = [movie_cats[k] for k in v_cat]
            while len(cat) < 6:
                cat.append(0)

            # 保存电影完整信息
            movie_info[v_id] = {
                'movie_id': int(v_id),
                'title': title,
                'cat': cat,
                'year': int(v_year)
            }

        return movie_info, movie_titles, movie_cats
    
    # 获取评分数据
    def get_rating_info(self, path):
        with open(path, 'r') as f:
            data = f.readlines()

        rating_info = {}

        for item in data:
            item = item.strip().split('::')
            user_id, movie_id, score = item[0], item[1], item[2]
            if user_id not in rating_info.keys():
                rating_info[user_id] = {movie_id: float(score)}
            else:
                rating_info[user_id][movie_id] = float(score)

        return rating_info
    
    # 构建数据集
    def get_dataset(self, user_info, movie_info, rating_info):
        dataset = []

        # 按照评分数据的key值索引数据
        for user_id in rating_info.keys():
            user_ratings = rating_info[user_id]
            for movie_id in user_ratings:
                dataset.append({
                    'user_info': user_info[user_id],
                    'movie_info': movie_info[movie_id],
                    'score': user_ratings[movie_id]
                })

        return dataset
    
    # 数据加载器
    def load_data(self, dataset=None, mode='train'):
        BATCHSIZE = 256  # 定义批次大小
        data_length = len(dataset)
        index_list = list(range(data_length))

        # 定义数据迭代加载器
        def data_generator():
            # 训练模式下，打乱训练数据
            if mode == 'train':
                random.shuffle(index_list)

            # 声明每个特征的列表
            user_id_list, user_gender_list, user_age_list, user_job_list = [], [], [], []
            movie_id_list, movie_title_list, movie_cat_list, movie_poster_list = [], [], [], []
            score_list = []

            # 按索引遍历输入数据集
            for idx, i in enumerate(index_list):
                # 获取特征数据并保存到对应特征列表中
                user_id_list.append(dataset[i]['user_info']['user_id'])
                user_gender_list.append(dataset[i]['user_info']['gender'])
                user_age_list.append(dataset[i]['user_info']['age'])
                user_job_list.append(dataset[i]['user_info']['job'])

                movie_id_list.append(dataset[i]['movie_info']['movie_id'])
                movie_title_list.append(dataset[i]['movie_info']['title'])
                movie_cat_list.append(dataset[i]['movie_info']['cat'])

                # 如果使用电影海报数据
                if self.use_poster:
                    movie_id = dataset[i]['movie_info']['movie_id']
                    poster = Image.open(poster_path + 'mov_id{}.jpg'.format(str(movie_id)))
                    poster = poster.resize([64, 64])
                    if len(poster.size) <= 2:
                        poster = poster.convert('RGB')

                    movie_poster_list.append(np.array(poster))

                score_list.append(int(dataset[i]['score']))

                # 如果读取到数据量达到定义的批次大小，则返回当前批次
                if len(user_id_list) == BATCHSIZE:
                    # 转换list到ndarray，并reshape到固定形状
                    user_id_arr = np.array(user_id_list)
                    user_gender_arr = np.array(user_gender_list)
                    user_age_arr = np.array(user_age_list)
                    user_job_arr = np.array(user_job_list)

                    movie_id_arr = np.array(movie_id_list)
                    movie_title_arr = np.reshape(np.array(movie_title_list), [BATCHSIZE, 1, 15]).astype(np.int64)
                    movie_cat_arr = np.reshape(np.array(movie_cat_list), [BATCHSIZE, 6]).astype(np.int64)

                    if use_poster:
                        movie_poster_arr = np.reshape(np.array(movie_poster_list)/127.5 - 1, 
                                                      [BATCHSIZE, 3, 64, 64]).astype(np.float32)
                    else:
                        movie_poster_arr = np.array([0.])

                    score_arr = np.reshape(np.array(score_list), [-1, 1]).astype(np.float32)

                    # 返回当前批次数据
                    yield [user_id_arr, user_gender_arr, user_age_arr, user_job_arr], \
                            [movie_id_arr, movie_title_arr, movie_cat_arr, movie_poster_arr], score_arr

                    # 清空数据
                    user_id_list, user_gender_list, user_age_list, user_job_list = [], [], [], []
                    movie_id_list, movie_title_list, movie_cat_list, movie_poster_list = [], [], [], []
                    score_list = []

        return data_generator