# Embedding

## Load Data

In [None]:
import pandas as pd

splits = {'train': 'plain_text/train-00000-of-00001.parquet', 'test': 'plain_text/test-00000-of-00001.parquet', 'unsupervised': 'plain_text/unsupervised-00000-of-00001.parquet'}

df_imdb_train = pd.read_parquet("hf://datasets/stanfordnlp/imdb/" + splits["train"])
df_imdb_test = pd.read_parquet("hf://datasets/stanfordnlp/imdb/" + splits["test"])

df_imdb = pd.concat([df_imdb_train, df_imdb_test])
df_imdb.info()

## Create the training data

In [None]:
import nltk
from nltk.corpus import stopwords
from nltk.stem import PorterStemmer

nltk.download('stopwords')
stop_words = set(stopwords.words('english'))
stemmer = PorterStemmer()

In [None]:
from tqdm import tqdm
from collections import Counter

def tokenize(text):
    tokens = text.lower().split()
    tokens = [stemmer.stem(word) for word in tokens if word not in stop_words]
    return tokens

counter = Counter()
for example in tqdm(df_imdb["text"].values):
    counter.update(tokenize(example))

In [None]:
vocab = {word: idx for idx, (word, _) in enumerate(counter.most_common(1000), 1)}
vocab["<UNK>"] = 0

In [None]:
def context_target_pairs(text, window_size=2):
    tokens = [vocab.get(token, 0) for token in tokenize(text)]
    pairs = []
    for i in range(window_size, len(tokens) - window_size):
        context = tokens[i - window_size:i] + tokens[i + 1:i + 1 + window_size]
        target = tokens[i]
        pairs.append((context, target))
    return pairs

In [None]:
pairs = []
for text in tqdm(df_imdb["text"].values):  # Csak egy részhalmazt veszünk a gyors tanítás érdekében
    pairs.extend(context_target_pairs(text))

In [None]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

torch.manual_seed(42)

class Word2VecDataset(Dataset):
    def __init__(self, pairs):
        self.pairs = pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        context, target = self.pairs[idx]
        return torch.tensor(context, dtype=torch.long), torch.tensor(target, dtype=torch.long)

train_dataset = Word2VecDataset(pairs)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)

In [None]:
items, labels = next(iter(train_loader))
len(items), len(labels), items[0], labels[0]

## Create the model

<img src="https://miro.medium.com/v2/resize:fit:720/format:webp/1*bBETsVNLyjnaFJgM9avkeQ.png">

In [None]:
import torch.nn as nn

# CBOW
class Word2Vec(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super(Word2Vec, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embed_size)
        self.linear = nn.Linear(embed_size, vocab_size)

    def forward(self, context):
        embeds = self.embeddings(context).mean(dim=1)
        out = self.linear(embeds)
        return out

In [None]:
embed_size = 100
model = Word2Vec(len(vocab), embed_size)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
device

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
len(train_loader)

In [None]:
epochs = 5
pbar = tqdm(train_loader)
for epoch in range(epochs):
    total_loss = []
    for context, target in pbar:
        context = context.to(device)
        target = target.to(device)

        optimizer.zero_grad()

        output = model(context)

        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        total_loss.append(loss.item())
        avg_loss = sum(total_loss) / len(total_loss)
        pbar.set_description(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")

## Save the vectors

In [None]:
word_embeddings = model.embeddings.weight.cpu().detach().numpy()
print("Beágyazás mátrix mérete:", word_embeddings.shape)

In [None]:
import csv

with open("vectors.tsv", "w", newline='') as v_f, open("metadata.tsv", "w", newline='') as m_f:
  vector_writer = csv.writer(v_f, delimiter='\t')
  metadata_writer = csv.writer(m_f, delimiter='\t')

  for word, idx in vocab.items():
    if idx < len(word_embeddings):
      vector_writer.writerow(word_embeddings[idx])
      metadata_writer.writerow([word])