In [1]:
import pandas as pd
import numpy as np

from ckiptagger import WS, POS
from tqdm.notebook import tqdm

In [2]:
df_train = pd.read_csv('news_clustering_train.tsv', sep='\t')
df_test = pd.read_csv('news_clustering_test.tsv', sep='\t')

In [3]:
train_titles = {row['index']: row['title'] for _, row in df_train.iterrows()}
train_classes = {row['index']: row['class'] for _, row in df_train.iterrows()}

test_titles = {row['index']: row['title'] for _, row in df_test.iterrows()}
test_classes = {row['index']: row['class'] for _, row in df_test.iterrows()}

In [4]:
all_news_class = ['體育', '財經', '科技', '旅遊', '農業', '遊戲']

# 斷詞 + POS

In [5]:
ws = WS('D07data')
pos = POS('D07data')

In [6]:
train_title_cuts = {}
for index, title in tqdm(train_titles.items()):
    word_s = ws([title])
    word_p = pos(word_s)
    train_title_cuts[index] = list(zip(word_s[0], word_p[0]))

HBox(children=(IntProgress(value=0, max=1800), HTML(value='')))




In [7]:
test_title_cuts = {}
for index, title in tqdm(test_titles.items()):
    word_s = ws([title])
    word_p = pos(word_s)
    test_title_cuts[index] = list(zip(word_s[0], word_p[0]))

HBox(children=(IntProgress(value=0, max=600), HTML(value='')))




# 尋找降維的詞向量：PPMI + SVD

In [8]:
word2index = {}
index2word = {}
n = 0
for index in train_title_cuts:
    for word, flag in train_title_cuts[index]:
        if word in word2index:
            continue
        word2index[word] = n 
        index2word[n] = word
        n += 1

In [9]:
len(word2index)

6690

如果使用one-hot就需要這麼大的維度的詞向量

In [13]:
# 建立Co-Matrix

vocab_size = len(word2index)
co_matrix = np.zeros(shape=(vocab_size, vocab_size), dtype=np.int32)

window_size = 1
for pairs in train_title_cuts.values():
    indices = [word2index[word] for word, _ in pairs]
    for center_i, center_id in enumerate(indices):
        context_ids = indices[max(0,center_i-window_size):center_i]
        for left_word_id in context_ids:
            co_matrix[left_word_id, center_id] += 1
            co_matrix[center_id, left_word_id] += 1
        

In [14]:
co_matrix

array([[0, 2, 0, ..., 0, 0, 0],
       [2, 0, 1, ..., 0, 0, 0],
       [0, 1, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]])

In [15]:
# 建立PPMI

def get_ppmi(co_matrix: np.ndarray, eps: float=1e-8):
    M = np.zeros_like(co_matrix)
    N = np.sum(co_matrix)
    S = co_matrix.sum(axis=0)
    
    for i in range(co_matrix.shape[0]):
        for j in range(co_matrix.shape[1]):
            
            pmi = np.log2(co_matrix[i,j]*N / (S[i]*S[j] + eps))
            M[i,j] = max(0, pmi)
    return M

ppmi = get_ppmi(co_matrix)

  # This is added back by InteractiveShellApp.init_path()


In [16]:
ppmi

array([[ 0,  9,  0, ...,  0,  0,  0],
       [ 9,  0, 11, ...,  0,  0,  0],
       [ 0, 11,  0, ...,  0,  0,  0],
       ...,
       [ 0,  0,  0, ...,  0,  0,  0],
       [ 0,  0,  0, ...,  0,  0,  0],
       [ 0,  0,  0, ...,  0,  0,  0]])

In [18]:
# 進行SVD分解，並得到降維的詞向量

from sklearn.decomposition import TruncatedSVD

# 使用`TruncatedSVD`進行降維，降維到dim=1000
svd = TruncatedSVD(n_components=1000)
word_vectors = svd.fit_transform(ppmi)

