# Continuous Bag of Words (CBOW) Model

In [38]:
corpus = [
    "Natural Language Processing is a fascinating field of study that has been evolving rapidly over the past few decades.",
    "Machine learning provides powerful tools for automating tasks and making predictions from data.",
    "Text data is often messy and unstructured, which makes it challenging to analyze and understand without the right tools.",
    "Deep learning models have shown remarkable success in understanding complex patterns in data, especially for tasks related to NLP.",
    "I love building machine learning models and experimenting with different techniques to improve their performance.",
    "Clean and properly preprocessed data is essential for building successful machine learning models that generalize well."
]

In [39]:
import re
import spacy
from pprint import pprint

nlp = spacy.load("en_core_web_sm")


def clean_text(documents: list[str]):
    cleaned_docs = []
    for doc in documents:
        doc = nlp(re.sub(r"[^\w\s]", "", doc.lower()))
        filtered_text = [token.text for token in doc if not token.is_stop]
        cleaned_docs.append(filtered_text)

    return cleaned_docs

cleaned_corpus = clean_text(corpus)
pprint(cleaned_corpus)


[['natural',
  'language',
  'processing',
  'fascinating',
  'field',
  'study',
  'evolving',
  'rapidly',
  'past',
  'decades'],
 ['machine',
  'learning',
  'provides',
  'powerful',
  'tools',
  'automating',
  'tasks',
  'making',
  'predictions',
  'data'],
 ['text',
  'data',
  'messy',
  'unstructured',
  'makes',
  'challenging',
  'analyze',
  'understand',
  'right',
  'tools'],
 ['deep',
  'learning',
  'models',
  'shown',
  'remarkable',
  'success',
  'understanding',
  'complex',
  'patterns',
  'data',
  'especially',
  'tasks',
  'related',
  'nlp'],
 ['love',
  'building',
  'machine',
  'learning',
  'models',
  'experimenting',
  'different',
  'techniques',
  'improve',
  'performance'],
 ['clean',
  'properly',
  'preprocessed',
  'data',
  'essential',
  'building',
  'successful',
  'machine',
  'learning',
  'models',
  'generalize']]


In [40]:
from collections import Counter

def build_vocab(corpus: list[str]):
    vocab = Counter(term for doc in corpus for term in doc)
    word_to_idx = {word: idx for idx, (word, _) in enumerate(vocab.items())}
    idx_to_word = {idx: word for idx, (word, _) in enumerate(vocab.items())}
    return word_to_idx, idx_to_word

word_to_idx, idx_to_word = build_vocab(cleaned_corpus)
print(word_to_idx)
print(idx_to_word)

