# torch `nn.Embedding`

In [2]:
import nltk

nltk.download('punkt')
nltk.download('punkt_tab')
nltk.download('stopwords')

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\Playdata\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\Playdata\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\Playdata\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


True

## 사전학습된 임베딩을 사용하지 않는 경우



In [1]:
sentences = [
    'nice great best amazing',  # 긍정 문장 예시
    'stop lies',                # 부정/비판 문장 예시
    'pitiful nerd',             # 부정 문장 예시
    'excellent work',           # 긍정 문장 예시
    'supreme quality',          # 긍정 문장 예시
    'bad',                      # 부정 문장 예시
    'highly respectable'        # 긍정 문장 예시
]                               # 분류 모델에 넣을 입력 문장 리스트(list[str])
labels = [1, 0, 0, 1, 1, 0, 1]  # 각 문장에 대한 이진 라벨(1=긍정, 0=부정)

In [None]:
# 토큰화
from nltk.tokenize import word_tokenize

tokenized_sentences = [word_tokenize(sent) for sent in sentences]   # 각 문장을 토큰 리스트(list(list[str]))로 변환
tokenized_sentences

[['nice', 'great', 'best', 'amazing'],
 ['stop', 'lies'],
 ['pitiful', 'nerd'],
 ['excellent', 'work'],
 ['supreme', 'quality'],
 ['bad'],
 ['highly', 'respectable']]

In [None]:
# 단어 사전 생성 + 정수 인코딩
from collections import Counter

tokens = [token for sent in tokenized_sentences for token in sent]           # 문장 리스트를 평탄화하여 전체 토큰 리스트 생성
word_counts = Counter(tokens)                                                # 전체 토큰 등장 빈도 계산      
print(word_counts)                                                           # 토큰별 빈도 딕셔너리 형태

word_to_index = {word: index + 2 for index, word in enumerate(tokens)}       # 토큰을 순서대로 인덱싱(+2 : 특수토큰용)
word_to_index['<PAD>'] = 0                                                   # 패딩 토큰(길이 맞추기용)
word_to_index['<UNK>'] = 1                                                   # OOV 토큰 (처리 불가 단어 대체)
word_to_indexP = dict(sorted(word_to_index.items(), key=lambda x: x[1]))     # 인덱스를 기준으로 정렬
print(word_to_index)                                                         # 단어 -> 인덱스 사전

vocab_size = len(word_to_index)                                              # 전체 어휘 수(특수 토큰 포함)
vocab_size

Counter({'nice': 1, 'great': 1, 'best': 1, 'amazing': 1, 'stop': 1, 'lies': 1, 'pitiful': 1, 'nerd': 1, 'excellent': 1, 'work': 1, 'supreme': 1, 'quality': 1, 'bad': 1, 'highly': 1, 'respectable': 1})
{'nice': 2, 'great': 3, 'best': 4, 'amazing': 5, 'stop': 6, 'lies': 7, 'pitiful': 8, 'nerd': 9, 'excellent': 10, 'work': 11, 'supreme': 12, 'quality': 13, 'bad': 14, 'highly': 15, 'respectable': 16, '<PAD>': 0, '<UNK>': 1}


17

In [None]:
# 정수 인코딩 함수 : 토큰화된 문장 리스트를 단어 -> 인덱스 사전으로 정수 시퀀스 (list[list(int)])로 변ㅋ환
def texts_to_sequence(sentences, word_to_index):
    sequences = []
    
    for sent in sentences:                                  # 문장 단위로 반복
        sequence = []
        
        for token in sent:
            if token in word_to_index:
                sequence.append(word_to_index[token])       # 해당 단어 인덱스 추가
            else:
                sequence.append(word_to_index['<UNK>'])     # 사전에 없으면 UNK 토큰

        sequences.append(sequence)  

    return sequences

sequences = texts_to_sequence(tokenized_sentences, word_to_index)
sequences

[[2, 3, 4, 5], [6, 7], [8, 9], [10, 11], [12, 13], [14], [15, 16]]

In [None]:
import numpy as np

# 서로 다른 길이의 정수 시퀀스를 0(<PAD>)으로 채워 (문장수, maxlen) 형태에 맞춰주는 함수
def pad_sequences(sequences, maxlen):
    padded_sequences = np.zeros((len(sequences), maxlen), dtype=int)    # (문장수 x maxlen) 크기의 0 패딩 배열
    
    for index, seq in enumerate(sequences):                             # 각 문장 시퀀스 순회
        padded_sequences[index, :len(seq)] = seq[:maxlen]               # 앞에서부터 시퀀스 채운다. 길면 maxlen까지만 채워 자른다.
        
    return padded_sequences                                             # 패딩 작업 완료된 2D 배열


padded_sequences = pad_sequences(sequences, maxlen=4)                   # 모든 문장 길이 4로 패딩/자르기
padded_sequences                                                        # (문장 수, 4) 형태

