In [17]:
import json
import os
import pickle
from collections import OrderedDict
import re
import random
import numpy as np
from tqdm import tqdm

In [2]:
_dir = "data/mvsa/MVSA-multiple/"
data_path = _dir + "data.json"

In [3]:
data = None
with open(data_path, "r") as r:
    data = json.load(r)
len(data)

12599

In [5]:
def build_glove_vocab():
    vocab = OrderedDict()
    with open("pretrained/glove27b/glove.twitter.27B.25d.txt", "r") as r:
        for i, l in enumerate(r.readlines()):
            l = l.strip().split()
            vocab[l[0]] = i
    return vocab
glove_vocab = build_glove_vocab()

In [13]:
list(glove_vocab.items())[:5]

[('<user>', 0), ('.', 1), (':', 2), ('rt', 3), (',', 4)]

In [6]:
topic_vocab = set()
for i in glove_vocab.keys():
    if len(i) >= 2 and i.startswith("#"):
        topic_vocab.add(i)
len(topic_vocab)

1629

In [7]:
p1 = re.compile(r"&[a-zA-Z]+;") # 清除转义字符
p2 = re.compile(r"([\W])") # 用于切分字符串 
def clear_text(text):
    res = []
    split_text = text.strip().lower().split()
    for i in split_text:
        if i.startswith("http") or p1.match(i) is not None: # 忽略URL和转义字符
            continue
        elif i.startswith("@"): # 
            res.append(i[1:])
        elif i.startswith("#"):
            if len(i) >= 2:
                if i in topic_vocab: # 在词表中 则直接加入
                    res.append(i)
                else: # 否则拆分
                    res.append("#")
                    res.append(i[1:])
            else:
                res.append(i)
        else: # 其他类型切分然后加入结果中
            i = p2.split(i)
            for _i in i:
                if len(_i) > 0: # 会有空字符串 忽略
                    res.append(_i)
    return res 

In [None]:
%%time
for i in data:
    i["text"] = clear_text(i["text"])
data[:5]

In [None]:
random.shuffle(data)
data[:5]

In [10]:
glove_data = {}
point = len(data) // 5
glove_data["train"] = data[:point*3]
glove_data["valid"] = data[point*3:point*4]
glove_data["test"] = data[point*4:]
len(glove_data["train"]), len(glove_data["valid"]), len(glove_data["test"])

(7557, 2519, 2523)

In [11]:
def build_freq(freq, data):
    for i in data:
        for j in i["text"]:
            freq[j] = freq.get(j, 0) + 1
    return freq
freq = {}
build_freq(freq, glove_data["train"])
build_freq(freq, glove_data["valid"])
len(freq) # all 26273

22796

In [14]:
def build_vocab_from_glove(freq_dict, glove_vocab, dir):
    _vocab = list(filter(lambda item: item[0] in glove_vocab, freq_dict.items())) # 删除掉不在glove中的词
    _vocab = sorted(_vocab, key=lambda item: item[1], reverse=True) # 降序排序
    token2idx = OrderedDict()
    glove_idx = []
    idx = 1
    for key, val in _vocab:
        token2idx[key] = idx
        glove_idx.append(glove_vocab[key]) # 用来读取glove词向量
        idx += 1
    d = {}
    d["token2idx"] = token2idx
    d["glove_idx"] = glove_idx
    with open(dir + "glove_vocab.pickle", "wb") as o:
        pickle.dump(d, o, protocol=pickle.HIGHEST_PROTOCOL)
    return d
vocab = build_vocab_from_glove(freq, glove_vocab, _dir) # 221.4kb
len(vocab["token2idx"])

12488

In [21]:
def load_glove_vocab(_dir):
    with open(_dir + "glove_vocab.pickle", "rb") as r:
        return pickle.load(r)
    return None

