In [85]:
import torch
import pandas as pd
import numpy as np
import pickle
import torchvision.models as models

In [86]:
dim = 50

In [8]:
def exShapeMatrix(m_path='ShapeNet.pth', d=300):
    m_path = m_path.replace('Net', 'Net_'+str(d)+'d')
    pretrained_dict = torch.load(m_path)
    M = pretrained_dict['model']['classifier.0.weight']
    # resave
    torch.save({'shape2vec': M}, f'shape2vec.{d}d')
    print("Saved!")
# M = pretrained_dict['model']['classifier.0.weight'].cpu().detach().numpy()

exShapeMatrix(d=dim)

Saved!


In [4]:
def get_label_dict():
    f = open('./chinese_labels', 'rb')
    label_dict = pickle.load(f)
    f.close()
    return label_dict

id2char = get_label_dict()
char2id = {x: y for x, y in zip(id2char.values(), id2char.keys())}

def getAllReady():
    id2char = get_label_dict()
    char2id = {x: y for x, y in zip(id2char.values(), id2char.keys())}
    M = torch.load(f'./shape2vec.{dim}d')['shape2vec']
    print(M.shape)
    return M
    
M = getAllReady()

torch.Size([3755, 50])


In [87]:
def getTopKSim(q_char, K=20):
    id = char2id[q_char]
    sim_lst = []
    for i in range(M.shape[0]):
        sim = torch.cosine_similarity(M[id], M[i], dim=-1)
        sim_lst.append((sim, i))
    sim_lst.sort(key=lambda x: x[0], reverse=True)
    topK_id = sim_lst[1:K+1]
    topK_char = []
    for _, id in topK_id:
        topK_char.append(id2char[id])
    return topK_char


#### Test

In [88]:
# 50d
print("鹿：", getTopKSim("鹿", 20))
print("于：", getTopKSim("于", 20))
print("茵：", getTopKSim("茵", 20))
print("少：", getTopKSim("少", 20))
print("饮：", getTopKSim("饮", 20))

鹿： ['蔑', '席', '衷', '商', '胞', '展', '度', '厘', '底', '宦', '腕', '胰', '虎', '窟', '寇', '蓖', '腥', '庭', '彪', '脯']
于： ['干', '子', '丁', '吁', '巧', '云', '壬', '古', '乎', '蹬', '予', '晋', '牙', '手', '天', '霞', '订', '责', '舌', '寸']
茵： ['苗', '萤', '菌', '卤', '酋', '窗', '商', '谊', '苞', '芭', '萄', '笛', '值', '砖', '亩', '苟', '囱', '首', '宦', '茧']
少： ['吵', '沙', '父', '小', '尘', '乡', '山', '今', '纱', '炒', '步', '夕', '刃', '尖', '抄', '诊', '仅', '立', '砂', '仪']
饮： ['钦', '炊', '饥', '坎', '饭', '伙', '忱', '吹', '饺', '次', '恢', '饱', '欲', '砍', '收', '蚀', '价', '欧', '欢', '饶']


In [None]:
# 100d
print("鹿：", getTopKSim("鹿", 20))
print("于：", getTopKSim("于", 20))
print("茵：", getTopKSim("茵", 20))
print("少：", getTopKSim("少", 20))
print("饮：", getTopKSim("饮", 20))

In [None]:
# 300d
print("鹿：", getTopKSim("鹿", 20))
print("于：", getTopKSim("于", 20))
print("茵：", getTopKSim("茵", 20))
print("少：", getTopKSim("少", 20))
print("饮：", getTopKSim("饮", 20))

#### 形似字字典

In [None]:
# 生成（~2hrs）
def getTopKSimDict(K=20):
    sim_dct = {}
    for q_char in char2id.keys():
        id = char2id[q_char]
        sim_lst = []
        for i in range(M.shape[0]):
            sim = torch.cosine_similarity(M[id], M[i], dim=-1)
            sim_lst.append((sim, i))
        sim_lst.sort(key=lambda x: x[0], reverse=True)
        topK_id = sim_lst[1:K+1]
        topK_char = []
        for _, id in topK_id:
            topK_char.append(id2char[id])
        sim_dct[q_char] = topK_char
    return sim_dct

sim_dct = getTopKSimDict()
torch.save(sim_dct, "./sim_dct")

In [89]:
# 加载
def get_sim_dict():
    sim_dict = torch.load('sim_dct')
    return sim_dict

sim_dct = get_sim_dict()