array([[ 2,  3,  4,  5],
       [ 6,  7,  0,  0],
       [ 8,  9,  0,  0],
       [10, 11,  0,  0],
       [12, 13,  0,  0],
       [14,  0,  0,  0],
       [15, 16,  0,  0]])

In [6]:
padded_sequences.shape

(7, 4)

In [None]:
# Pytorch 텍스트 분류 모델 : Embedding + RNN + Linear로 이진분류(logit) 출력
import torch
import torch.nn as nn         # 신경망 레이어
import torch.optim as optim   # 옵티마이저 (활성화함수)
from torch.utils.data import DataLoader, TensorDataset # 배치로더 / 데이터셋 유틸

class SimpleNet(nn.Module):
    # 정수 시퀀스를 임베딩 -> RNN -> 선형층으로 처리해 이진 분류 logit(1개)를 출력
    def __init__(self, vocab_size, embedding_dim, hidden_size):
        super().__init__()                 # nn.Module 초기화
        self.embedding = nn.Embedding(     # 단어 ID를 밀집 벡터로 변환하는 임베딩 층
            num_embeddings = vocab_size,   # 단어 사전 크기 (어휘 수)
            embedding_dim = embedding_dim, # 임베딩 차원
            padding_idx = 0                # PAD(0) 인덱스는 0 그대로 사용
        )
        self.rnn = nn.RNN(
            embedding_dim,                 # 입력 차원
            hidden_size,                   # 은닉 상태 차원
            batch_first = True             # 배치 차원이 첫번째
        )
        self.out = nn.Linear(
            hidden_size, 1                 # 마지막 은닉 상태를 1차원 logit으로 변환
        )
    
    def forward(self, x):
        embedded = self.embedding(x)       # (batch, seq_len) -> (batch, seq_len, embedding_dim)
        out, h_n = self.rnn(embedded)      # h_n : (num_layers * directions , batch, hidden_size)
        out = self.out(h_n.squeeze(0))     # (batch_size, hidden_size) -> (batch, 1)
        return out                         # 출력 : 시그모이드 전 logit(확률이 아님)

embedding_dim = 100                        # 단어 벡터 차원 설정
model = SimpleNet(vocab_size, embedding_dim, hidden_size = 16) #  어휘 크기 / 임베딩 차원 / 은닉크기로 모델 생성
model

SimpleNet(
  (embedding): Embedding(17, 100, padding_idx=0)
  (rnn): RNN(100, 16, batch_first=True)
  (out): Linear(in_features=16, out_features=1, bias=True)
)

In [16]:
from torchinfo import summary # 모델 구조를 표 형태로 요약

summary(model)                # model의 레이어 구성 / 파라미터 수를 요약

Layer (type:depth-idx)                   Param #
SimpleNet                                --
├─Embedding: 1-1                         1,700
├─RNN: 1-2                               1,888
├─Linear: 1-3                            17
Total params: 3,605
Trainable params: 3,605
Non-trainable params: 0

In [17]:
# 임베딩 가중치 확인 : 학습 전/후 Embedding 테이블과 단어별 벡터 조회

import pandas as pd

# 학습 전 임베딩 벡터
wv = model.embedding.weight.data    # Embedding 층의 가중치 행렬(단어ID x 임베딩 차원) 추출
print(wv.shape)                     # (vocab_size, embedding_dim)

# 특정 단어 벡터
vocab = word_to_index.keys()        # 단어사전에서 단어만 뽑아온다.
pd.DataFrame(wv, index=vocab)

