In [1]:
!pip install  sentence-transformers gensim 

# Import necessary libraries
import pandas as pd
import numpy as np
from tqdm import tqdm
from gensim.models import Word2Vec
from sentence_transformers import SentenceTransformer, models, util
import torch
import os
import pickle


Collecting sentence-transformers
  Downloading sentence_transformers-3.0.1-py3-none-any.whl.metadata (10 kB)
Downloading sentence_transformers-3.0.1-py3-none-any.whl (227 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m227.1/227.1 kB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: sentence-transformers
Successfully installed sentence-transformers-3.0.1


  from tqdm.autonotebook import tqdm, trange
2024-06-23 10:50:00.313422: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-23 10:50:00.313577: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-23 10:50:00.471711: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
class CustomTokenizer:
    def __init__(self, vocab):
        self.vocab = vocab

    def tokenize(self, text):
        return [self.vocab[word] for word in text.lower().split() if word in self.vocab]

    def save(self, output_path):
        with open(os.path.join(output_path, 'vocab.pkl'), 'wb') as f:
            pickle.dump(self.vocab, f)

    @staticmethod
    def load(input_path):
        with open(os.path.join(input_path, 'vocab.pkl'), 'rb') as f:
            vocab = pickle.load(f)
        return CustomTokenizer(vocab)

def create_custom_model(word2vec_model_path):
    # Load the Word2Vec model
    word2vec = Word2Vec.load(word2vec_model_path)
    embedding_weights = torch.FloatTensor(word2vec.wv.vectors)
    vocab = word2vec.wv.key_to_index

    # Create a custom tokenizer
    custom_tokenizer = CustomTokenizer(vocab)

    # Create WordEmbeddings module
    w2v_embeddings = models.WordEmbeddings(vocab, embedding_weights)
    w2v_embeddings.tokenizer = custom_tokenizer

    # Create Pooling module
    pooling_layer = models.Pooling(w2v_embeddings.get_word_embedding_dimension())

    # Create SentenceTransformer model
    custom_model = SentenceTransformer(modules=[w2v_embeddings, pooling_layer])

    return custom_model

In [3]:
# Saving the model and the custom tokenizer
model = create_custom_model('/kaggle/input/word2vec-new/word2vec_new.model')
model.save('/kaggle/working/sbert_model')

# Ensure the directory exists and save the custom tokenizer
custom_tokenizer = model[0].tokenizer
custom_tokenizer.save('/kaggle/working/sbert_model/0_WordEmbeddings')

# Loading the model
loaded_model = SentenceTransformer('/kaggle/working/sbert_model')

# Load the custom tokenizer
custom_tokenizer = CustomTokenizer.load('/kaggle/working/sbert_model/0_WordEmbeddings')
loaded_model[0].tokenizer = custom_tokenizer


# Example: encoding a sentence
sentence_embeddings = loaded_model.encode(["Das ist ein test"])

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

In [4]:
sentence_embeddings

array([[ 0.6527207 ,  0.06922134, -0.96157366,  1.0999242 ,  0.63465333,
         0.10980794,  0.07311545,  0.67537653, -0.940318  ,  0.6189362 ,
         0.09219496,  0.19711597, -1.2204223 , -0.53486025, -0.6097276 ,
        -0.77752554,  0.611468  , -0.13125095, -0.9501294 , -0.16616935,
        -1.0042355 ,  1.1879679 ,  0.38627875, -0.88356274,  0.9980704 ,
        -0.25445908, -0.24891172, -0.7029323 ,  0.50321424, -0.33768168,
        -1.5039492 , -0.20570269, -0.41746753,  0.48846605, -0.08247624,
         0.07022371,  0.24974115, -0.5829959 , -0.26858747,  0.3057723 ,
         0.26451892,  0.7270578 , -0.32864368,  1.2760888 ,  0.85239685,
        -0.9060792 ,  0.47935364,  0.18874416,  0.22411975, -0.54197335,
        -1.0213695 , -2.3708992 ,  0.55412096, -0.25532648, -0.00754103,
         0.3824754 ,  0.26769072, -0.24422595, -0.1189203 ,  0.8046151 ,
        -1.1530076 , -1.4395635 , -1.044627  ,  0.9247408 ,  0.46085832,
        -0.02773243,  1.4447328 , -1.059274  ,  0.7