In [1]:
import os
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity

import torch
from transformers import BertTokenizer, BertForSequenceClassification

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
model = BertForSequenceClassification.from_pretrained("bert-base-chinese", num_labels=21).to(device)
state_dict = torch.load("models/NoteEmbedding_1211_200624/epoch_10.pth", map_location=device)
model.load_state_dict(state_dict)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-chinese and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


<All keys matched successfully>

In [12]:
df = pd.read_csv("data/note_embedding.csv")
embeddings = df.iloc[:, 1:].values.astype(float)

In [14]:
def get_embedding(note):
    model.eval()
    with torch.no_grad():
        encoding = tokenizer(
            note,
            truncation=True,
            padding='max_length', 
            max_length=64,
            return_tensors='pt'
        )
        input_ids = encoding['input_ids'].to(device)
        attention_mask = encoding['attention_mask'].to(device)
        outputs = model.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_embedding = outputs.last_hidden_state[:, 0, :].squeeze().cpu().numpy()
    return cls_embedding

In [15]:
like_note = "巧克力"
like_embedding = get_embedding(like_note)
like_embedding = like_embedding.reshape(1, -1)
similarities = cosine_similarity(like_embedding, embeddings)[0]
df['similarity'] = similarities
df_sorted = df.sort_values(by='similarity', ascending=False)
df_sorted[['note', 'similarity']].head(10)

Unnamed: 0,note,similarity
1387,巧克力,1.0
110,白巧克力,0.994687
678,黑巧克力,0.993454
1331,牛奶巧克力,0.992477
1388,巧克力糖浆,0.98827
957,可可,0.975067
1962,杏仁糖,0.974163
958,可可果,0.973386
875,焦糖,0.972874
1808,咸奶油软糖,0.972029
