# Config

In [1]:
import os; os.chdir('../')

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] =  "6"
# from huggingface_hub import interpreter_login; interpreter_login() # id

import gc
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import seaborn as sns
import matplotlib.pyplot as plt
import nest_asyncio

from utils import Config, Data_Manager, model_selection
from data_generation.utils_data_eval import linear_fit_torch
from data_generation.utils_data_generation import (
    OpenAI_Async_Processor, OpenAI_Batch_Processor,
    Google_Async_Processor, Google_Batch_Processor, 
    Setup_Prompts, 
    extract_activations,
    postprocess_sev_label_data
)


cfg = Config()
dm = Data_Manager(cfg)

label_dict, sub_dim_dict, symp_keys, subdim_keys  = dm.load_dict(dict_type='label')
query_dict = dm.load_dict(dict_type='query')
abbv_dict = dm.load_dict(dict_type='abbv')

# Generate, Predict, Update Data

In [None]:
"""
choose task:
    'generate-sev': Section A4.3 - Step 1 (generates the text data expressing varying severity of the labeled thoughts)
    'predict-sev':  Section A4.3 - Step 2 (predicts the text data severity)
"""

task = 'generate-sev' # 'generate' , 'predict' , 'generate-sev' , 'predict-sev'

Step 1&2: Option 1 - Executes async API call (faster option)

In [None]:
"""
ASYNC PROCESSING
"""

nest_asyncio.apply()

if 'gpt' in cfg.api_model_id.lower(): 
    processor_async = OpenAI_Async_Processor(cfg, label_dict); print(f'Using {cfg.api_model_id} for async processing \n') # 
if 'gemini' in cfg.api_model_id.lower(): 
    processor_async = Google_Async_Processor(cfg, label_dict); print(f'Using {cfg.api_model_id} for async processing \n')
    
chat_system = Setup_Prompts(cfg, label_dict, sub_dim_dict, query_dict)

for t in label_dict.keys():
    print('THOUGHT VAR:', t)
    chats, df = chat_system.setup_prompts(thought_var=t, task=task)
    df = processor_async.process_async(t, task, chats, df)


Step 1&2: Option 2 - Executes batch processing (cheaper option)

In [None]:
"""
BATCH PROCESSING - API REQUEST
"""

task = 'predict' # 'generate', 'predict', 'generate-sev', 'predict-sev'
chat_system = Setup_Prompts(cfg, label_dict, sub_dim_dict, query_dict)

if 'gpt' in cfg.api_model_id: 
    processor_batch = OpenAI_Batch_Processor(cfg, label_dict)
elif 'gemini' in cfg.api_model_id: 
    processor_batch = Google_Batch_Processor(cfg, label_dict)

for t in label_dict.keys():
    print('THOUGHT VAR:', t)
    chats, df = chat_system.setup_prompts(thought_var=t, task=task)
    batch, chats, df = processor_batch.request_batch(t, task, chats, df)


In [None]:
"""
BATCH PROCESSING - CHECK STATUS, RETRIEVE, or CANCEL
"""

task = 'generate' # 'generate', 'predict', 'generate-sev', 'predict-sev'

thought_vars = list(all_label_dict.keys())

for t in thought_vars:
    print('THOUGHT VAR:', t)    
    
    if 'gpt' in cfg.api_model_id:
        processor_batch = OpenAI_Batch_Processor(cfg, all_label_dict)
        batch_list = processor_batch.load_batch(t, task)
        json_path = cfg.json_gen_file if task == 'generate' else cfg.json_pred_file
        for batch in batch_list:
            status = cfg.client.batches.retrieve(batch.id).status
            print(f'Batch ID: {batch.id}')
            print(f'Batch request status: {status}')
            print(f'Batch status: {batch.request_counts} \n')

    ### check batch status
    # batch_id = ''
    # cfg.client.batches.retrieve(batch_id)
    # cfg.client.batches.cancel(batch_id)
    
    elif 'gemini' in cfg.api_model_id:    
        processor_batch = Google_Batch_Processor(cfg, all_label_dict)
        batch_list = processor_batch.load_batch(t, task)
        json_path = cfg.json_gen_file if task == 'generate' else cfg.json_pred_file
        for batch in batch_list:
            print(f'Batch ID: {batch.display_name}')
            print(f'Batch request status: {batch.state.name} \n')    
    

In [None]:
"""
BATCH PROCESSING - TO CSV
"""

task = 'generate' # 'generate', 'predict', 'generate-sev', 'predict-sev'
thought_vars = all_label_dict.keys()
chat_system = Setup_Prompts(cfg, all_label_dict, label_dict, sub_dim_dict, query_dict)

if 'gpt' in cfg.api_model_id:
    processor_batch = OpenAI_Batch_Processor(cfg, all_label_dict)
elif 'gemini' in cfg.api_model_id:
    processor_batch = Google_Batch_Processor(cfg, all_label_dict)

for t in thought_vars:
    print('THOUGHT VAR:', t)
    batch_list = processor_batch.load_batch(t, task)
    data = []
    for batch in batch_list:
        data.extend(processor_batch.check_batch_status_api(batch, task))
    chats, df = chat_system.setup_prompts(thought_var=t, task=task)
    df = processor_batch.batch_to_df(data, df, task)

POST-Process: Steps 3 & 4

In [2]:
"""
POSTPROCESS SEVERITY LABEL DATA (RUN ONLY AFTER DATA GENERATION IS DONE)
"""

filt_dfs = postprocess_sev_label_data(cfg)

# Extract LLM Activation

In [2]:
"""
EXTRACT LLM ACTIVATIONS FROM THE FILTERED DATA
"""
df = pd.read_csv(f'{cfg.sev_filt_dir}/{cfg.df_file_name}_v{cfg.current_v}.csv')

model, tokenizer = model_selection(cfg)
X_dict, y, y_sev = extract_activations(
    model, 
    tokenizer,
    dfs=df,
    label_dict=label_dict,
    batch_size=128, 
    layers=cfg.hook_layers,
    label_type='severity'
)
y = [y, y_sev]

"""
SAVE ACTIVATIONS AND LABELS
"""
for layer in cfg.hook_layers:
    if not os.path.exists(f'{cfg.outcome_dir}/layer_{layer}'):
        os.makedirs(f'{cfg.outcome_dir}/layer_{layer}')

    X_path = f'{cfg.outcome_dir}/layer_{layer}/{cfg.X_file_name}_sev_v{cfg.current_v}.pt'
    y_path = f'{cfg.outcome_dir}/layer_{layer}/{cfg.y_file_name}_sev_v{cfg.current_v}.pt'

    torch.save(X_dict[layer], X_path)
    torch.save(y, y_path)

del X_dict, y, y_sev, model, tokenizer
gc.collect()
torch.cuda.empty_cache()

Loading model...


Loading checkpoint shards:   0%|          | 0/17 [00:00<?, ?it/s]

Extracting activations...


100%|██████████| 24/24 [00:48<00:00,  2.00s/it]