torch.Size([17, 100])


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99
nice,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
great,0.775794,-0.194863,-0.323649,-0.186236,-0.120143,1.671384,0.844475,0.873787,-0.102356,1.111574,-0.133635,1.713191,-0.715437,0.069271,-1.314278,0.085946,1.054635,-0.026516,0.24447,0.057606,0.474399,-1.993125,1.187491,0.464123,1.338911,-0.34566,-1.069142,-0.073286,0.482998,2.641071,-0.497173,-0.062149,0.550014,1.11841,-0.28218,0.097933,-0.958356,-0.988669,0.233976,-0.037655,...,-1.339609,0.220197,-0.534401,0.669685,2.429523,2.062786,-0.643933,-0.711514,-1.430425,-0.795139,1.001662,1.378816,1.829479,-1.333533,1.211764,-1.112946,1.363706,-0.252723,-0.426644,1.232487,-0.036241,0.753172,0.43666,0.258547,0.004333,-0.406172,-2.660337,0.687404,-1.292208,-0.914667,-1.169202,0.706188,-0.273713,2.256105,-2.815936,0.156267,-1.035854,1.407637,-0.696666,-1.541372
best,0.880386,0.544978,0.251571,1.287426,0.25879,0.977891,-1.050327,1.534952,-0.31319,0.422316,-1.418592,1.921553,-0.297401,-0.324536,-2.084427,0.591316,0.098162,-0.838386,-0.210605,0.583819,-0.378457,0.551645,-0.043866,-0.027884,1.494357,0.261525,0.260176,-2.011613,0.501413,0.201501,2.203593,-0.540236,1.856648,0.973109,-0.668786,0.680517,-1.117998,-0.159861,-0.830109,0.398681,...,0.533331,0.646759,0.546802,-0.375187,-0.216122,-0.123625,0.156251,0.412533,-0.386952,0.014226,-0.10855,-0.739278,-0.834107,-1.499457,1.177024,0.296533,0.25794,1.648079,-0.031395,1.23994,-0.290355,0.445231,-1.400327,1.94163,0.782338,1.705711,1.039511,0.126029,-0.282044,-1.086391,0.61487,0.903278,0.96204,0.178664,-0.620035,1.28964,-0.977955,-1.066367,1.264904,1.515571
amazing,-0.482203,-1.514841,0.573388,1.294246,-0.141386,-1.527407,-0.090894,0.211856,0.768626,-0.575384,0.736074,0.14412,1.01527,0.343379,0.077584,0.575889,-0.657602,0.545769,-1.880338,0.885284,0.205546,1.353168,-1.465541,1.693276,-0.661827,-0.921864,-3.147268,1.60228,-0.247682,-0.708077,-0.054856,0.876009,-0.711294,-0.506791,0.119807,-0.359548,0.87317,0.010822,1.163037,-0.25411,...,-0.350184,-0.144217,-0.012687,1.990358,-0.636763,1.955903,-0.0511,-0.039844,-0.972721,0.853346,-0.738375,-1.284361,0.23663,-0.989357,-0.328943,0.744524,0.293493,-1.147221,0.257053,1.314096,1.674895,-1.553812,-0.555219,0.613866,-0.744844,-0.128494,-0.35707,0.693233,0.812784,-1.013538,-0.433409,-0.530956,-0.845212,-1.259531,0.405191,0.822183,1.277288,-0.528976,-1.469818,-2.064824
stop,-0.465107,0.807339,0.0346,-0.993048,-0.762271,1.569123,0.008205,-0.825188,0.881909,-0.764524,1.255022,-0.45496,-0.001288,-1.161705,-0.855213,0.98335,1.79983,-0.172322,-0.495682,0.930947,-1.907112,-0.0435,1.271795,0.153612,1.662385,-0.596851,-0.169525,0.357262,-0.065617,0.287547,1.017096,0.116832,-0.149508,-0.955815,-1.032361,1.188748,-0.381669,0.606815,-1.080582,-1.41149,...,0.56823,-0.500958,0.247356,0.161332,1.284875,0.817403,0.121264,0.439143,-1.40431,0.475928,-1.926452,0.358652,0.408809,-0.651928,-1.878429,-1.750882,-0.412977,-0.798207,0.279826,-0.788834,-0.399392,-0.641278,-2.129856,-0.617535,0.172147,-0.691182,-0.237168,1.863462,-0.298035,-0.090845,-0.715637,-0.058978,-0.674597,-0.325651,0.757723,-0.480383,0.338044,2.098855,-0.352852,0.888461
lies,0.209052,-0.849466,-0.067222,-0.692221,-1.353825,-0.167834,-0.630096,0.002504,-1.581282,0.37046,0.482342,-1.074471,-2.020626,-0.207465,-1.170002,-0.416874,-0.217819,-0.739021,0.114173,-0.664361,-0.868051,0.634751,-0.289991,0.35416,0.88984,0.549405,-0.767899,-0.389358,-0.548047,0.109175,0.128485,-0.246259,1.149912,1.495458,-0.635587,-1.543033,0.901147,-0.365757,-1.345393,0.128872,...,0.379396,-0.043553,-1.396819,-0.258506,-1.031098,-0.259572,-1.15458,-1.342501,-0.934142,-0.287301,-1.134779,0.721175,-1.336361,0.030684,0.405205,-0.455942,-0.958824,-0.722825,-0.594669,-0.469122,0.703961,1.229537,0.918397,-0.195124,1.111732,0.299944,0.766068,0.547975,-1.840689,-1.017564,0.045298,-0.02638,-2.752438,1.093626,0.277874,-0.644846,-0.51878,-1.052457,-1.469079,-0.018341
pitiful,1.840178,-0.499861,-1.052492,0.567679,0.56538,-0.064665,0.749571,-0.462899,1.296457,1.630518,1.741418,0.080681,0.907792,-0.037061,-0.714092,-0.131203,1.198142,-0.171694,-0.322398,0.340406,0.409167,-0.308406,-1.303023,-0.657259,-1.34103,0.177849,0.676288,1.134787,-0.056571,1.533202,0.929121,0.385078,0.663952,0.31625,-1.746564,0.230567,-0.005842,0.720549,-3.121509,0.633064,...,0.961692,-0.064796,-1.055362,-0.409348,0.30931,1.163396,0.62279,0.749045,-0.374006,0.738836,0.564056,1.924635,0.589641,-0.516324,1.475651,-0.796467,-1.618379,0.241081,2.30275,0.412842,-0.747896,0.119906,0.07321,-0.041104,0.375421,-0.557361,-0.099334,-0.713493,-2.084928,1.151414,0.752574,-0.187914,0.898261,-0.126317,-1.138474,0.318433,2.478432,1.559519,0.074078,0.242984
nerd,-0.040889,-1.263668,-1.484635,-1.529543,0.24479,-0.189859,1.546927,-0.701815,0.126868,0.954621,-0.215214,-1.906773,-0.458584,0.374746,0.820955,-0.614607,-0.656457,-0.300699,-0.474318,-0.891234,-0.262245,0.984514,-0.115808,-1.194289,-0.482119,1.130238,-0.144326,0.892314,-0.763715,0.248944,-1.067972,-1.115433,-0.80409,2.644451,0.013447,2.05988,1.538976,0.439843,0.210279,0.407809,...,0.71685,0.39955,-0.589025,0.589231,0.942851,-1.845391,-1.270614,0.436259,0.769167,0.02165,0.117579,-1.397538,-0.109113,-0.090183,0.041412,0.62311,1.564511,1.186734,-0.173256,0.029839,-0.05682,-1.019324,-0.684135,0.901315,2.343261,0.382312,0.500134,0.409483,0.739303,-1.452829,-0.163135,-1.066609,-1.512901,0.254164,0.74887,-1.271425,0.141821,-1.078734,-0.537534,0.206237
excellent,-0.141044,0.367432,-1.740816,-0.77268,-0.884874,-0.980222,0.018667,-0.243648,0.486145,0.362261,2.355379,0.467203,0.998574,0.185857,-1.6079,1.114627,-0.978788,0.950492,-0.878775,1.169639,-1.061287,1.394224,-0.139489,1.433773,-1.423423,1.086332,-0.933373,0.67548,-1.590089,-0.649141,0.091636,0.059193,0.397155,-1.133233,-0.797245,-0.224359,-0.400199,0.031862,-1.065524,-1.686958,...,0.443944,-0.400523,-0.077229,-0.809964,-0.675451,-1.804889,-0.481853,-0.261289,1.074996,-0.228264,0.097753,-0.433253,1.33921,-2.105231,-0.491323,0.253188,2.190663,0.612754,-0.448705,1.125045,1.092168,2.060359,-1.272265,0.505646,-0.435757,-1.020588,0.160759,1.025814,0.541286,-0.337916,-0.248467,-0.515574,-1.580935,-0.42419,-0.370452,1.007201,0.351547,-0.456819,-0.033226,-0.079322
work,0.375592,2.202576,-0.679846,-0.146635,-1.132112,-0.040998,1.268483,0.744403,-0.146779,-0.52276,-0.507,-2.697689,1.040195,1.465421,-1.435217,-0.772629,0.118911,-0.069034,0.487303,-0.328601,0.154591,-0.419681,-0.127474,0.212925,0.424153,0.073481,-0.61972,-0.281288,0.248176,-1.856938,-1.417135,-0.283913,-0.994068,0.269628,-0.438877,-0.40922,0.289088,-0.081232,0.409475,0.650914,...,1.367074,1.511298,0.22267,-0.153669,-0.141149,0.143551,-1.888577,0.837668,0.634942,-1.005419,-0.368261,-1.503276,0.715695,0.041575,0.365894,0.094418,0.742555,-0.292024,-0.421559,0.63899,1.192882,-0.450639,-0.841768,0.715896,0.212123,-0.95192,0.499701,0.558598,-0.445627,0.973964,0.94483,0.234452,-0.415458,0.872105,0.93584,-1.668889,1.213822,0.779819,-1.46028,1.755072


