In [1]:
from json import load
import sys
sys.path.append("/home/ly/workspace/mmsa")

import json
import os
import pickle
import collections
import numpy as np
from typing import *
from tqdm import tqdm

from utils.tokenization import BasicTokenizer
from utils.load_yelp import *
seed = 1024
np.random.seed(seed)

In [2]:
data = load_data(split622data)

In [3]:
len(data["train"])

26583

In [4]:
freq_dict = count_word_freq(data["train"])
freq_dict = count_word_freq(data["valid"], freq_dict)

Count word frequency: 100%|██████████| 26583/26583 [00:42<00:00, 619.72it/s]
Count word frequency: 100%|██████████| 8861/8861 [00:14<00:00, 622.52it/s]


In [5]:
token2idx, idx2token = load_vocab_file("pretrained/glove6B/vocab.txt")

400000it [00:00, 1276335.72it/s]


In [6]:
token2idx, idx2token, glove_idx = build_vocab_from_glove(freq_dict)

400000it [00:00, 1300177.39it/s]


There are 41267 words in vocab.


In [7]:
save_vocab(base_dir, token2idx, idx2token, glove_idx)

In [8]:
vocab = load_glove_vocab(base_dir)
len(vocab["token2idx"]), len(vocab["idx2token"]), len(vocab["glove_idx"])

(41267, 41269, 41267)

In [9]:
all_d = [50, 100, 200, 300]
for d in all_d:
    w = get_yelp_glove_weight(split622data, d)

Load glove: 100%|██████████| 400000/400000 [00:04<00:00, 93394.92it/s]
Load glove: 100%|██████████| 400000/400000 [00:07<00:00, 51931.99it/s]
Load glove: 100%|██████████| 400000/400000 [00:12<00:00, 31228.32it/s]
Load glove: 100%|██████████| 400000/400000 [00:18<00:00, 21766.67it/s]


In [11]:
glove_tokenizer = YelpSimpleTokenizer(vocab["token2idx"], do_lower_case=True)

In [None]:
glove_tokenizer.to_idx(data["train"][0]["Text"])

In [10]:
def load_vgg_features(i): # 事实上所有review都只有三张图
    path = os.path.join(base_dir, "raw", "photo_features", i[:2], i + ".npy")
    if os.path.exists(path):
        return np.load(path)
    else:
        return None
def build_glove_and_vgg_data(tokenizer, reviews:List[dict]):
    res = []
    for review in tqdm(reviews):
        d = {}
        d["Text"] = tokenizer.to_idx(review["Text"])
        d["Photos"] = []
        for _id in review["Photos"]:
            features = load_vgg_features(_id)
            if features is not None:
                d["Photos"].append(features)
        d["Rating"] = review["Rating"]
        res.append(d)
    return res

In [13]:
%%time
glove_vgg_data = {}
for key in ["train", "valid", "test"]:
    glove_vgg_data[key] = build_glove_and_vgg_data(glove_tokenizer, data[key])

100%|██████████| 26583/26583 [16:21<00:00, 27.10it/s]
100%|██████████| 8861/8861 [05:27<00:00, 27.06it/s]
100%|██████████| 8861/8861 [05:25<00:00, 27.21it/s]

CPU times: user 6min 3s, sys: 27.4 s, total: 6min 31s
Wall time: 27min 14s





In [14]:
glove_vgg_data_path = os.path.join(split622data, "glove_vgg_data.pickle")
glove_vgg_data_path

'data/yelp-vistanet/622data/glove_vgg_data.pickle'

In [15]:
%%time
with open(glove_vgg_data_path, "wb") as w:
    pickle.dump(glove_vgg_data, w, protocol=pickle.HIGHEST_PROTOCOL)

CPU times: user 1.75 s, sys: 1.51 s, total: 3.26 s
Wall time: 3.83 s


In [16]:
imgs_num = []
for key in ["train", "valid", "test"]:
    for review in glove_vgg_data[key]:
        imgs_num.append(len(review["Photos"]))
len(imgs_num)

44305

In [17]:
a = np.array(imgs_num)

In [18]:
(a == 3).sum()

44305

In [None]:
glove_vgg_data["train"][0]

In [20]:
def check_photo(i): 
    path = os.path.join(base_dir, "photos", i[:2], i + ".jpg")
    return os.path.exists(path)

def build_glove_data(tokenizer, reviews:List[dict]):
    res = []
    total_img = 0
    for review in tqdm(reviews):
        d = {}
        d["Text"] = tokenizer.to_idx(review["Text"])
        d["Photos"] = []
        for _id in review["Photos"]:
            if check_photo(_id):
                d["Photos"].append(_id)
                total_img += 1
        d["Rating"] = review["Rating"]
        res.append(d)
    print(f"Image num : {total_img}")
    return res

In [21]:
%%time
glove_data = {}
for key in ["train", "valid", "test"]:
    glove_data[key] = build_glove_data(glove_tokenizer, data[key])

100%|██████████| 26583/26583 [00:45<00:00, 582.08it/s]
  1%|          | 66/8861 [00:00<00:13, 654.06it/s]

Image num : 98493


100%|██████████| 8861/8861 [00:14<00:00, 598.61it/s]
  1%|          | 73/8861 [00:00<00:12, 721.41it/s]

Image num : 33527


100%|██████████| 8861/8861 [00:14<00:00, 594.86it/s]

Image num : 33407
CPU times: user 1min 14s, sys: 1.09 s, total: 1min 15s
Wall time: 1min 15s





In [22]:
path = os.path.join(split622data, "glove_data.pickle")
with open(path, "wb") as w:
    pickle.dump(glove_data, w, protocol=pickle.HIGHEST_PROTOCOL)