In [1]:
pip install transformers torch



In [8]:
import torch
import pandas as pd
import numpy as np
from transformers import BertModel, BertTokenizer
from torch.utils.data import DataLoader, Dataset

class SymbolDataset(Dataset):
    def __init__(self, csv_file, text_column):
        self.data = pd.read_csv(csv_file)
        self.texts = self.data[text_column].tolist()

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return self.texts[idx]

def compute_embeddings(dataset, batch_size=32, max_length=128, model_name='klue/bert-base'):
    tokenizer = BertTokenizer.from_pretrained(model_name)
    model = BertModel.from_pretrained(model_name)
    model.eval()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    sentence_embeddings_list = []  # Store sentence-level embeddings

    with torch.no_grad():
        for batch in dataloader:
            inputs = tokenizer(batch, return_tensors='pt', padding=True, truncation=True, max_length=max_length, return_attention_mask=True)
            input_ids = inputs['input_ids'].to(device)
            attention_mask = inputs['attention_mask'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            sentence_embeddings = (outputs.last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=1, keepdim=True)  # Mean of token embeddings
            sentence_embeddings_list.append(sentence_embeddings.cpu())

    sentence_embeddings = torch.cat(sentence_embeddings_list, dim=0)  # Sentence embeddings for all expressions

    return sentence_embeddings

In [10]:
dataset = SymbolDataset(csv_file='symbols.csv', text_column='Symbol')

sentence_embeddings = compute_embeddings(dataset, batch_size=32)
sentence_embeddings_np = sentence_embeddings.numpy()

symbol_df = pd.read_csv('symbols.csv')
sentence_df = pd.DataFrame(sentence_embeddings_np)



In [11]:
result_df = pd.concat([symbol_df, sentence_df], axis=1)
result_df.to_csv('symbol_with_embeddings.csv', index=False)

In [12]:
sentence_df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,758,759,760,761,762,763,764,765,766,767
0,0.137469,-0.967392,0.744639,0.041597,0.741354,0.291326,0.618810,0.398630,0.563857,0.374512,...,-0.222111,-0.248447,0.265798,0.106606,0.254744,-0.369918,-0.437465,-0.469659,-0.544954,0.073634
1,-0.883729,-1.550201,0.067479,2.301232,0.711818,-0.167876,-0.941561,-0.555733,1.074224,0.141762,...,0.138704,-0.061283,0.354352,1.485684,0.496543,-0.409772,-0.017113,0.230947,-1.035610,0.139816
2,1.565012,-0.726382,0.406343,0.768778,-0.635288,-0.826894,-0.456032,1.051658,0.504987,-0.856611,...,-0.222164,0.598648,0.284044,0.961415,-0.621139,0.031911,1.126210,-0.084719,-0.283263,0.246667
3,0.403880,-1.215052,0.044865,1.311892,0.611136,-0.410936,-0.136824,0.197430,-0.111148,-0.331359,...,0.336527,0.047007,0.112959,0.543155,0.347441,0.012900,1.665157,0.616060,-0.060827,-0.411542
4,1.094342,-0.379413,-0.569546,0.438057,0.488836,-0.314038,-1.093029,0.883381,0.457393,-1.326222,...,0.079012,0.041419,-0.070390,0.638202,-0.400359,0.477926,0.332438,0.183040,-1.169068,0.291550
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7265,-0.225558,0.202582,0.707888,1.324562,0.837342,0.010573,-0.287568,0.327531,1.528269,-1.313155,...,-0.012785,-0.608310,1.559495,0.237692,-0.228152,-0.704669,-0.402542,-0.574044,-0.603798,0.188569
7266,-0.892496,-0.398331,0.674856,1.379705,0.385307,-0.482905,-0.299641,-0.885883,1.091894,-1.141056,...,0.322485,-0.058503,0.723672,0.897589,0.315378,-0.164158,1.398671,-0.396113,-0.958343,0.630111
7267,0.472238,-0.973623,0.235567,1.108780,-0.640560,0.117277,0.406879,0.497449,0.676735,-0.952913,...,-0.014551,-0.142153,0.294594,0.604706,0.268467,-0.140104,0.339377,-0.271185,-1.196764,0.705126
7268,0.206559,-1.403651,0.247229,0.877581,1.168767,-0.936918,-0.552112,0.033882,-0.036005,-1.410789,...,-0.190453,0.238751,0.001054,-0.253329,0.510085,-0.044803,-0.261565,-0.638770,-0.820140,0.785677