In [18]:
# Pytorch 학습 준비 : 텐서 변환 -> DataLoader 구성 -> 손실함수/옵티마이저 설정
X = torch.tensor(padded_sequences,dtype=torch.long)          # 입력 시퀀스(정수 ID)를 LongTensor로 변환
y = torch.tensor(labels, dtype=torch.float).unsqueeze(1)     # 라벨을 float으로 변환 후(N,) -> (N,1) 로 차원 맞춤

dataset = TensorDataset(X, y)                                  #(X,y) 쌍을 Dataset 객체로 묶음
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)   # 배치 단위로 섞여서 공급하는 로더

criterion = nn.BCEWithLogitsLoss()                             # 출력 logit과 정답(0/1)로 이진분류 손실 계산(시그모이드 포함)
optimizer = optim.Adam(model.parameters(), lr=0.005)           # 모델 파라미터는 Adam으로 업데이트

BCEWithLogitsLoss을 사용할떄에는 모델 출력이 Sigmoid를 거치지 않은 logit이어야 한다.

In [None]:
# 학습 루프: 미니배치 단위로 20 epoch 학습하며 평균 손실 출력
for epoch in range(20):
    epoch_loss = 0                          # 손실 누적
    
    for x_batch, y_batch in dataloader:     # 미니배치 단위로(X, y)가져오기
        optimizer.zero_grad()               # 이전 배치 기울기 초기화
        output = model(x_batch)             # 순전파로 logit 계산
        loss = criterion(output, y_batch)   # 예측 logit과 정답으로 손실 계산
        loss.backward()                     # 역전파로 기울기 계싼
        optimizer.step()                    # 파라미터 업데이트
        
        epoch_loss += loss.item()           # 배치손실을 float으로 누적
        
    print(f"Epoch {epoch + 1}: Loss {epoch_loss/ len(dataloader)}") # epoch별 평균 손실 계산

