In [None]:
class WordEmbedder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, context_window_size):
        super().__init__()
        self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)
        self.context_window_size = context_window_size
        
    def forward(self, titles):
        batch_size, total_length = titles.shape 
        
        # Initialize result tensor filled with padding value, here assumed as ones for simplicity
        # You might choose a different value or method depending on how you handle padding in embeddings
        result_embeddings = torch.ones(batch_size, total_length - (2 * context_window_size), embedding_dim)
            
        for batch in range(batch_size):
            title = titles[batch]
    
            # Find indices that are not padding
            non_pad_indices = (title != pad_index).nonzero().squeeze()
            
            for idx, i in enumerate(non_pad_indices):
                if i - self.context_window_size < 0 or i + self.context_window_size >= total_length:
                    # Skip positions where full context cannot be obtained due to padding
                    continue
                    
                # Gather context indices, avoiding going out of bounds
                context_indices_left = title[max(i - self.context_window_size, 0):i]
                context_indices_right = title[i+1:min(i+1+self.context_window_size, total_length)]
                context_indices = torch.cat((context_indices_left, context_indices_right), dim=0).unsqueeze(0)
                
                # Compute and store the average context embedding at the appropriate position
                context_embedding = self.embeddings(context_indices).mean(dim=1)
                result_embeddings[batch, i, :] = context_embedding
        
        return result_embeddings

In [None]:
import torch
from tqdm import tqdm

num_epochs = 15  # Number of epochs to train the model

for epoch in range(num_epochs):
    total_loss = 0.0
    progress_bar = tqdm(data_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")

    for batch in progress_bar:
        titles = batch[0].to(device)  # Move the batch of titles to the specified device (e.g., GPU)

        optimizer.zero_grad()  # Reset the gradients to zero before the backward pass

        # Compute the predicted context embeddings for the current batch
        predicted_context_embeddings = word_embedder(titles)
        # Prepare the target embeddings
        # Assuming titles[:, context_window_size:-context_window_size] are the target indices
        target_indices = titles[:, context_window_size:-context_window_size]
        target_embeddings = word_embedder.embeddings(target_indices)
        # The target embeddings have a shape of [batch_size, reduced_seq_len, embedding_dim]

        # Create a mask to exclude padding and edge tokens from the loss calculation
        pad_mask = (titles != pad_index)[:, context_window_size:-context_window_size].unsqueeze(-1)
        pad_mask = pad_mask.expand(-1, -1, embedding_dim)  # Expand the mask to cover the embedding dimensions

        # Apply the mask to the predicted and target embeddings
        # This ensures that the loss is only computed for the relevant (non-padding) tokens
        predicted_context_embeddings *= pad_mask
        target_embeddings *= pad_mask

        # Compute the loss between the predicted and target embeddings
        # Note: Depending on the specific requirements, you may need to adjust the way the loss is computed
        # to handle the 3D nature of the tensors and potential misalignment
        loss = loss_fn(predicted_context_embeddings, target_embeddings)

        # Backpropagate the loss and update the model parameters
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())

    # Compute the average loss for the current epoch
    average_loss = total_loss / len(data_loader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {average_loss:.4f}')