In [90]:
# 字间相似度
def charSim(c1="于", c2="干"):
    # if not exists
    if char2id.get(c1) is None or char2id.get(c2) is None:
        return
    sim = torch.cosine_similarity(M[char2id[c1]], M[char2id[c2]], dim=-1)
    return sim

In [91]:
print(charSim(c1="于", c2="干"))
print(charSim(c1="望", c2="谨"))

tensor(0.5860, device='cuda:0')
tensor(0.5262, device='cuda:0')


#### albert

In [92]:
# pip install transformers==2.2.2
from transformers import BertTokenizer, AlbertForMaskedLM
import torch
import copy
from torch.nn.functional import softmax

pretrained = 'voidful/albert_chinese_tiny'
tokenizer = BertTokenizer.from_pretrained(pretrained)
model = AlbertForMaskedLM.from_pretrained(pretrained)

In [93]:
def correctAll(sent=""):
    assert(len(sent) > 1)
    for i in range(len(sent)):
        msk_char = sent[i]
        msk_sent = sent[:i] + "[MASK]" + sent[i+1:]
        if msk_sent is not None:
            maskpos = tokenizer.encode(msk_sent, add_special_tokens=True).index(103)

            input_ids = torch.tensor(tokenizer.encode(msk_sent, add_special_tokens=True)).unsqueeze(0)  # Batch size 1
            outputs = model(input_ids, masked_lm_labels=input_ids)

            loss, prediction_scores = outputs[:2]

            logit_prob = softmax(prediction_scores[0, maskpos], dim=0).data.tolist()

            _, indices = torch.topk(prediction_scores[0, maskpos], k=10, dim=0)

            for idx in indices:
                idx = idx.item()
                predicted_token = tokenizer.convert_ids_to_tokens([idx])[0]
                sim = charSim(c1=msk_char, c2=predicted_token)
                if sim is not None and sim > 0.5:
                    if sent[i] != predicted_token:
                        print(f"{sent[i]} -> {predicted_token}")
                    sent = sent[:i] + predicted_token + sent[i+1:]
                    break
    return sent


In [94]:
print(correctAll("令天心请不错！"))

令 -> 今
请 -> 情
今天心情不错！


In [64]:
print(correctAll("金国90%以上进口冷链食品可追溯"))

金 -> 全
全国90%以上进口冷链食品可追溯


In [65]:
print(correctAll("拜登拟任命亚州事务王管"))

州 -> 洲
王 -> 主
拜登拟任命亚洲事务主管


In [66]:
print(correctAll("小先队圆全称：中国小年先锋队队员"))

小 -> 少
小先队圆全称：中国少年先锋队队员


In [56]:
print(correctAll("员员的太阳湾湾的月亮。"))

员 -> 长
员 -> 大
湾 -> 海
长大的太阳海湾的月亮。


#### DATASET 4 TEST

In [12]:
import random

In [13]:
random.random()

0.883338453887229

In [70]:
with open("pairs_test.txt", mode="w", encoding='utf8') as p:
    with open("correct_test.txt", encoding='utf8') as f:
        for line in f.readlines():
            new = mix(line).replace('\n', '')
            p.write(new + '\t' + line)

In [69]:
def mix(sent):
    sent = list(sent)
    new = ""
    for idx in range(len(sent)):
        if random.random() > 0.8 and sim_dct.get(sent[idx]) is not None:
            v = random.choice(sim_dct[sent[idx]][:2])
            if charSim(sent[idx], v) > 0.5:
                sent[idx] = random.choice(sim_dct[sent[idx]][:3])
    for ch in sent:
        new += ch
    return new
print(mix("传《半条命2：第三章》今年发售无望"))

传《半绦命2：第三章》今年发售无望


In [48]:
sent = "传《半条命2：第三章》今年发售无望"
sent = list(sent)
new = ""
for idx in range(len(sent)):
    if random.random() > 0.8 and sim_dct.get(sent[idx]) is not None:
        sent[idx] = random.choice(sim_dct[sent[idx]][:2])
sent = 

['传', '《', '伴', '绦', '俞', '2', '：', '第', '三', '章', '》', '今', '年', '发', '售', '无', '望']


In [51]:
str(sent)

"['传', '《', '伴', '绦', '俞', '2', '：', '第', '三', '章', '》', '今', '年', '发', '售', '无', '望']"

In [30]:
random.choice(sim_dct["风"][:3])

'冈'

In [49]:
sim_dct["望"][:3]

['谨', '窒', '塑']