Epoch 1: Loss 0.6762917786836624
Epoch 2: Loss 0.5656344443559647
Epoch 3: Loss 0.43073007464408875
Epoch 4: Loss 0.3385210931301117
Epoch 5: Loss 0.24175537377595901
Epoch 6: Loss 0.18178850412368774
Epoch 7: Loss 0.12561515346169472
Epoch 8: Loss 0.08815594390034676
Epoch 9: Loss 0.0681375078856945
Epoch 10: Loss 0.05171221029013395
Epoch 11: Loss 0.04189129080623388
Epoch 12: Loss 0.03407358378171921
Epoch 13: Loss 0.02891269465908408
Epoch 14: Loss 0.025184732396155596
Epoch 15: Loss 0.022377862595021725
Epoch 16: Loss 0.020712945610284805
Epoch 17: Loss 0.018600626848638058
Epoch 18: Loss 0.017010850366204977
Epoch 19: Loss 0.015507802832871675
Epoch 20: Loss 0.014427006943151355


In [None]:
# 평가 / 예측 : 학습된 모델로 확률 -> 0/1 예측값 생성 후 정답과 비교
model.eval()                                # 평가 모드
with torch.no_grad():                       # 기울기 계산 비활성화
    output = model(X)                       # 전체 샘플에 대한 예측 logit 계산
    prob = torch.sigmoid(output)            # logit에 0~1 확률로 변환
    pred = (prob >= 0.5).int()              # 임계값 0.5 기준으로 이진 분류(0/1) 예측값 생성

print(labels)
print(pred.squeeze().detach().numpy())      # 예측 라벨을 1차원 numpy 배열로 변환

[1, 0, 0, 1, 1, 0, 1]
[1 0 0 1 1 0 1]


## 사전학습된 임베딩을 사용하는 경우

In [23]:
from gensim.models import KeyedVectors
model_wv = KeyedVectors.load_word2vec_format("GoogleNews-vectors-negative300.bin.gz", binary=True)
model_wv.vectors.shape

(3000000, 300)

In [24]:
# 임베딩 메트릭스 초기화 : 사전학습 벡터로 Embedding 레이어를 채우기 위한 준비
print(len(word_to_index))           # 어휘 크기 (vocab_size) 확인

# (vocab_size, embedding_dim) 크기의 0 행렬 생성
embedding_matrix = np.zeros((len(word_to_index), model_wv.vectors.shape[1]))
print(embedding_matrix.shape)

17
(17, 300)


In [None]:
# 사전학습 임베딩 매핑 : 내 단어사전을 GoogleNews 벡터로 채워 Embedding_matrix 구성
model_wv.key_to_index['bad']                    # 'bad'의 내부 인덱스 확인(706)
model_wv.vectors[240]                           # 특정 인덱스 벡터 직접 조회


# 단어가 사전학습 모델에 있으면 임베딩 벡터(np.narray)를 반환 없으면 None 반환
def get_word_embedding(word):
    if word in model_wv:                        # 사전 학습 단어가 존재하면
        return model_wv[word]                   # 해당 단어 임베딩 벡터 반환
    else:
        return None
    
# get_word_embedding('bad')
for word, index in word_to_index.items():       # 내 단어사전(단어-> 인덱스)를 순회
    if index >= 2:                              # 특수토큰 제외
        emb = get_word_embedding(word)          # 사전학습 임베딩에서 해당 단어 벡터 조회
        if emb is not None:                     # 벡터가 존재하면
            embedding_matrix[index] = emb       # 내 인덱스 위치에 사전학습 벡터를 복사해서 채운다.
            

