In [1]:
import pickle
import os
import numpy as np
import torch
import collections
import json
from tqdm import tqdm

import sys
sys.path.append("/home/ly/workspace/mmsa")
from utils.dataset import *
from utils.tokenization import BasicTokenizer
from utils.load_yelp import *

In [3]:
def check_photo(_id:str):
    path = os.path.join(DATA_DIR, "photos", _id[:2], _id + ".jpg")
    return os.path.exists(path)

In [4]:
def read_reviews(file_path:str, clean_data:bool=False) -> List[Dict[str, str]]: 
    # 读入数据
    reviews = None
    if file_path.endswith(".json"):
         with open(file_path, 'r', encoding="utf-8") as f:
            reviews = []
            for line in tqdm(f, "Read json"):
                review = json.loads(line)
                imgs = []
                captions = []
                for photo in review['Photos']:
                    _id = photo['_id']
                    caption = photo["Caption"]
                    if clean_data:
                        if check_photo(_id):
                            imgs.append(_id)
                            captions.append(caption)
                    else:
                        imgs.append(_id)
                        captions.append(caption)
                reviews.append({'_id': review['_id'],
                      'Text': review['Text'],
                      'Photos': imgs,
                      'Captions': captions,
                      'Rating': review['Rating']})
    elif file_path.endswith(".pickle"):
        with open(file_path, 'rb') as f:
            reviews = pickle.load(f) # 直接从pickle中加载
    else:
        raise RuntimeError("Illegal file path!")
    return reviews

In [49]:
train = read_reviews(train_json, True)
valid = read_reviews(valid_json, True)
test = {}
for city in cities:
    test[city] = read_reviews(DATA_DIR + "raw/test/" + city + "_test.json", True)

Read json: 35435it [00:01, 33412.25it/s]
Read json: 2215it [00:00, 36740.49it/s]
Read json: 315it [00:00, 35542.08it/s]
Read json: 325it [00:00, 30684.27it/s]
Read json: 3730it [00:00, 35776.20it/s]
Read json: 1715it [00:00, 21868.18it/s]
Read json: 570it [00:00, 1210.24it/s]


In [50]:
data = {"train" : train, "valid" : valid, "test" : test}

In [51]:
with open(DATA_DIR + "raw/" + "clean_data.pickle", "wb") as o:
    pickle.dump(data, o, protocol=pickle.HIGHEST_PROTOCOL)

In [5]:
def load_data():
    data = None
    with open(DATA_DIR + "raw/" + "clean_data.pickle", "rb") as r:
        data = pickle.load(r)
    return data
data = load_data()

In [52]:
freq_dict = count_word_freq(train)
freq_dict = count_word_freq(valid, freq_dict)
token2idx, idx2token, glove_idx = build_vocab_from_glove(freq_dict)

Count word frequency: 100%|██████████| 35435/35435 [00:45<00:00, 785.20it/s]
Count word frequency: 100%|██████████| 2215/2215 [00:02<00:00, 807.86it/s]


In [55]:
save_vocab(DATA_DIR + "raw/", token2idx, idx2token, glove_idx)

In [10]:
glove_vocab = load_glove_vocab(DATA_DIR + "raw/")

In [11]:
glove_vocab.keys()

dict_keys(['token2idx', 'idx2token', 'glove_idx'])

In [12]:
glove_tokenizer = YelpSimpleTokenizer(glove_vocab["token2idx"], True)

In [14]:
%%time
# for key in ["train", "valid"]:
#     for i in data[key]:
#         i["Text"] = glove_tokenizer.to_idx(i["Text"])
for city in cities:
    for i in data["test"][city]:
        i["Text"] = glove_tokenizer.to_idx(i["Text"])

CPU times: user 8.38 s, sys: 3.99 ms, total: 8.38 s
Wall time: 8.38 s


In [19]:
def load_glove_data():
    with open(DATA_DIR + "raw/" + "glove_data.pickle", "rb") as r:
        return pickle.load(r)


In [17]:
with open(DATA_DIR + "raw/" + "glove_data.pickle", "wb") as o:
    pickle.dump(data, o, protocol=pickle.HIGHEST_PROTOCOL)

In [18]:
get_yelp_glove_weight(DATA_DIR + "raw/", 100)

Load glove: 100%|██████████| 400000/400000 [00:06<00:00, 64483.32it/s]


array([[ 0.        ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  0.        ],
       [-0.33979   ,  0.20941   ,  0.46348   , ..., -0.23394   ,
         0.47298   , -0.028803  ],
       [-0.038194  , -0.24487   ,  0.72812   , ..., -0.1459    ,
         0.8278    ,  0.27062   ],
       ...,
       [-0.07382662, -0.09844293,  0.03890738, ...,  0.03119204,
        -0.04567178,  0.04480205],
       [-0.09762643, -0.01279308, -0.09459477, ...,  0.05298445,
         0.09622958,  0.00178859],
       [ 0.08038855,  0.0121604 ,  0.04316022, ...,  0.05453663,
        -0.07731168, -0.02396872]], dtype=float32)

In [20]:
glove_data = load_glove_data()