In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import transformers
import torch

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

pipeline = transformers.pipeline(
    "text-generation",
    model=model_id,
    device_map="auto",
    model_kwargs={
        "torch_dtype": torch.float16,
        "quantization_config": {"load_in_4bit": True},
        "low_cpu_mem_usage": True,
    },
)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [None]:
from datasets import load_dataset

import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent))
from tab_exp.tab import generate_synth_data, PIIData


In [None]:
sd_train = generate_synth_data(samples=1000, output="samples_train", clean=True, debug=True)
ds_train = load_dataset("json", data_files=sd_train["combined_path"], split="train")

sd_test = generate_synth_data(samples=300, output="samples_test", clean=True, debug=True)
ds_test = load_dataset("json", data_files=sd_test["combined_path"])

# sd_validate = generate_synth_data(samples=100, output="samp  les_validate", clean=True, debug=True)
# ds_validate = load_dataset("json", data_files=sd_validate["combined_path"], split="validation")


In [None]:
# convert the text in the dataset to numbers
from transformers import DataCollatorWithPadding, TrainingArguments

def tokenize_fn(data: PIIData):
    return tokenizer(data['text'], truncation=True, padding=True)
tokened_train_ds = ds_train.map(tokenize_fn)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

## TODO

PII project 

- in the tab module:
    - for each field, create the text and a label of anonymized, original
    - create one an example for anonymized and original each
    - Create a Dataset from this sequence
- pass the dataset to the model
- follow the [directions here to use ORPO](https://huggingface.co/blog/mlabonne/orpo-llama-3)
- Run some experiments with different hyperparams
    - Look into plotting the results with plotly


In [7]:
import tiktoken as tt
enc_name = tt.encoding_name_for_model("gpt-4o")
enc = tt.get_encoding(enc_name)

toks = enc.encode('''{
    "user_id": 1234,
    "another_val": "something"
}''')
enc.decode(toks)

'{\n    "user_id": 1234,\n    "another_val": "something"\n}'

: 