In [27]:
pd.DataFrame(embedding_matrix, index=word_to_index.keys())

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,260,261,262,263,264,265,266,267,268,269,270,271,272,273,274,275,276,277,278,279,280,281,282,283,284,285,286,287,288,289,290,291,292,293,294,295,296,297,298,299
nice,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
great,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
best,0.158203,0.105957,-0.189453,0.386719,0.083496,-0.267578,0.083496,0.113281,-0.104004,0.178711,-0.123535,-0.222656,-0.018066,-0.253906,0.131836,0.085938,0.161133,0.11084,-0.11084,-0.085938,0.026733,0.345703,0.151367,-0.00415,0.10498,0.049072,-0.069824,0.086426,0.031982,-0.028442,-0.157227,0.118652,0.361328,0.001732,0.052979,-0.234375,0.117676,0.086426,-0.01123,0.259766,...,-0.081055,-0.066895,-0.318359,-0.084473,0.135742,0.0625,0.070801,-0.142578,-0.112793,0.014526,-0.066895,0.038818,0.194336,0.095215,0.11377,-0.124512,0.137695,-0.188477,-0.052246,0.158203,0.098633,-0.043701,-0.060547,0.216797,0.040771,-0.146484,-0.189453,-0.251953,-0.168945,-0.086426,-0.085449,0.189453,-0.146484,0.134766,-0.040771,0.032715,0.089355,-0.267578,0.008362,-0.213867
amazing,0.071777,0.208008,-0.028442,0.178711,0.132812,-0.099609,0.096191,-0.116699,-0.008545,0.148438,-0.033447,-0.185547,0.041016,-0.089844,0.021729,0.069336,0.180664,0.222656,-0.100586,-0.069336,0.000104,0.160156,0.040771,0.07373,0.15332,0.067871,-0.103027,0.041748,0.042725,-0.110352,-0.066895,0.041992,0.25,0.212891,0.15918,0.014465,-0.048828,0.013977,0.003555,0.209961,...,-0.068359,-0.139648,-0.15918,-0.017944,0.02124,0.07373,0.130859,-0.080566,0.029907,0.015564,-0.166016,0.150391,-0.006775,0.010132,0.114746,-0.148438,-0.045898,-0.139648,-0.173828,-0.042725,-0.058105,0.052246,-0.111328,0.084473,-0.025513,0.140625,-0.181641,0.017212,-0.137695,-0.014771,-0.011475,0.064453,-0.289062,-0.048096,-0.199219,-0.071289,0.064453,-0.167969,-0.020874,-0.142578
stop,-0.126953,0.021973,0.287109,0.15332,0.12793,0.032715,-0.115723,-0.029541,0.15332,0.011292,0.139648,-0.086914,0.257812,0.07373,-0.018921,0.125,0.09082,-0.001556,-0.031982,-0.145508,0.047607,0.173828,-0.146484,0.006012,0.030273,0.040771,-0.066406,0.18457,0.097168,-0.10498,0.024902,0.056396,0.165039,0.09082,0.185547,0.225586,-0.039795,-0.167969,-0.069336,0.019653,...,-0.178711,0.120605,-0.035889,0.095703,0.152344,0.003998,-0.059082,-0.032471,-0.054199,-0.005493,-0.045654,-0.001526,-0.050293,0.255859,0.04834,-0.019409,-0.12793,-0.088379,-0.225586,0.087402,0.205078,0.085938,0.066406,0.108398,-0.191406,0.070312,-0.163086,-0.002472,0.020264,0.001701,0.006439,-0.033936,-0.166016,-0.016846,-0.048584,-0.022827,-0.152344,-0.101562,-0.090332,0.088379
lies,0.07373,0.004059,-0.135742,0.022095,0.180664,-0.046631,0.224609,-0.229492,-0.040039,0.225586,-0.124023,-0.243164,-0.036621,-0.287109,0.077148,0.224609,0.261719,0.196289,-0.155273,0.085449,-0.095703,0.289062,0.044678,-0.133789,0.117676,-0.165039,0.041016,0.078125,0.131836,-0.292969,-0.044434,0.129883,0.275391,0.231445,0.169922,-0.114258,-0.025879,-0.143555,0.075684,0.322266,...,-0.183594,-0.12207,-0.031128,0.071777,0.068848,-0.011475,0.267578,-0.172852,0.047852,-0.010681,-0.056152,0.117188,0.120605,0.125977,0.075195,-0.206055,0.169922,-0.018799,0.024292,-0.069336,0.075195,-0.004669,-0.328125,0.025024,-0.120605,0.026611,-0.12793,-0.095215,0.157227,0.098145,0.018433,-0.02124,-0.25,-0.020142,-0.310547,-0.207031,-0.006317,-0.141602,-0.150391,-0.137695
pitiful,-0.057861,0.013184,0.115234,0.069824,-0.306641,-0.044678,0.048584,0.152344,0.073242,-0.100098,0.155273,-0.121582,-0.006744,0.028198,-0.172852,0.000759,-0.084961,-0.066406,0.28125,-0.146484,0.172852,0.211914,-0.21582,0.126953,0.118652,0.102051,0.094727,0.026123,0.217773,-0.120605,-0.025513,-0.181641,-0.339844,-0.114258,-0.035645,0.014954,0.092773,-0.108887,0.128906,0.149414,...,0.108398,0.146484,0.154297,0.001656,-0.01123,-0.106445,-0.175781,0.076172,-0.013794,-0.324219,0.149414,0.298828,0.246094,0.030273,0.234375,0.030151,-0.255859,-0.283203,0.055176,-0.275391,0.116699,0.037842,0.180664,-0.012756,-0.235352,0.131836,-0.228516,-0.088867,0.199219,-0.203125,0.100098,0.171875,-0.113281,0.064453,-0.115723,0.048096,-0.004822,0.086426,0.029907,0.007812
nerd,0.149414,-0.012817,0.328125,0.025513,0.017334,0.19043,0.188477,-0.143555,-0.09082,0.206055,-0.300781,0.157227,-0.155273,0.126953,0.151367,0.114746,-0.139648,0.077148,-0.208984,0.101074,-0.145508,0.016357,0.015503,0.217773,0.388672,-0.154297,-0.229492,0.308594,-0.009033,-0.166016,0.07666,-0.023315,-0.220703,0.15332,-0.020996,0.025757,0.009583,0.206055,0.382812,0.118652,...,-0.170898,-0.154297,0.174805,0.134766,0.115234,0.457031,-0.000443,-0.132812,-0.166016,-0.333984,-0.129883,0.098145,-0.119141,-0.148438,-0.058838,0.273438,0.075195,-0.068359,-0.129883,-0.014709,0.01709,-0.439453,0.147461,-0.063965,0.072266,-0.078613,-0.115723,0.314453,-0.10791,0.024292,-0.308594,0.183594,-0.202148,0.031494,-0.164062,-0.201172,0.080078,-0.105469,0.149414,0.157227
excellent,0.269531,0.253906,-0.020996,0.060303,-0.010925,0.217773,0.139648,-0.057617,0.3125,0.253906,-0.037354,-0.028687,-0.043457,0.137695,0.18457,0.125,0.19043,-0.005463,-0.15625,0.026489,0.107422,0.021729,0.001503,-0.036621,0.328125,0.008789,-0.253906,0.306641,0.077148,-0.03125,-0.191406,-0.003677,-0.07373,-0.09668,-0.15625,0.134766,-0.039551,0.074707,-0.021484,0.149414,...,0.069824,0.204102,0.077148,0.120117,0.020386,0.10498,-0.173828,-0.143555,-0.052002,-0.125,-0.186523,0.386719,0.189453,-0.056641,-0.195312,-0.147461,-0.08252,-0.041016,-0.138672,0.285156,0.028931,-0.177734,0.089355,0.00145,-0.318359,-0.070312,-0.026855,-0.058594,0.124512,-0.08252,-0.063477,0.132812,-0.094238,0.089355,-0.06543,-0.016235,-0.10791,-0.072266,-0.094238,0.028809
work,0.265625,-0.207031,-0.026611,0.419922,-0.208984,0.390625,0.164062,0.063965,0.149414,-0.0177,0.02417,-0.148438,0.283203,-0.226562,-0.161133,-0.028564,-0.035645,-0.054688,0.145508,-0.112305,-0.188477,0.098633,0.066406,0.040771,0.030151,-0.200195,-0.061279,0.158203,0.029907,-0.098633,0.248047,-0.012695,0.031738,0.298828,0.005371,-0.097168,0.310547,-0.142578,0.226562,0.273438,...,0.139648,-0.096191,0.09082,-0.12793,0.041016,0.146484,0.238281,0.023926,-0.029907,-0.117676,-0.207031,0.021851,0.291016,-0.002625,-0.255859,-0.200195,0.179688,-0.220703,-0.223633,0.339844,0.134766,0.011597,-0.04248,0.318359,0.227539,-0.078613,-0.151367,-0.15625,0.017822,0.062988,0.21582,0.125,-0.227539,-0.310547,-0.112793,-0.09668,0.255859,0.124023,-0.030273,0.082031


