In [None]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, Trainer, TrainingArguments, BitsAndBytesConfig
import torch, evaluate, datasets, pandas as pd, numpy as np
from peft import LoraConfig, TaskType, get_peft_model

f1 = evaluate.load("f1")

task = "fb" # "semeval"

if task == "fb":
    id2label = {0:'Against, Against',1:'Against, Favor',2:'Against, None',
                3:'Favor, Against',4:'Favor, None',
                5:'None, Against', 6:'None, Favor', 7:'None, None'}
else:
    id2label = {0:'Clinton, Against',1:'Clinton, Favor',2:'Clinton, None',
                3:'Trump, Against',4:'Trump, Favor',5:'Trump, None'}

train = pd.read_csv(f"data/{task}_train.csv")
test = pd.read_csv(f"data/{task}_test.csv")

########################## 1. GPU device setup
# Consult the following sources for more information
# https://github.com/philschmid/deep-learning-pytorch-huggingface/blob/main/training/scripts/run_fsdp_qlora.py
# https://www.philschmid.de/fsdp-qlora-llama3#3-fine-tune-the-llm-with-pytorch-fsdp-q-lora-and-sdpa
# https://arxiv.org/abs/2305.14314
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch_dtype = torch.bfloat16
quant_storage_dtype = torch.bfloat16
device = "cuda"

quantization_config = BitsAndBytesConfig( 
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch_dtype,
            bnb_4bit_quant_storage=quant_storage_dtype, 
        )

########################## 2. Functions
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    if isinstance(
        logits, tuple
    ):  # if the model also returns hidden_states or attentions
        logits = logits[0]
    predictions = np.argmax(logits, axis=1)
    
    preds = pd.DataFrame({"completion":labels,
                  "preds":predictions,
                  })
    
    preds["completion"] = preds["completion"].replace(id2label, regex=True)
    preds["preds"] = preds["preds"].replace(id2label, regex=True)
    
    preds.to_csv(f"predicted_labels/{task}_{target.split("/")[-1]}_{num_data}.csv", index=False)
    
    return f1.compute(predictions=predictions, references=labels, average="weighted")

def tokenize(df):
    return tokenizer(df['prompt'], truncation=True)

########################## 3. Model list
# "bert-base-uncased"
# "sentence-transformers/all-mpnet-base-v2"
# "microsoft/deberta-v3-base"
# "google/flan-t5-xxl"
# "mistralai/Mistral-7B-v0.3"
# "meta-llama/Meta-Llama-3-8B"
# "meta-llama/Meta-Llama-3-70B"

target = "meta-llama/Meta-Llama-3-8B"
tokenizer = AutoTokenizer.from_pretrained(target)
if target in ["mistralai/Mistral-7B-v0.3", "meta-llama/Meta-Llama-3-8B", "meta-llama/Meta-Llama-3-70B"]:
    tokenizer.pad_token = tokenizer.eos_token
collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [None]:
########################## 4. Tokenize
train_ds = datasets.Dataset.from_pandas(train).rename_columns(
    {"completion":"label"}
    ).class_encode_column("label").map(tokenize).select_columns(['input_ids', 'attention_mask', 'label'])

test_ds = datasets.Dataset.from_pandas(test).rename_columns(
    {"completion":"label"}
    ).class_encode_column("label").map(tokenize).select_columns(['input_ids', 'attention_mask', 'label'])

########################## 4.1. Differ learning regime
regime = ["10", "100", "1000", "all"] if task == "fb" else ["10", "100", "all"]

for num_data in regime:
    if num_data == "10":
        train_ds_subset = train_ds.select(range(10)) # For 10-shot regime
    elif num_data == "100":
        train_ds_subset = train_ds.select(range(100)) # For 100-shot regime
    elif num_data == "1000":
        train_ds_subset = train_ds.select(range(1000)) # For 1000-shot regime
    else:
        train_ds_subset = train_ds # For all-shot regime
        
    # Training argument
    training_args = TrainingArguments(
        output_dir=f"predicted_labels/log/fine-tuned/{task}/{target.split("/")[-1]}/{num_data}",
        logging_dir=f"predicted_labels/log/fine-tuned/{task}/{target.split("/")[-1]}/{num_data}",
        num_train_epochs=4,
        learning_rate=2e-4,
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        gradient_accumulation_steps=4,
        optim="adamw_torch",
        dataloader_num_workers=4,
        gradient_checkpointing=True,
        gradient_checkpointing_kwargs={"use_reentrant":True},
        evaluation_strategy="steps",
        eval_steps=0.5,
        logging_strategy="steps",
        logging_steps=0.5
        )

    model = AutoModelForSequenceClassification.from_pretrained(
        target,
        device_map=device,
        num_labels=8 if task == "fb" else 6,
        torch_dtype=quant_storage_dtype,
        quantization_config=quantization_config,
        use_cache=False
    )

    if target in ["google/flan-t5-xxl", "mistralai/Mistral-7B-v0.3", "meta-llama/Meta-Llama-3-8B", "meta-llama/Meta-Llama-3-70B"]:
        peft_config = LoraConfig(
            lora_alpha=16,
            lora_dropout=0.1 if target=="meta-llama/Meta-Llama-3-70B" else 0.05,
            r=64,
            bias="none",
            task_type=TaskType.SEQ_CLS,
            target_modules="all-linear"
            ) 
        model = get_peft_model(model, peft_config)

    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=collator,
        train_dataset=train_ds_subset,
        eval_dataset=test_ds,
        compute_metrics=compute_metrics
    )
    print("---------------------------------------------------")
    print(f"Start training with {num_data} rows for {task}")
    trainer.train()