In [19]:
word_vectors.shape

(6690, 1000)

# 新的詞向量 + Group mean vector: 測試

In [20]:
excluded_flags = [
    'Nh', 'Nep', 'Nes', 'DE', 'T', 'P', 'V_2', 'SHI',
    'Dfa', 'Dfb', 'Da', 'Di', 'Dk',
    'Caa', 'Cab', 'Cba', 'Cbb',
    'COLONCATEGORY', 'COMMACATEGORY', 'DASHCATEGORY', 'DOTCATEGORY', 'ETCCATEGORY', 'EXCLAMATIONCATEGORY',
    'PARENTHESISCATEGORY', 'PAUSECATEGORY', 'PERIODCATEGORY', 'QUESTIONCATEGORY', 'SEMICOLONCATEGORY',
    'SPCHANGECATEGORY', 'WHITESPACE'
]

In [21]:
train_svd_vectors = {}
for index, pairs in train_title_cuts.items():
    selected_word_vectors = []
    for word, flag in pairs:
        if word in word2index and flag not in excluded_flags:
            selected_word_vectors.append(word_vectors[word2index[word], :])
    vector = np.sum(selected_word_vectors, axis=0)
    if np.sum(np.square(vector)) == 0:
        continue
    train_svd_vectors[index] = vector
    

In [22]:
test_svd_vectors = {}
for index, pairs in test_title_cuts.items():
    selected_word_vectors = []
    for word, flag in pairs:
        if word in word2index and flag not in excluded_flags:
            selected_word_vectors.append(word_vectors[word2index[word], :])
    vector = np.sum(selected_word_vectors, axis=0)
    if np.sum(np.square(vector)) == 0:
        continue
    test_svd_vectors[index] = vector

In [23]:
group_vectors = {news_class: [] for news_class in all_news_class}
for index, vector in sorted(train_svd_vectors.items()):
    news_class = train_classes[index]
    group_vectors[news_class].append(vector)

group_mean_vector = {}
for news_class, vectors in group_vectors.items():
    group_mean_vector[news_class] = np.mean(vectors, axis=0)

In [24]:
def cosine_similarity(bow1, bow2):
    len_bow1 = np.sqrt(np.sum(np.square(bow1)))
    len_bow2 = np.sqrt(np.sum(np.square(bow2)))
    return np.sum(bow1 * bow2) / (len_bow1 * len_bow2)

In [25]:
classification = {news_class: [] for news_class in all_news_class}
for index, vector in sorted(test_svd_vectors.items()):
    if np.sum(np.square(vector)) == 0:
        continue

    max_val = -2.0
    max_class = None
    for news_class, ref_vector in group_mean_vector.items():
        val = cosine_similarity(ref_vector, vector)
        if val > max_val:
            max_class = news_class
            max_val = val

    classification[max_class].append(index)

In [26]:
from collections import Counter

for group, ids in classification.items():
    counter = Counter([test_classes[id] for id in ids])
    print('predict', group, ': ', counter)

predict 體育 :  Counter({'體育': 59, '財經': 8, '遊戲': 8, '旅遊': 7, '科技': 4, '農業': 3})
predict 財經 :  Counter({'財經': 62, '科技': 23, '農業': 16, '體育': 10, '遊戲': 8, '旅遊': 5})
predict 科技 :  Counter({'科技': 53, '體育': 15, '財經': 14, '農業': 9, '遊戲': 9, '旅遊': 8})
predict 旅遊 :  Counter({'旅遊': 60, '農業': 10, '科技': 5, '財經': 4, '遊戲': 3, '體育': 2})
predict 農業 :  Counter({'農業': 60, '旅遊': 8, '體育': 4, '遊戲': 4, '財經': 3, '科技': 2})
predict 遊戲 :  Counter({'遊戲': 68, '科技': 12, '旅遊': 10, '財經': 9, '體育': 8, '農業': 1})