In [29]:
# Pytorch 텍스트 분류 모델 : Embedding + RNN + Linear 로 이진 분류 (logit) 출력
import torch
import torch.nn as nn           # 신경망 레이어
import torch.optim as optim     # 옵티마이저(활성화 함수)
from torch.utils.data import DataLoader, TensorDataset  # 배치 로더/ 데이터셋 유틸

class SimpleNet(nn.Module):
    # 정수 시퀀스를 임베딩 -> RNN -> 선형층으로 처리해 이진 분류 logit(1개)를 출력
    def __init__(self,vocab_size, embedding_dim, hidden_size):
        super().__init__()                  # nn.Module 초기화
        self.embedding = nn.Embedding(      # 단어 ID를 밀집 벡터로 변환하는 임베딩 층
            num_embeddings = vocab_size,    # 단어 사전 크기 (어휘 수)
            embedding_dim = embedding_dim,  # 임베딩 차원
            padding_idx = 0                 # PAD(0) 인덱스는 0 그대로 사용
        )

        # 사전 학습된 임베딩 벡터로 초기화 : Embedding 가중치를 사전학습 행렬로 덮어쓰기
        self.embedding.weight = nn.Parameter(torch.tensor(embedding_matrix, dtype = torch.float))

        self.rnn = nn.RNN(embedding_dim, hidden_size, batch_first = True)   # 압력(배치, 길이, 차원) 형태의 RNN
        self.out = nn.Linear(hidden_size, 1)    # 마지막 은닉 상태를 1차원 logit으로 변환

    def forward(self, x):
        embedded = self.embedding(x)        # (batch, seq_len) -> (batch, seq_len, embedding_dim)
        out, h_n = self.rnn(embedded)       # h_n : (num_layers * directions, batch, hidden_size)
        out = self.out(h_n.squeeze(0))      # (batch_size, hidden_size) -> (batch, 1)
        return out # 출력 : 시그모이드 전 logit(확률이 아님)
    
