<a href="https://colab.research.google.com/github/pd8459/skip_gram/blob/main/skip_gram.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import math
import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import re

import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [None]:
from sklearn.datasets import fetch_20newsgroups

In [None]:
dataset = fetch_20newsgroups(shuffle=True, random_state=1, remove=('headers', 'footers', 'quotes'))
documents = dataset.data[:100]

In [None]:
def load_text(filepath):
    with open(filepath, 'r', encoding='utf-8') as f:
        return f.read()

text = load_text('text.txt')

def preprocess(text):
  text = text.lower()
  text = re.sub(r'[^a-z\s]', '', text)
  text = text.replace('\n', ' ')
  text = re.sub(r'\s+', ' ', text).strip()

  words = text.split(' ')
  word_to_id  = {}
  id_to_word = {}

  for word in words:
      if word not in word_to_id:
        new_id = len(word_to_id)
        word_to_id[word] = new_id
        id_to_word[new_id] = word

  corpus = [word_to_id[word] for word in words]
  return corpus, word_to_id, id_to_word

corpus, word_to_id, id_to_word = preprocess(", ".join(documents))

In [None]:
def generate_skipgram_pairs(corpus, window_size):
  pairs = []
  for i in range(window_size, len(corpus) - window_size):
      center = corpus[i]
      for j in range(-window_size, window_size +1):
          if j == 0:
              continue
          context = corpus[i+j]
          pairs.append((center, context))
      return pairs

window_size = 5
K=5
vocab_size = len(word_to_id)
pairs = generate_skipgram_pairs(corpus, window_size)

In [None]:
import random
def get_negative_samples(context_word, vocab_size, K):
    neg_samples = []
    while len(neg_samples) < K:
        neg = random.randint(0, vocab_size-1)
        if neg != context_word:
            neg_samples.append(neg)
    return neg_samples

class SkipGramNegDataset(Dataset):
    def __init__(self, pairs, vocab_size, K):
        self.pairs = pairs
        self.vocab_size = vocab_size
        self.K = K

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        center, context = self.pairs[idx]
        negatives = get_negative_samples(context, self.vocab_size, self.K)
        return (
            torch.tensor(center, dtype = torch.long),
            torch.tensor(context, dtype = torch.long),
            torch.tensor(negatives, dtype = torch.long)
        )



In [None]:
dataset = SkipGramNegDataset(pairs, vocab_size, K)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

In [None]:
for center, context, negatives in dataloader:
    print("Center:", center)
    print("Context:", context)
    print("Nagatives:", negatives)
    break

Center: tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5])
Context: tensor([ 2,  4,  1,  8, 10,  3,  7,  6,  0,  9])
Nagatives: tensor([[3478, 1486, 4252, 1628, 2070],
        [3313, 4732, 2775,  742,  369],
        [4781, 4195, 1196, 2527, 2176],
        [ 222, 4434,  177, 2263, 3389],
        [ 650, 1164,  367, 2512, 3853],
        [3385, 4116, 1268, 4519, 4864],
        [3149, 1966, 2105, 4272, 4004],
        [1397, 1742, 2648,  693, 4443],
        [4366,  153, 1821, 1594, 2845],
        [2554, 2775, 3990,  635,  688]])


In [None]:
import torch.nn.functional as F

class SkipGramNegSampling(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(SkipGramNegSampling, self).__init__()
        self.input_embedding = nn.Embedding(vocab_size, embedding_dim)
        self.output_embedding = nn.Embedding(vocab_size, embedding_dim)

    def forward(self, center, pos_context, neg_contexts):
        v_center = self.input_embedding(center)
        v_pos = self.output_embedding(pos_context)
        v_neg = self.output_embedding(neg_contexts)

        pos_score = torch.sum(v_center * v_pos, dim=1)
        pos_loss = F.logsigmoid(pos_score)

        neg_score = torch.bmm(v_neg, v_center.unsqueeze(2)).squeeze()
        neg_loss = F.logsigmoid(-neg_score).sum(1)

        loss = -(pos_loss + neg_loss).mean()
        return loss

In [None]:
torch.cuda.manual_seed(123)
model = SkipGramNegSampling(vocab_size, 100).to(device)
# Assign the optimizer to a variable named 'optimizer'
optimizer = torch.optim.Adam(model.parameters(), lr = 0.003)

epochs = 50
for epoch in range(epochs):
  total_loss = 0
  for center, context, negatives in dataloader:
    optimizer.zero_grad()
    loss = model(center.to(device), context.to(device), negatives.to(device))
    loss.backward()
    optimizer.step()
    total_loss += loss.item()
  print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss:.4f}")

Epoch 1/50, Loss: 32.2958
Epoch 2/50, Loss: 14.6382
Epoch 3/50, Loss: 15.8007
Epoch 4/50, Loss: 15.0683
Epoch 5/50, Loss: 15.5238
Epoch 6/50, Loss: 17.0632
Epoch 7/50, Loss: 26.6278
Epoch 8/50, Loss: 18.4083
Epoch 9/50, Loss: 22.5019
Epoch 10/50, Loss: 17.1314
Epoch 11/50, Loss: 20.0984
Epoch 12/50, Loss: 15.7137
Epoch 13/50, Loss: 19.1127
Epoch 14/50, Loss: 19.2005
Epoch 15/50, Loss: 20.5294
Epoch 16/50, Loss: 21.5845
Epoch 17/50, Loss: 19.0956
Epoch 18/50, Loss: 15.6312
Epoch 19/50, Loss: 20.2344
Epoch 20/50, Loss: 17.1900
Epoch 21/50, Loss: 20.1115
Epoch 22/50, Loss: 25.6098
Epoch 23/50, Loss: 19.9753
Epoch 24/50, Loss: 14.5494
Epoch 25/50, Loss: 19.3087
Epoch 26/50, Loss: 18.0134
Epoch 27/50, Loss: 15.9002
Epoch 28/50, Loss: 15.4950
Epoch 29/50, Loss: 11.0526
Epoch 30/50, Loss: 20.5480
Epoch 31/50, Loss: 21.1355
Epoch 32/50, Loss: 16.3085
Epoch 33/50, Loss: 17.0618
Epoch 34/50, Loss: 17.6909
Epoch 35/50, Loss: 21.8743
Epoch 36/50, Loss: 22.1711
Epoch 37/50, Loss: 15.3769
Epoch 38/5

In [None]:
def most_similar(word, model, word_to_id, id_to_word, topk=5):
    if word not in word_to_id:
      print("단어가 vocabulary에 없습니다.")
      return

    word_id = word_to_id[word]
    with torch.no_grad():
      emb = model.input_embedding.weight
      target_vec = emb[word_id]
      scores = F.cosine_similarity(target_vec.unsqueeze(0), emb)
      topk_ids = torch.topk(scores, topk+1).indices.tolist()