In [18]:
UNK_NUM = 100
class GloveTokenizer:
    def __init__(self, glove_vocab, unk_num:int=UNK_NUM):
        self.vocab = glove_vocab
        self.vocab_size = len(glove_vocab)
        self.unk_num = unk_num
        print(self.vocab_size + unk_num)
    def tokenize(self, tokens_list):
        res = []
        for i in tokens_list:
            if i in self.vocab:
                res.append(self.vocab[i])
            else:
                res.append(random.randint(self.vocab_size + 1, self.vocab_size + self.unk_num))
        return res
tokenizer = GloveTokenizer(vocab["token2idx"])

12588


In [22]:
def load_glove_weight(d:int):
    p = re.compile(r"\s")
    path = os.path.join("pretrained", "glove27b", "glove.twitter.27B." + str(d) + "d.txt")
    with open(path, "r") as r:
        file = r.readlines()
    n = len(file)
    weight = np.zeros((n, d), dtype=np.float32)
    for i, line in enumerate(tqdm(file)):
        values = p.split(line.strip())
        if len(values) == d:
            weight[i] = np.asarray(values, dtype=np.float32)
        else:
            weight[i] = np.asarray(values[1:], dtype=np.float32)
    return weight

def get_mvsa_glove_weight(dir, d:int, _uniform:float=0.1):
    path = os.path.join(dir, "glove27b" + str(d) + "d.npy")
    if os.path.exists(path):
        return np.load(path)
    glove_weight = load_glove_weight(d)
    vocab = load_glove_vocab(dir)
    n = len(vocab["token2idx"]) 
    weight = np.zeros((n + UNK_NUM + 1, d), dtype=np.float32) 
    weight[1:n+1] = glove_weight[vocab["glove_idx"]] # 正文
    glove_size = len(glove_weight)
    for i in range(UNK_NUM):
        temp_weight = glove_weight[random.sample(list(range(glove_size)), 100000)]
        weight[n + i + 1] = temp_weight.mean(axis=0) # UNK
    np.save(path, weight) # all 5.4mb
    return weight
weight = get_mvsa_glove_weight(_dir, 100)
weight.shape

100%|██████████| 1193514/1193514 [00:42<00:00, 28140.41it/s]


(12589, 100)

In [24]:
for key in ["train", "valid", "test"]:
    for i in glove_data[key]:
        i["text"] = tokenizer.tokenize(i["text"])
with open(_dir + "glove_data.pickle", "wb") as o:
    pickle.dump(glove_data, o, protocol=pickle.HIGHEST_PROTOCOL)

In [42]:
glove_data["train"][:5]

[{'id': '11401',
  'text': [27,
   7,
   22,
   244,
   364,
   1,
   12519,
   524,
   6,
   3097,
   6093,
   16,
   365,
   46,
   12492,
   16,
   12525,
   2510,
   2,
   1,
   12533],
  'label': 'neutral'},
 {'id': '10832',
  'text': [320,
   1,
   12540,
   1,
   3098,
   1,
   3099,
   1,
   1079,
   1,
   1080,
   1,
   12565,
   1,
   12537,
   1,
   12505],
  'label': 'neutral'},
 {'id': '8336',
  'text': [24,
   600,
   39,
   875,
   6094,
   1467,
   116,
   6095,
   525,
   1,
   12503,
   159,
   27,
   9,
   47,
   76,
   4],
  'label': 'positive'},
 {'id': '13991',
  'text': [287,
   9,
   2150,
   280,
   175,
   662,
   16,
   20,
   54,
   6,
   366,
   550,
   12570,
   12525,
   1,
   12556],
  'label': 'positive'},
 {'id': '18352',
  'text': [12531,
   6096,
   12556,
   15,
   12530,
   15,
   12524,
   3,
   3,
   1,
   72,
   1,
   237,
   1,
   6097,
   1,
   12528,
   1,
   12512,
   1,
   1867,
   1,
   12565],
  'label': 'neutral'}]

In [41]:
def load_mvsa_glove_data():
    with open(_dir + "glove_data.pickle", "rb") as r:
        return pickle.load(r)