embedding_dim = model_wv.vectors.shape[1] # 사전학습 임베딩 차원 (300) 으로 임베딩 차원 설정
model = SimpleNet(vocab_size, embedding_dim, hidden_size= 16)   # 어휘 크기/ 임베딩 차원/ 은닉 크기로 모델 생성
print(model)

criterion = nn.BCEWithLogitsLoss()  # 출력 logit과 정답(0/1)로 이진분류 손실 계산(시그모이드 포함)
optimizer = optim.Adam(model.parameters(), lr=0.005)    # 모델 파라미터는 Adam으로 업데이트

SimpleNet(
  (embedding): Embedding(17, 300, padding_idx=0)
  (rnn): RNN(300, 16, batch_first=True)
  (out): Linear(in_features=16, out_features=1, bias=True)
)


In [30]:
# 학습 루프: 미니배치 단위로 20 epoch 학습하며 평균 손실 출력
for epoch in range(20):
    epoch_loss = 0                          # 손실 누적
    
    for x_batch, y_batch in dataloader:     # 미니배치 단위로(X, y)가져오기
        optimizer.zero_grad()               # 이전 배치 기울기 초기화
        output = model(x_batch)             # 순전파로 logit 계산
        loss = criterion(output, y_batch)   # 예측 logit과 정답으로 손실 계산
        loss.backward()                     # 역전파로 기울기 계싼
        optimizer.step()                    # 파라미터 업데이트
        
        epoch_loss += loss.item()           # 배치손실을 float으로 누적
        
    print(f"Epoch {epoch + 1}: Loss {epoch_loss/ len(dataloader)}") # epoch별 평균 손실 계산

Epoch 1: Loss 0.6801373660564423
Epoch 2: Loss 0.49727021157741547
Epoch 3: Loss 0.35623661428689957
Epoch 4: Loss 0.23015503957867622
Epoch 5: Loss 0.1480734534561634
Epoch 6: Loss 0.09338825568556786
Epoch 7: Loss 0.06634146627038717
Epoch 8: Loss 0.047831411473453045
Epoch 9: Loss 0.03677579294890165
Epoch 10: Loss 0.029688909649848938
Epoch 11: Loss 0.02426537685096264
Epoch 12: Loss 0.021065642591565847
Epoch 13: Loss 0.018339503556489944
Epoch 14: Loss 0.015760740963742137
Epoch 15: Loss 0.014512802474200726
Epoch 16: Loss 0.012776592746376991
Epoch 17: Loss 0.011818057857453823
Epoch 18: Loss 0.011056839255616069
Epoch 19: Loss 0.010238915914669633
Epoch 20: Loss 0.009514382109045982


In [31]:
# 평가 / 예측 : 학습된 모델로 확률 -> 0/1 예측값 생성 후 정답과 비교
model.eval()                                # 평가 모드
with torch.no_grad():                       # 기울기 계산 비활성화
    output = model(X)                       # 전체 샘플에 대한 예측 logit 계산
    prob = torch.sigmoid(output)            # logit에 0~1 확률로 변환
    pred = (prob >= 0.5).int()              # 임계값 0.5 기준으로 이진 분류(0/1) 예측값 생성

print(labels)
print(pred.squeeze().detach().numpy())      # 예측 라벨을 1차원 numpy 배열로 변환

[1, 0, 0, 1, 1, 0, 1]
[1 0 0 1 1 0 1]


사전학습 임베딩을 사용했을 때에도 학습 데이터 분류가 잘 되는지 파악한다.  
만약 틀린 샘플이 있다면 해당 문장이 OOV(0벡터) 비중이 큰지 확인해 봐야한다.  