In [1]:
from json import load
import sys
sys.path.append("/home/ly/workspace/mmsa")

import json
import os
import pickle
import collections
import numpy as np
from typing import *
from tqdm import tqdm
from collections import OrderedDict
from utils.tokenization import BasicTokenizer
from utils.load_yelp import *
seed = 1024
np.random.seed(seed)

In [2]:
base_dir = os.path.join("data","yelp-vistanet")

In [3]:
def load_data():
    with open("data/yelp-vistanet/clear_data.pickle", "rb") as r:
        return pickle.load(r)
class YelpSimpleTokenizer(BasicTokenizer):
    def __init__(self, vocab:Dict[str, int]=None, do_lower_case:bool=True) -> None:
        super(YelpSimpleTokenizer, self).__init__(do_lower_case)
        self.SENT_DELIMITER = '|||'
        self.vocab = vocab
        self.UNK = len(vocab) + 1 if vocab is not None else None # 

    def tokenize(self, text:str) -> List[str]: # 默认切成2d
        res = []
        for sent in text.split(self.SENT_DELIMITER):
            if len(sent) > 0: # 有一定几率出现空字符串
                res.append(super(YelpSimpleTokenizer, self).tokenize(sent))
        return res

    def _getidx(self, token:str):
        return self.vocab.get(token, self.UNK)
        
    def to_idx(self, text:str) -> List[int]:
        assert self.vocab is not None, "No vocab!"
        sents = self.tokenize(text)
        res = []
        for sent in sents:
            res.append([self._getidx(token) for token in sent])
        return res

In [4]:
data = load_data()
vocab = load_glove_vocab()
glove_tokenizer = YelpSimpleTokenizer(vocab["token2idx"], do_lower_case=True)
len(vocab["token2idx"]), len(vocab["idx2token"]), len(vocab["glove_idx"])

(42822, 42824, 42822)

In [6]:
def check_photo(i): 
    path = os.path.join(base_dir, "photos", i[:2], i + ".jpg")
    return os.path.exists(path)

def build_glove_data(tokenizer, reviews:List[dict]):
    res = []
    total_img = 0
    for review in tqdm(reviews):
        d = {}
        d["Text"] = tokenizer.to_idx(review["Text"])
        d["Photos"] = []
        for _id in review["Photos"]:
            if check_photo(_id):
                d["Photos"].append(_id)
                total_img += 1
        d["Rating"] = review["Rating"]
        res.append(d)
    print(f"Image num : {total_img}")
    return res

In [9]:
%%time
glove_data = {}
for key in ["train", "valid", "test"]:
    glove_data[key] = build_glove_data(glove_tokenizer, data[key])

100%|██████████| 35445/35445 [01:23<00:00, 423.84it/s]
  1%|▏         | 57/4430 [00:00<00:07, 561.97it/s]

Image num : 132715


100%|██████████| 4430/4430 [00:08<00:00, 508.59it/s]
  1%|▏         | 64/4430 [00:00<00:06, 635.72it/s]

Image num : 16408


100%|██████████| 4430/4430 [00:08<00:00, 551.55it/s]

Image num : 16304
CPU times: user 1min 25s, sys: 1.58 s, total: 1min 27s
Wall time: 1min 40s





In [10]:
path = os.path.join(base_dir, "glove_data.pickle")
with open(path, "wb") as w:
    pickle.dump(glove_data, w, protocol=pickle.HIGHEST_PROTOCOL)