In [None]:
import os
import pandas as pd
from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder

import torch
from transformers import BertTokenizer, BertForSequenceClassification

In [2]:
df = pd.read_csv("data/note_class_map.csv")
notes = df['note'].tolist()
le = LabelEncoder()
labels = le.fit_transform(df['class'])
print(le.classes_)

['东方调' '木质香调' '果香调' '柑橘调' '柑苔香调' '树脂调' '概念性' '海洋调' '烟草香调' '琥珀调' '皮革香调'
 '矿物调' '粉香调' '美食香调' '花香调' '苔藓调' '辛香调' '醛香调' '青香调' '馥奇调' '麝香调']


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained("./bert-base-chinese")
model = BertForSequenceClassification.from_pretrained("./bert-base-chinese", num_labels=len(le.classes_)).to(device)
state_dict = torch.load("models/NoteEmbedding_1211_200624/epoch_10.pth", map_location=device)
model.load_state_dict(state_dict)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ./bert-base-chinese and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


<All keys matched successfully>

In [4]:
df = pd.read_csv("data/note_class_map.csv")

In [5]:
model.eval()
embeddings = []
for row in tqdm(df.itertuples(), total=len(df)):
    with torch.no_grad():
        encoding = tokenizer(
            row.note,
            truncation=True,
            padding='max_length', 
            max_length=64,
            return_tensors='pt'
        )
        input_ids = encoding['input_ids'].to(device)
        attention_mask = encoding['attention_mask'].to(device)
        outputs = model.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_embedding = outputs.last_hidden_state[:, 0, :].squeeze().cpu().numpy()
        embeddings.append(cls_embedding)

embeddings_df = pd.DataFrame(embeddings)
embeddings_df.insert(0, 'note', df['note'])
embeddings_df.to_csv("data/note_embedding.csv", index=False)

100%|██████████| 2285/2285 [00:16<00:00, 137.70it/s]
