In [2]:
from torch.utils.data.dataloader import DataLoader

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from tqdm.auto import tqdm

import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
from torch.utils.data import Subset

from src.train_test_utils import train, test 
import torch.nn as nn
from src.dataset import TextDataset

from src.probes import LinearProbeClassification
import sklearn.model_selection
import pickle

import time

tic, toc = (time.time, time.time)

In [4]:
import os
from huggingface_hub import login

# Set your Hugging Face token as an environment variable
os.environ["HUGGINGFACE_TOKEN"] = "-------
# Log in to Hugging Face
login(token=os.environ["HUGGINGFACE_TOKEN"])

In [5]:
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it", use_auth_token=True)
model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", use_auth_token=True)
model.half().cuda()
model.eval()

tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/838 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/241M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]

Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 2304, padding_idx=0)
    (layers): ModuleList(
      (0-25): 26 x Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): Linear(in_features=2304, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (v_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2304, bias=False)
          (rotary_emb): Gemma2RotaryEmbedding()
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (up_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (down_proj): Linear(in_features=9216, out_features=2304, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (pre_feedforward_layernorm): Gemma2RMSNorm((2304,), eps

In [6]:
class TrainerConfig:
    # optimization parameters
    learning_rate = 1e-3
    betas = (0.9, 0.95)
    weight_decay = 0.1 # only applied on matmul weights
    # learning rate decay params: linear warmup followed by cosine decay to 10% of original
    # checkpoint settings

    def __init__(self, **kwargs):
        for k,v in kwargs.items():
            setattr(self, k, v)

In [11]:
import zipfile
def unzip_files(zip_path, extract_path):
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_path)

# Unzip all necessary files
'''
for directory in '/root/usr_probes/dataset':
    zip_path = directory[:-1] + '.zip'  # Assuming the zip file has the same name as the directory
    if os.path.exists(zip_path):
        unzip_files(zip_path, os.path.dirname(directory))
'''

In [25]:
jump_socioeco = True

new_prompt_format=True
residual_stream=True
uncertainty = False
logistic = True
augmented = False
remove_last_ai_response = True
include_inst = True
one_hot = True

In [26]:
label_to_id_age = {"child": 0,
                   "adolescent": 1,
                   "adult": 2,
                   "older adult": 3,
                  }

label_to_id_gender = {"male": 0,
                      "female": 1,
                     }

label_to_id_socioeconomic = {"low": 0,
                             "middle": 1,
                             "high": 2}

label_to_id_neweducation = {"someschool": 0,
                            "highschool": 1,
                            "collegemore": 2}

prompt_translator = {"_age_": "age",
                     "_gender_": "gender",
                     "_socioeco_": "socioeconomic status",
                     "_education_": "education level",}

openai_dataset = {"_age_": "dataset/openai_age_1/",
                  "_gender_": "dataset/openai_gender_1/",
                  "_education_": "dataset/openai_education_1/",
                  "_socioeco_": "dataset/openai_socioeconomic_1/",}

accuracy_dict = {}

directories = ["/root/usr_probes/dataset/llama_age_1/", "/root/usr_probes/dataset/llama_gender_1/",
               "/root/usr_probes/dataset/llama_socioeconomic_1/", "/root/usr_probes/dataset/openai_education_1/"]

label_idfs = ["_age_", "_gender_", "_socioeco_", "_education_"]

label_to_ids = [label_to_id_age, label_to_id_gender,
                label_to_id_socioeconomic, label_to_id_neweducation]

In [24]:
for directory, label_idf, label_to_id in zip(directories, label_idfs, label_to_ids):
    if not os.path.exists(directory):
        zip_path = directory[:-1] + '.zip'
        if os.path.exists(zip_path):
            unzip_files(zip_path, os.path.dirname(directory))
        else:
            print(f"Warning: Neither directory {directory} nor zip file {zip_path} found. Skipping...")
            continue
    # additional_dataset=[directory[:-1] + "_additional/"]
    if label_idf == "_education_":
        additional_dataset=[]
    else:
        print('THEO')
        continue
        #additional_dataset=[directory[:-2] + "2.zip", openai_dataset[label_idf]]
    if label_idf == "_gender_":
        additional_dataset += ["/root/usr_probes/dataset/openai_gender_2/", "/root/usr_probes/dataset/openai_gender_3/", 
                               "/root/usr_probes/dataset/openai_gender_4",]
    if label_idf == "_education_":
        additional_dataset += ["/root/usr_probes/dataset/openai_education_2", "/root/usr_probes/dataset/openai_education_3/"]
    if label_idf == "_socioeco_":
        additional_dataset += ["/root/usr_probes/dataset/openai_socioeconomic_2/", "/root/usr_probes/dataset/openai_socioeconomic_3/"]
    if label_idf == "_age_":
        additional_dataset += ["/root/usr_probes/dataset/openai_age_2/", "/root/usr_probes/dataset/openai_age_3/"]
        
    dataset = TextDataset(directory, tokenizer, model, label_idf=label_idf, label_to_id=label_to_id,
                          convert_to_llama2_format=True, additional_datas=additional_dataset, 
                          new_format=new_prompt_format,
                          residual_stream=residual_stream, if_augmented=augmented, 
                          remove_last_ai_response=remove_last_ai_response, include_inst=include_inst, k=1,
                          one_hot=False, last_tok_pos=-1)
    dict_name = label_idf.strip("_")

    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_idx, val_idx = sklearn.model_selection.train_test_split(list(range(len(dataset))), 
                                                                  test_size=test_size,
                                                                  train_size=train_size,
                                                                  random_state=12345,
                                                                  shuffle=True,
                                                                  stratify=dataset.labels,
                                                                 )

    train_dataset = Subset(dataset, train_idx)
    test_dataset = Subset(dataset, val_idx)

    sampler = None
    train_loader = DataLoader(train_dataset, shuffle=True, sampler=sampler, pin_memory=True, batch_size=200, num_workers=1)
    test_loader = DataLoader(test_dataset, shuffle=False, pin_memory=True, batch_size=400, num_workers=1)

    loss_func = nn.BCELoss()
    torch_device = "cuda"

    seeds = seeds[:9]
    accuracy_dict[dict_name] = []
    accuracy_dict[dict_name + "_final"] = []
    accuracy_dict[dict_name + "_train"] = []
        
    accs = []
    final_accs = []
    train_accs = []
    for i in tqdm(range(0, 41)):
        trainer_config = TrainerConfig()
        probe = LinearProbeClassification(probe_class=len(label_to_id.keys()), device="cuda", input_dim=5120,
                                            logistic=logistic)
        optimizer, scheduler = probe.configure_optimizers(trainer_config)
        best_acc = 0
        max_epoch = 50
        verbosity = False
        layer_num = i
        print("-" * 40 + f"Layer {layer_num}" + "-" * 40)
        for epoch in range(1, max_epoch + 1):
            if epoch == max_epoch:
                verbosity = True
            # Get the train results from training of each epoch
            if uncertainty:
                train_results = train(probe, torch_device, train_loader, optimizer, 
                                        epoch, loss_func=loss_func, verbose_interval=None,
                                        verbose=verbosity, layer_num=layer_num, 
                                        return_raw_outputs=True, epoch_num=epoch, num_classes=len(label_to_id.keys()))
                test_results = test(probe, torch_device, test_loader, loss_func=loss_func, 
                                    return_raw_outputs=True, verbose=verbosity, layer_num=layer_num,
                                    scheduler=scheduler, epoch_num=epoch, num_classes=len(label_to_id.keys()))
            else:
                train_results = train(probe, torch_device, train_loader, optimizer, 
                                        epoch, loss_func=loss_func, verbose_interval=None,
                                        verbose=verbosity, layer_num=layer_num,
                                        return_raw_outputs=True,
                                        one_hot=one_hot, num_classes=len(label_to_id.keys()))
                test_results = test(probe, torch_device, test_loader, loss_func=loss_func, 
                                    return_raw_outputs=True, verbose=verbosity, layer_num=layer_num,
                                    scheduler=scheduler,
                                    one_hot=one_hot, num_classes=len(label_to_id.keys()))

            if test_results[1] > best_acc:
                best_acc = test_results[1]
                torch.save(probe.state_dict(), f"probe_checkpoints/reading_probe/{dict_name}_probe_at_layer_{layer_num}.pth")
        torch.save(probe.state_dict(), f"probe_checkpoints/reading_probe/{dict_name}_probe_at_layer_{layer_num}_final.pth")
        
        accs.append(best_acc)
        final_accs.append(test_results[1])
        train_accs.append(train_results[1])
        cm = confusion_matrix(test_results[3], test_results[2])
        cm_display = ConfusionMatrixDisplay(cm, display_labels=label_to_id.keys()).plot()
        plt.show()

        accuracy_dict[dict_name].append(accs)
        accuracy_dict[dict_name + "_final"].append(final_accs)
        accuracy_dict[dict_name + "_train"].append(train_accs)
        
        with open("probe_checkpoints/reading_probe_experiment.pkl", "wb") as outfile:
            pickle.dump(accuracy_dict, outfile)
    del dataset, train_dataset, test_dataset, train_loader, test_loader
    torch.cuda.empty_cache()

THEO
THEO
THEO
