In [None]:
from skt.gcp import (
    PROJECT_ID,
    bq_insert_overwrite,
    bq_to_df,
    bq_to_pandas,
    get_bigquery_client,
    bq_table_exists,
    get_max_part,
    load_query_result_to_table,
    pandas_to_bq,
    pandas_to_bq_table,
    load_bigquery_ipython_magic,
    get_bigquery_client,
    _print_query_job_results,
    load_query_result_to_partitions
    
)

from skt.ye import (
    get_hdfs_conn,
    get_spark,
    hive_execute,
    hive_to_pandas,
    pandas_to_parquet,
    slack_send,
    get_secrets
)

In [None]:
from skt.vault_utils import get_secrets

In [None]:
from google.cloud.bigquery.job import QueryJobConfig

In [None]:
import pandas as pd
import requests
import json
from datetime import datetime, date, timedelta
from typing import List, Dict
import os

In [None]:
import torch
from transformers import (
    AdamW,
    AutoModel,
    get_linear_schedule_with_warmup,
    AutoTokenizer,
    AutoConfig
)
import torch.nn.functional as F


In [None]:
proxies = get_secrets('proxies')

In [None]:
os.environ['http_proxy'] = proxies['http']
os.environ['https_proxy'] = proxies['https']

In [None]:
print(f'current_dt: {current_dt}')
print(f'state: {state}')
print(f'long_duration: {long_duration}')


In [None]:
execution_dt = datetime.strptime(current_dt, '%Y-%m-%d')
execution_dt_one_ago = (execution_dt - timedelta(days=1)).strftime('%Y-%m-%d')
execution_dt_next = (execution_dt + timedelta(days=1))
current_dt_next = execution_dt_next.strftime('%Y-%m-%d')

In [None]:
db_name = 'adot_reco_dev'
project_id = 'skt-datahub'
table_name = "nudge_offering_api_table"

In [None]:
bq_client = get_bigquery_client()

In [None]:
query = f"""
SELECT  event_id,
        event_id_description
FROM {db_name}.{table_name}
"""

In [None]:
nudge_table = bq_to_pandas(query)

In [None]:
embedding_candidate_dict = nudge_table.set_index("event_id").to_dict()["event_id_description"]

In [None]:
query_list = []
event_ids = []
for event_id, query in embedding_candidate_dict.items():
    query_list.append(query)
    event_ids.append(event_id)

In [None]:
model = AutoModel.from_pretrained('BM-K/KoDiffCSE-RoBERTa')
tokenizer = AutoTokenizer.from_pretrained('BM-K/KoDiffCSE-RoBERTa')

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model = model.to(device)

In [None]:
def batch_embedd(query_list:list, event_ids:list, batch_size = 128, method="cls"):
    embedding_result = []
    event_id_result = []
    with torch.no_grad():
        for i in range(0, len(query_list), batch_size):
            batch = query_list[i:i+batch_size]
            batch_event_ids = event_ids[i:i+batch_size]
            inputs = tokenizer(batch, padding=True, truncation=True, return_tensors="pt").to(device)
            outputs = model(**inputs, return_dict=True)
            #embeddings, _ = model(**inputs, return_dict=False)
            if method =='cls':
                embeddings = outputs.last_hidden_state[:, 0, :]  
                embeddings = F.normalize(embeddings, p=2, dim=1)
                embeddings = embeddings.unsqueeze(1)
            elif method =='mean_pool':
                embeddings = torch.mean(outputs.last_hidden_state, dim=1)
                embeddings = F.normalize(embeddings, p=2, dim=1)
                embeddings = embeddings.unsqueeze(1) # Shape: (batch_size, 768)
            embedding_result.extend(embeddings.cpu().numpy())
            event_id_result.extend(batch_event_ids)
            
    return embedding_result, event_id_result

In [None]:
embedding_result, event_id_result = batch_embedd(query_list=query_list, event_ids=event_ids)

In [None]:
result = []
for i in range(len(event_id_result)):
    result.append(
        {
         "event_description": query_list[i], 
         "event_description_vector": embedding_result[i].squeeze(0), 
         "event_id":event_id_result[i].split('_')[0], 
         'dt': current_dt_next
        }
    )

In [None]:
import pandas as pd
df = pd.DataFrame(result)

In [None]:
PROJECT_ID = "skt-datahub"
db_name = "adot_reco_dev"
table_name = "nudge_eventId_embedding_userretrive"

In [None]:
pandas_to_bq(pd_df = df, destination=f"{PROJECT_ID}.{db_name}.{table_name}", partition='dt')