# Japanese Text Embedding Generator with PLaMo-1B

**Goal:** This notebook demonstrates how to generate high-quality text embeddings for Japanese sentences using the `pfnet/plamo-embedding-1b` model from Hugging Face. Embeddings are numerical representations of text that capture semantic meaning, useful for various NLP tasks like similarity search, clustering, and classification.

We will cover:
1.  Installing necessary libraries.
2.  Loading the pre-trained PLaMo model and its tokenizer.
3.  Defining a function to generate embeddings using mean pooling.
4.  Running examples for single sentences and batches of sentences.
5.  Saving the generated embeddings to a file and loading them back.

## 1. Library Installation

The following libraries are required:
*   `torch`: The core PyTorch library for tensor computations and neural network operations.
*   `transformers`: Hugging Face's library providing access to pre-trained models (like PLaMo) and tokenizers.
*   `sentencepiece`: A tokenizer library often used by models like PLaMo.
*   `numpy`: A library for numerical operations, especially for handling the embeddings as arrays.

The cell below will install these libraries using pip.

In [None]:
# Install necessary libraries
!pip install torch transformers sentencepiece numpy

## 2. Import Libraries

Now, let's import the installed libraries and necessary modules.
*   `torch` for tensor operations.
*   `AutoTokenizer` and `AutoModel` from `transformers` for loading the model and tokenizer automatically based on the model name.
*   `numpy` (as `np`) for numerical operations.

In [None]:
# Import libraries
import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np

## 3. Load Pre-trained Model and Tokenizer

We will use the `pfnet/plamo-embedding-1b` model, a powerful model specifically trained for generating Japanese text embeddings.

*   **`MODEL_NAME`**: Specifies the Hugging Face model identifier.
*   **`AutoTokenizer.from_pretrained(MODEL_NAME)`**: Loads the appropriate tokenizer for the specified model. The tokenizer converts text into a format (token IDs, attention masks) that the model can understand.
*   **`AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True)`**: Loads the pre-trained PLaMo model.
    *   `trust_remote_code=True`: This argument is sometimes required for models that have custom code within their Hugging Face repository. It allows the execution of this custom code. Always ensure you trust the source of the model when using this option.
*   **`device`**: We check if a CUDA-enabled GPU is available and set the device accordingly (`'cuda'` or `'cpu'`).
*   **`.to(device)`**: This moves the model's parameters and buffers to the selected device (GPU if available, otherwise CPU). Processing on a GPU significantly speeds up computations.

In [None]:
# Load pre-trained model and tokenizer
MODEL_NAME = 'pfnet/plamo-embedding-1b'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# It's good practice to move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True).to(device)

## 4. Embedding Generation Function

The function `get_japanese_embedding` takes either a single Japanese sentence or a list of sentences and returns their embeddings.

**Function Breakdown:**