{'natural': 0, 'language': 1, 'processing': 2, 'fascinating': 3, 'field': 4, 'study': 5, 'evolving': 6, 'rapidly': 7, 'past': 8, 'decades': 9, 'machine': 10, 'learning': 11, 'provides': 12, 'powerful': 13, 'tools': 14, 'automating': 15, 'tasks': 16, 'making': 17, 'predictions': 18, 'data': 19, 'text': 20, 'messy': 21, 'unstructured': 22, 'makes': 23, 'challenging': 24, 'analyze': 25, 'understand': 26, 'right': 27, 'deep': 28, 'models': 29, 'shown': 30, 'remarkable': 31, 'success': 32, 'understanding': 33, 'complex': 34, 'patterns': 35, 'especially': 36, 'related': 37, 'nlp': 38, 'love': 39, 'building': 40, 'experimenting': 41, 'different': 42, 'techniques': 43, 'improve': 44, 'performance': 45, 'clean': 46, 'properly': 47, 'preprocessed': 48, 'essential': 49, 'successful': 50, 'generalize': 51}
{0: 'natural', 1: 'language', 2: 'processing', 3: 'fascinating', 4: 'field', 5: 'study', 6: 'evolving', 7: 'rapidly', 8: 'past', 9: 'decades', 10: 'machine', 11: 'learning', 12: 'provides', 13: 

In [41]:
def create_context_target_pairs(corpus: list[str], window_size: int = 2):
    pairs = []
    for document in corpus:
        for idx, term in enumerate(document):
            start_idx = max(idx - window_size, 0)
            end_idx = min(idx + window_size + 1, len(document))
            pairs.append(([document[i] for i in range(start_idx, end_idx)], term))

    return pairs

pairs = create_context_target_pairs(cleaned_corpus)
pairs           

[(['natural', 'language', 'processing'], 'natural'),
 (['natural', 'language', 'processing', 'fascinating'], 'language'),
 (['natural', 'language', 'processing', 'fascinating', 'field'], 'processing'),
 (['language', 'processing', 'fascinating', 'field', 'study'], 'fascinating'),
 (['processing', 'fascinating', 'field', 'study', 'evolving'], 'field'),
 (['fascinating', 'field', 'study', 'evolving', 'rapidly'], 'study'),
 (['field', 'study', 'evolving', 'rapidly', 'past'], 'evolving'),
 (['study', 'evolving', 'rapidly', 'past', 'decades'], 'rapidly'),
 (['evolving', 'rapidly', 'past', 'decades'], 'past'),
 (['rapidly', 'past', 'decades'], 'decades'),
 (['machine', 'learning', 'provides'], 'machine'),
 (['machine', 'learning', 'provides', 'powerful'], 'learning'),
 (['machine', 'learning', 'provides', 'powerful', 'tools'], 'provides'),
 (['learning', 'provides', 'powerful', 'tools', 'automating'], 'powerful'),
 (['provides', 'powerful', 'tools', 'automating', 'tasks'], 'tools'),
 (['powe

In [42]:
def encode_pairs(pairs: list[str], word_to_idx: dict):
    encoded_pairs = []
    for context, target in pairs:
        context_idx = [word_to_idx[term] for term in context]
        target_idx = word_to_idx[target]
        encoded_pairs.append((context_idx, target_idx))

    return encoded_pairs

encoded_pairs = encode_pairs(pairs, word_to_idx)
encoded_pairs

[([0, 1, 2], 0),
 ([0, 1, 2, 3], 1),
 ([0, 1, 2, 3, 4], 2),
 ([1, 2, 3, 4, 5], 3),
 ([2, 3, 4, 5, 6], 4),
 ([3, 4, 5, 6, 7], 5),
 ([4, 5, 6, 7, 8], 6),
 ([5, 6, 7, 8, 9], 7),
 ([6, 7, 8, 9], 8),
 ([7, 8, 9], 9),
 ([10, 11, 12], 10),
 ([10, 11, 12, 13], 11),
 ([10, 11, 12, 13, 14], 12),
 ([11, 12, 13, 14, 15], 13),
 ([12, 13, 14, 15, 16], 14),
 ([13, 14, 15, 16, 17], 15),
 ([14, 15, 16, 17, 18], 16),
 ([15, 16, 17, 18, 19], 17),
 ([16, 17, 18, 19], 18),
 ([17, 18, 19], 19),
 ([20, 19, 21], 20),
 ([20, 19, 21, 22], 19),
 ([20, 19, 21, 22, 23], 21),
 ([19, 21, 22, 23, 24], 22),
 ([21, 22, 23, 24, 25], 23),
 ([22, 23, 24, 25, 26], 24),
 ([23, 24, 25, 26, 27], 25),
 ([24, 25, 26, 27, 14], 26),
 ([25, 26, 27, 14], 27),
 ([26, 27, 14], 14),
 ([28, 11, 29], 28),
 ([28, 11, 29, 30], 11),
 ([28, 11, 29, 30, 31], 29),
 ([11, 29, 30, 31, 32], 30),
 ([29, 30, 31, 32, 33], 31),
 ([30, 31, 32, 33, 34], 32),
 ([31, 32, 33, 34, 35], 33),
 ([32, 33, 34, 35, 19], 34),
 ([33, 34, 35, 19, 36], 35),
 ([34, 

In [43]:
import torch
import torch.nn as nn
import torch.optim as optim

class CBOWModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(CBOWModel, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.linear = nn.Linear(embedding_dim, vocab_size)

    def forward(self, context):
        embedded = self.embeddings(context).mean(dim=1)
        out = self.linear(embedded)
        return out


def train_model(model, encoded_pairs, word_to_idx, epochs, learning_rate):
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)
    loss_function = nn.CrossEntropyLoss()
    losses = []

    for epoch in range(epochs):
        total_loss = 0
        for context, target in encoded_pairs:
            context_tensor = torch.tensor([context], dtype=torch.long)
            target_tensor = torch.tensor([target], dtype=torch.long)

            optimizer.zero_grad()
            output = model(context_tensor)
            loss = loss_function(output, target_tensor)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        losses.append(total_loss)

    return losses

embedding_dim = 10
model = CBOWModel(vocab_size=len(word_to_idx), embedding_dim=embedding_dim)
losses = train_model(model, encoded_pairs, word_to_idx, epochs=100, learning_rate=0.01)
losses

[260.90443658828735,
 258.69257378578186,
 256.5143029689789,
 254.36835098266602,
 252.25356817245483,
 250.16892910003662,
 248.11352348327637,
 246.0865547657013,
 244.08734774589539,
 242.11532282829285,
 240.17001032829285,
 238.25102734565735,
 236.3580801486969,
 234.49096059799194,
 232.64952182769775,
 230.83367776870728,
 229.0433909893036,
 227.27866196632385,
 225.53951120376587,
 223.82597541809082,
 222.1380751132965,
 220.47582972049713,
 218.83922016620636,
 217.2281894683838,
 215.642636179924,
 214.08240628242493,
 212.5472673177719,
 211.03693866729736,
 209.5510617494583,
 208.089213848114,
 206.65091395378113,
 205.2356197834015,
 203.8427346944809,
 202.4716159105301,
 201.12159132957458,
 199.79195046424866,
 198.48197078704834,
 197.19091093540192,
 195.91802787780762,
 194.66257762908936,
 193.42382884025574,
 192.2010635137558,
 190.99357795715332,
 189.80069839954376,
 188.6217725276947,
 187.45617401599884,
 186.30331897735596,
 185.1626352071762,
 184.03359