# Skip-Gram with Negative Sampling(SGNS)

    기존 학습 과정 : 임베딩 테이블에 있는 모든 단어에 대한 임베딩 벡터 값을 업뎃.
    SGNS : 전체 단어 집합이 아닌 일부 단어 집합에만 집중할 수 있는 방법.
           중심 단어에 대해서 주변 단어 + 랜덤 단어 -> 단어 집합(1 or 0으로 레이블링)
           중심/주변 단어 모두 입력이 되고 두 단어가 실제 윈도우 크기 내에 존재하는 이웃 관계인지 확률을 예측.

## 20뉴스 그룹 이용해 실습

### 1) 데이터 전처리

In [1]:
from sklearn.datasets import fetch_20newsgroups

dataset = fetch_20newsgroups(shuffle=True,random_state=1,remove=('headers','footers','quotes'))
doc = dataset.data

print("total sample : ",len(doc))

total sample :  11314


In [2]:
import pandas as pd

news_df = pd.DataFrame({'doc':doc})

news_df['clean_doc'] = news_df['doc'].str.replace("[^a-zA-Z]"," ")
news_df['clean_doc'] = news_df['clean_doc'].apply(lambda x : ' '.join([w for w in x.split() if len(w) > 3]))
news_df['clean_doc'] = news_df['clean_doc'].apply(lambda x : x.lower())
news_df.replace("",float("NaN"),inplace=True)

  news_df['clean_doc'] = news_df['doc'].str.replace("[^a-zA-Z]"," ")


In [3]:
news_df.isnull().sum()

doc          218
clean_doc    319
dtype: int64

In [4]:
news_df.dropna(inplace=True)
print("total sample :",len(news_df))

total sample : 10995


In [5]:
# 불용어 제거
from nltk.corpus import stopwords

stop_words = stopwords.words('english')
tokenized_doc = news_df['clean_doc'].apply(lambda x : x.split())
tokenized_doc = tokenized_doc.apply(lambda x : [word for word in x if word not in stop_words])
tokenized_doc = tokenized_doc.to_list()

# 단어가 1개 이하인 샘플 제거
import numpy as np
drop_word = [idx for idx, sent in enumerate(tokenized_doc) if len(sent) <= 1]
tokenized_doc = np.delete(tokenized_doc,drop_word,axis=0)

print("total samples : ",len(tokenized_doc))

total samples :  10940


In [6]:
# 정수 인코딩
from tensorflow.keras.preprocessing.text import Tokenizer

tokenizer = Tokenizer()
tokenizer.fit_on_texts(tokenized_doc)

word2idx = tokenizer.word_index
idx2word = {v:k for k,v in word2idx.items()}
encoded = tokenizer.texts_to_sequences(tokenized_doc)

vocab_size = len(word2idx) + 1
print("size of words set : ",vocab_size)

size of words set :  64277


### 2) 네거티브 샘플링을 통한 데이터셋 구성
    
    keras 에서 제공하는 skipgrams 사용.
    시간이 많이 걸림으로 상위 10개의 뉴스그룹 샘플에 대해서만 수행.

In [7]:
from tensorflow.keras.preprocessing.sequence import skipgrams

# 네거티브 샘플링
skip_grams = [skipgrams(sample, vocabulary_size=vocab_size,window_size=10) for sample in encoded]

In [8]:
# 첫번째 샘플인 skip_grams[0] 내 데이터셋 확인
pairs, labels = skip_grams[0][0], skip_grams[0][1]

for i in range(5) :
    print("({:s} ({:d}), {:s} ({:d})) -> {:d}".format(
          idx2word[pairs[i][0]], pairs[i][0], 
          idx2word[pairs[i][1]], pairs[i][1], 
          labels[i]))

(guilt (4989), rgammon (41237)) -> 0
(degree (1530), europeans (4520)) -> 1
(occured (4294), marv (30356)) -> 0
(power (68), vessels (12350)) -> 0
(clearly (661), vpic (3948)) -> 0


### 3) SGNS 구현

In [9]:
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Embedding, Reshape, Activation, Input
from tensorflow.keras.layers import Dot


embed_size = 100

# 중심 단어를 위한 임베딩 테이블
w_inputs = Input(shape=(1, ), dtype='int32')
word_embedding = Embedding(vocab_size, embed_size)(w_inputs)

# 주변 단어를 위한 임베딩 테이블
c_inputs = Input(shape=(1, ), dtype='int32')
context_embedding  = Embedding(vocab_size, embed_size)(c_inputs)

dot_product = Dot(axes=2)([word_embedding, context_embedding])
dot_product = Reshape((1,), input_shape=(1, 1))(dot_product)
output = Activation('sigmoid')(dot_product)

model = Model(inputs=[w_inputs, c_inputs], outputs=output)
model.summary()
model.compile(loss='binary_crossentropy', optimizer='adam')

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 1)]          0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None, 1)]          0                                            
__________________________________________________________________________________________________
embedding (Embedding)           (None, 1, 100)       6427700     input_1[0][0]                    
__________________________________________________________________________________________________
embedding_1 (Embedding)         (None, 1, 100)       6427700     input_2[0][0]                    
______________________________________________________________________________________________

In [None]:
# 커널이 계속 죽네
for epoch in range(1, 6):
    loss = 0
    for _, elem in enumerate(skip_grams):
        first_elem = np.array(list(zip(*elem[0]))[0], dtype='int32')
        second_elem = np.array(list(zip(*elem[0]))[1], dtype='int32')
        labels = np.array(elem[1], dtype='int32')
        X = [first_elem, second_elem]
        Y = labels
        loss += model.train_on_batch(X,Y)  
    print('Epoch :',epoch, 'Loss :',loss)

### 4) 결과 확인하기.
    

학습된 벡터를 txt 로 저장 후 gensim 이용해 로드하면 쉽게 단어 간 유사도 구할 수 있다.

In [10]:
# vector.txt 에 저장
f = open('vector.txt','w')
f.write('{} {}\n'.format(vocab_size-1,embed_size))
vectors = model.get_weights()[0]

for word, i in tokenizer.word_index.items():
    f.write('{} {}\n'.format(word,' '.join(map(str,list(vectors[i,:])))))

f.close()

# vector.txt 로드

import gensim
w2v = gensim.models.KeyedVectors.load_word2vec_format('vector.txt',binary=False)

w2v.most_similar(positive=['doctor'])


[('contoller', 0.4101942777633667),
 ('xtappgeterrordatabase', 0.4040236175060272),
 ('cradle', 0.40330061316490173),
 ('xichang', 0.3876398205757141),
 ('protestants', 0.384991317987442),
 ('intertestamental', 0.38475894927978516),
 ('reductions', 0.3805429935455322),
 ('empt', 0.3667444586753845),
 ('asai', 0.35873445868492126),
 ('fisk', 0.3551829159259796)]

In [11]:
w2v.most_similar(positive=['police'])

[('blinker', 0.39932137727737427),
 ('snore', 0.3857707381248474),
 ('comin', 0.3791111707687378),
 ('marvel', 0.3787321448326111),
 ('lakshmivarahan', 0.3746788501739502),
 ('cdfsga', 0.3690260350704193),
 ('starnet', 0.3647075295448303),
 ('entranced', 0.3577590882778168),
 ('pradhaan', 0.356963187456131),
 ('rainstorm', 0.35132846236228943)]