1.  **Tokenization (`tokenizer(...)`):**
    *   `text_or_texts`: The input Japanese string or list of strings.
    *   `return_tensors='pt'`: Returns PyTorch tensors.
    *   `truncation=True`: Truncates sequences to the model's maximum input length if they are too long.
    *   `padding=True`: Pads shorter sequences to the length of the longest sequence in a batch, ensuring uniform tensor dimensions.
    *   `max_length=512`: Sets the maximum sequence length. (Note: PLaMo-1B's default is 2048. For many common uses, 512 is a practical starting point, but this can be adjusted. The model will handle sequences up to its configured maximum.)
    *   `add_special_tokens=True`: Adds special tokens like `[CLS]` and `[SEP]` if the model expects them (PLaMo does).
    *   `.to(device)`: Moves the tokenized inputs to the same device as the model.

2.  **Model Inference (`with torch.no_grad(): ... outputs = model(**inputs)`)**
    *   `torch.no_grad()`: Disables gradient calculations, which is crucial for inference as it reduces memory consumption and speeds up computations when we are not training the model.
    *   `model(**inputs)`: Passes the tokenized input (input IDs, attention mask, etc.) to the model. The model returns a set of outputs, including the `last_hidden_state`.

3.  **Mean Pooling for Sentence Embedding:**
    *   `last_hidden_states = outputs.last_hidden_state`: This tensor contains the embeddings for each token in the input sequence(s). Its shape is typically (batch_size, sequence_length, hidden_dim).
    *   **Attention Mask (`inputs['attention_mask']`)**: The attention mask is used to distinguish real tokens from padding tokens. It has a value of 1 for real tokens and 0 for padding tokens.
    *   `expanded_mask = attention_mask.unsqueeze(-1).expand(last_hidden_states.size()).float()`: The attention mask is expanded to match the dimensions of `last_hidden_states`. This allows element-wise multiplication to zero out the embeddings of padding tokens.
    *   `sum_embeddings = torch.sum(last_hidden_states * expanded_mask, 1)`: The embeddings of non-padding tokens are summed up along the sequence dimension (dimension 1).
    *   `sum_mask = expanded_mask.sum(1)`: The number of actual (non-padding) tokens in each sequence is calculated.
    *   `mean_embeddings = sum_embeddings / sum_mask`: The sum of embeddings is divided by the number of actual tokens to get the mean embedding. This mean-pooled embedding represents the entire sentence.
    *   `clamp(sum_mask, min=1e-9)`: Prevents division by zero if a sequence had no actual tokens (though padding and tokenizer settings usually prevent this).

4.  **Output:**
    *   `.cpu().numpy().squeeze()`: The resulting embeddings are moved to the CPU, converted to a NumPy array, and any unnecessary single dimensions (e.g., if a single sentence was input) are removed using `squeeze()`.

In [None]:
# Function to generate embeddings using mean pooling
def get_japanese_embedding(text_or_texts):
    # Tokenize the text (handles single string or list of strings)
    inputs = tokenizer(text_or_texts, return_tensors='pt', truncation=True, padding=True, max_length=512, add_special_tokens=True).to(device)
    
    # Get model outputs
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Perform mean pooling
    last_hidden_states = outputs.last_hidden_state
    attention_mask = inputs['attention_mask']
    expanded_mask = attention_mask.unsqueeze(-1).expand(last_hidden_states.size()).float()
    sum_embeddings = torch.sum(last_hidden_states * expanded_mask, 1)
    sum_mask = expanded_mask.sum(1)
    sum_mask = torch.clamp(sum_mask, min=1e-9)
    mean_embeddings = sum_embeddings / sum_mask
    
    # Squeeze is important here: if input was single string, result is (1, dim), squeeze to (dim,)
    # If input was list of N strings, result is (N, dim), squeeze does nothing if N > 1.
    return mean_embeddings.cpu().numpy().squeeze()

## 5. Example Usage

The following cell demonstrates how to use the `get_japanese_embedding` function.
It shows:
1.  Generating an embedding for a single Japanese sentence.
2.  Generating embeddings for a batch of multiple Japanese sentences.
    *   The function `get_japanese_embedding` handles both single strings and lists of strings automatically.

In [None]:
# Example Usage for generating embeddings
print(f"Using Model: {MODEL_NAME}\n")

# 1. Single sentence example
print("--- Single Sentence Embedding ---")
sample_text_single = "こんにちは、美しい世界！"
embedding_single = get_japanese_embedding(sample_text_single)
print(f"Original sentence: {sample_text_single}")
print(f"Embedding shape: {embedding_single.shape}")
print(f"Sample embedding (first 5 values): {embedding_single[:5]}\n")

# 2. Batch (multiple sentences) example
print("--- Batch Sentences Embedding ---")
sample_texts_batch = [
    "これは最初の文です。",
    "日本語の埋め込みをテストしています。",
    "これが最後の文になります。"
]
embeddings_batch = get_japanese_embedding(sample_texts_batch)
print(f"Original sentences: {sample_texts_batch}")
print(f"Batch embeddings shape: {embeddings_batch.shape}") # Should be (num_sentences, embedding_dim)
if embeddings_batch.ndim == 2 and embeddings_batch.shape[0] > 0:
    print(f"Sample embedding for the first sentence (first 5 values): {embeddings_batch[0, :5]}\n")
else:
    print("Could not display sample of batch embeddings due to unexpected shape.\n")

## 6. Saving and Loading Embeddings

After generating embeddings, it's often useful to save them for later use, avoiding the need to recompute them. NumPy's `.npy` format is efficient for storing numerical arrays.

*   **`np.save(filename, array)`**: Saves the NumPy array (in our case, `embeddings_batch`) to the specified file (e.g., `japanese_embeddings.npy`).
*   **`np.load(filename)`**: Loads the array back from the `.npy` file.

The following code demonstrates this process, using the `embeddings_batch` generated in the previous cell.

In [None]:
# Saving and Loading Example (uses embeddings_batch from the cell above)

# Check if embeddings_batch exists and has content (e.g. if previous cell was run)
if 'embeddings_batch' in locals() and embeddings_batch.size > 0:
    # 3. Saving batch embeddings to a .npy file
    print("--- Saving Batch Embeddings ---")
    output_filename = "japanese_embeddings.npy"
    np.save(output_filename, embeddings_batch)
    print(f"Batch embeddings saved to: {output_filename}\n")

    # 4. Loading embeddings from the .npy file
    print("--- Loading Batch Embeddings ---")
    loaded_embeddings = np.load(output_filename)
    print(f"Embeddings loaded from: {output_filename}")
    print(f"Loaded embeddings shape: {loaded_embeddings.shape}")
    if loaded_embeddings.ndim == 2 and loaded_embeddings.shape[0] > 0:
        print(f"Sample of loaded embedding for the first sentence (first 5 values): {loaded_embeddings[0, :5]}")
else:
    print("Skipping saving/loading example as 'embeddings_batch' is not defined or empty. Please run the cell above first.")