# LLM Attribution Problem

In [None]:
!pip install transformers==4.44.2

In [133]:
import torch
from datasets import load_dataset
from transformers import pipeline, set_seed, BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments, TrainerCallback

import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, classification_report
from tqdm import tqdm
import numpy as np
import seaborn as sns
import re
import pandas as pd
from collections import defaultdict
from copy import deepcopy

In [None]:
torch.__version__

In [None]:
import transformers

In [None]:
transformers.__version__

In [3]:
set_seed(42)

In [None]:
# pipe = pipeline("text-generation", model="bigscience/bloomz-1b7", device=0)
pipe = pipeline("text-generation", model="gpt2") #, device=0)

# Creating Datasets

## Collate x_i

### Wiki Dataset

In [None]:
wiki_ds = load_dataset("wikimedia/wikipedia", "20231101.en")

In [None]:
sentence_endings = re.compile(r'([A-Z][^.!?]*[.!?])\s+(?=[A-Z])')

In [None]:
res = []
for p in wiki_ds["train"][0]["text"].split("\n\n"):
    sentences = sentence_endings.findall(p)
#     print(sentences)
    for sent in sentences:
        if len(sent.split()) < 3:
            continue
        res.append(" ".join(sent.split()[:5]))


In [None]:
len(res), sorted(res)

In [None]:
def get_wiki_dataset(train_samples=500, num_words=5):
    step = 1 # wiki_ds.num_rows["train"]//train_samples
    test_samples = train_samples//5
    train_offset = 0
    val_offset = 300000
    test_offset = 500000
    def get_split(num_samples, offset):
        res = []
        ctr = 0
        for t in wiki_ds["train"][offset:offset+1000:step]["text"]:
            if ctr >= num_samples:
                print("Done")
                break
            para = t.split("\n\n")
            if len(para) < 5:
                continue

            for p in para:
                sentences = sentence_endings.findall(p)
#                 print(sentences)
                ps = 0
                for sent in sentences:
                    if len(sent.split()) < 3:
                        continue
                    res.append((" ".join(sent.split()[:num_words]), sent))
                    ctr += 1
                    ps += 1
                    if ps >= 2:
                        break
            
        return sorted(list(set(res)))

    return (get_split(train_samples, train_offset),
            get_split(test_samples, val_offset),
            get_split(test_samples, test_offset)
           )

In [None]:
wiki_train, wiki_val, wiki_test = get_wiki_dataset(1000)

In [None]:
len(wiki_train), len(wiki_val), len(wiki_test)

In [None]:
set(wiki_val).intersection(set(wiki_test)), set(wiki_val).intersection(set(wiki_train)), set(wiki_test).intersection(set(wiki_train)), 

In [None]:
for s, org in wiki_train:
    print(s, f" |{len(s.split())}| ", org)

In [None]:
for s in wiki_val:
    print(s, len(s.split()))

In [None]:
for s in wiki_test:
    print(s, len(s.split()))

In [None]:
def get_df(l):
    df = pd.DataFrame()
    df["original_sentence"] = [s for _, s in l]
    df["truncated_sentence"] = [s for s, _ in l]
    return df

In [None]:
wiki_train_df = get_df(wiki_train)
wiki_val_df = get_df(wiki_val)
wiki_test_df = get_df(wiki_test)

In [None]:
wiki_train_df, wiki_val_df, wiki_test_df

In [None]:
wiki_train_df.to_csv("./wiki_train.csv", index=False)
wiki_val_df.to_csv("./wiki_val.csv", index=False)
wiki_test_df.to_csv("./wiki_test.csv", index=False)

In [None]:
pd.read_csv("wiki_train.csv")

In [None]:
pd.read_csv("wiki_val.csv")

In [None]:
pd.read_csv("wiki_test.csv")

### GSM8K Dataset

In [None]:
gsm8k_ds = load_dataset("openai/gsm8k", "main")

In [None]:
CHAR_LIMIT = 120
gsm8k_train_txts = [q for q in gsm8k_ds["train"]["question"] if len(q) < CHAR_LIMIT]
tv_split = int(len(gsm8k_train_txts) * 0.8)
gsm8k_train, gsm8k_val = gsm8k_train_txts[:tv_split], gsm8k_train_txts[tv_split:]
gsm8k_test = [q for q in gsm8k_ds["test"]["question"] if len(q) < CHAR_LIMIT]

In [None]:
len(gsm8k_train), len(gsm8k_val), len(gsm8k_test)

In [None]:
def get_gsm_df(l):
    df = pd.DataFrame()
    df["original_sentence"] = l
    return df

In [None]:
gsm8k_train_df = get_gsm_df(gsm8k_train)
gsm8k_val_df = get_gsm_df(gsm8k_val)
gsm8k_test_df = get_gsm_df(gsm8k_test)

In [None]:
gsm8k_train_df

In [None]:
gsm8k_val_df

In [None]:
gsm8k_test_df

In [None]:
gsm8k_train_df.to_csv("./gsm8k_train.csv", index=False)
gsm8k_val_df.to_csv("./gsm8k_val.csv", index=False)
gsm8k_test_df.to_csv("./gsm8k_test.csv", index=False)

In [None]:
pd.read_csv("gsm8k_train.csv")

In [None]:
pd.read_csv("gsm8k_val.csv")

In [None]:
pd.read_csv("gsm8k_test.csv")

## Generate model outputs

### GPT2

In [None]:
MODEL_NAME = "gpt2"
# MODEL_NAME = "gpt2-xl"
MODEL_NAME_clean = MODEL_NAME.replace("/", "-")
MODEL_NAME_clean 

In [None]:
# tokenizer = GPT2Tokenizer.from_pretrained(MODEL_NAME)
# tokenizer.pad_token_id = tokenizer.eos_token_id

In [None]:
generator = pipeline('text-generation', model=MODEL_NAME, device_map="auto") #, tokenizer=tokenizer)

### Setup Phi2

In [None]:
MODEL_NAME = "microsoft/phi-2"
MODEL_NAME_clean = MODEL_NAME.replace("/", "-")
MODEL_NAME_clean 

In [None]:
# tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) #, padding_side='left')
# tokenizer.pad_token_id = tokenizer.eos_token_id

In [None]:
generator = pipeline('text-generation', model=MODEL_NAME, device_map="auto") #, tokenizer=tokenizer)

### Setup Falcon-7B

In [None]:
import os
os.remove("state.db")

In [None]:
MODEL_NAME = "tiiuae/falcon-7b"
MODEL_NAME_clean = MODEL_NAME.replace("/", "-")
MODEL_NAME_clean 

In [None]:
# tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) #, padding=True, padding_side='left')
# tokenizer.pad_token_id = tokenizer.eos_token_id

In [None]:
generator = pipeline(
    "text-generation",
    model=MODEL_NAME,
#     tokenizer=tokenizer,
#     torch_dtype=torch.bfloat16,
#     trust_remote_code=True,
    device_map="auto"
)

In [None]:
tokenizer(" efsdf sdf sd sdf sgd f sdf sdf", padding=True)

In [None]:
# generated_text = generator("A -wide meteorite impact crater", max_length=MAX_LENGTH, num_return_sequences=1, batch_size=batch_size, do_sample=True, temperature=0.7) #, padding=True)
generated_text = generator(batch, max_length=MAX_LENGTH, num_return_sequences=1, batch_size=batch_size, do_sample=True, temperature=0.7, pad_token_id=generator.tokenizer.eos_token_id)


In [None]:
generated_text

In [None]:
len(tokenizer(generated_text[0]["generated_text"])["input_ids"])

### Setup Mistral-7B-Instruct-v0.2

In [None]:
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"
MODEL_NAME_clean = MODEL_NAME.replace("/", "-").replace(".", "-")
MODEL_NAME_clean 

In [None]:
from huggingface_hub import login
access_token_read = "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"
login(token = access_token_read)

In [None]:
# tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# tokenizer.pad_token_id = tokenizer.eos_token_id

In [None]:
generator = pipeline(
    "text-generation",
    model=MODEL_NAME,
#     tokenizer=tokenizer,
#     torch_dtype=torch.bfloat16,
#     trust_remote_code=True,
    device_map="auto",
)

### Text Generation Code

In [None]:
generator.tokenizer.pad_token_id = generator.tokenizer.eos_token_id
generator.tokenizer.padding_side = 'left'

In [None]:
# DATASET = "wiki_train"
# DATASET = "wiki_val"
DATASET = "wiki_test"

# DATASET = "gsm8k_train"
# DATASET = "gsm8k_val"
# DATASET = "gsm8k_test"

# colname = "truncated_sentence"
colname = "original_sentence"

In [None]:
read_df = pd.read_csv(f"{DATASET}.csv")

In [None]:
f"{DATASET}.csv"

In [None]:
model_out_df = read_df
model_out_df[MODEL_NAME_clean] = [None] * len(model_out_df)

In [None]:
model_out_df

In [None]:
MAX_LENGTH = 256

In [None]:
batch_size = 64 #32 #1 #32 #16

print(MODEL_NAME, MODEL_NAME_clean)
for DATASET, colname in zip(["wiki_train", "wiki_val", "wiki_test", "gsm8k_train", "gsm8k_val", "gsm8k_test"], ["truncated_sentence", "truncated_sentence", "truncated_sentence", "original_sentence", "original_sentence", "original_sentence"]):
    read_df = pd.read_csv(f"{DATASET}.csv")
    print(f"{DATASET}.csv")
    print(f"{colname=}")
    model_out_df = read_df
    model_out_df[MODEL_NAME_clean] = [None] * len(model_out_df)
    print(model_out_df)
    res = []
    for s in tqdm(range(0, len(model_out_df), batch_size)):
        batch = model_out_df[colname].loc[s:s+batch_size-1].to_list()
        print(len(batch), batch)
        generated_text = generator(batch, max_length=MAX_LENGTH, num_return_sequences=1, batch_size=batch_size, do_sample=True, temperature=0.7)
        gen_outs = [g[0]["generated_text"] for g in generated_text]
        res.extend(gen_outs)
        print(gen_outs)
        model_out_df[MODEL_NAME_clean] = res + [None] * (len(model_out_df) - len(res))
        model_out_df.to_csv(f"{DATASET}_{MODEL_NAME_clean}.csv", index=False)
        print(model_out_df)
        print(f"Saved to {DATASET}_{MODEL_NAME_clean}.csv")
    #     break

In [None]:
batch_size = 32 #1 #32 # 16
res = []
for s in tqdm(range(0, len(model_out_df), batch_size)):
    batch = model_out_df[colname].loc[s:s+batch_size-1].to_list()
    print(len(batch), batch)
    generated_text = generator(batch, max_length=MAX_LENGTH, num_return_sequences=1, batch_size=batch_size, do_sample=True, temperature=0.7)
    gen_outs = [g[0]["generated_text"] for g in generated_text]
    res.extend(gen_outs)
    print(gen_outs)
    model_out_df[MODEL_NAME_clean] = res + [None] * (len(model_out_df) - len(res))
    model_out_df.to_csv(f"{DATASET}_{MODEL_NAME_clean}.csv", index=False)
#     break

In [None]:
 # batch_size = 32
# res = []
# for s in tqdm(range(0, len(model_out_df), batch_size)):
#     batch = model_out_df["truncated_sentence"].loc[s:s+batch_size-1].to_list()
#     print(len(batch), batch)
#     generated_text = generator(batch, max_length=MAX_LENGTH, num_return_sequences=1, batch_size=batch_size, do_sample=True, temperature=0.7)
#     gen_outs = [g[0]["generated_text"] for g in generated_text]
#     res.extend(gen_outs)
#     print(gen_outs)
#     model_out_df[MODEL_NAME_clean] = res + [None] * (len(model_out_df) - len(res))
#     model_out_df.to_csv(f"{DATASET}_{MODEL_NAME_clean}.csv", index=False)
#     break

In [None]:
model_out_df

In [None]:
# model_out_df - test

In [None]:
model_out_df[MODEL_NAME_clean] = res

In [None]:
model_out_df.to_csv(f"{DATASET}_{MODEL_NAME_clean}.csv", index=False)

In [None]:
pd.read_csv(f"{DATASET}_{MODEL_NAME_clean}.csv")

In [None]:
# for idx, rw in tqdm(wiki_train_df.iterrows()):
#     generated_text = generator(rw["truncated_sentence"], max_length=256, num_return_sequences=1, temperature=0.7)
#     genout = generated_text[0]["generated_text"]
#     print(rw["truncated_sentence"])
#     print(genout)
#     wiki_train_gpt2_df.loc[idx, "gpt2"] = genout
# #     input_text = "Complete the following: "+rw["truncated_sentence"]
# #     print(input_text)
# #     generated_text = generator(input_text, max_length=256, num_return_sequences=1, temperature=0.7)
# #     print(generated_text[0]["generated_text"])
#     if idx % 10 == 0:
#         wiki_train_gpt2_df.to_csv("wiki_train_gpt2.csv", index=False)
# #         break

In [None]:
# for p, ft in train_txt:
#     print(f"RUNNING: {p} : {ft}")
#     gen_txt = pipe(p, num_return_sequences=3)
#     for seq in gen_txt:
#         print("*", seq["generated_text"])

In [None]:
for idx, rw in model_out_df.iterrows():
    print("Input:", rw["truncated_sentence"])
    print("Output:", rw[MODEL_NAME_clean])

## BERT Sequence Classification

In [4]:
MAX_LENGTH = 256

In [5]:
SEQ_CLF_MODEL = "bert-base-cased"

CLASS_TO_IDX = {"gpt2": 0,
               "gpt2-xl": 1,
               "microsoft-phi-2": 2,
               "tiiuae-falcon-7b": 3, 
               "mistralai-Mistral-7B-Instruct-v0-2": 4}

NUM_CLASSES = len(CLASS_TO_IDX.items())
IDX_TO_CLASS = {v: k for k, v in CLASS_TO_IDX.items()}
CLASSES = [IDX_TO_CLASS[k] for k in IDX_TO_CLASS]
DATASETS = ["wiki", "gsm8k"]
SPLITS = ["train", "val", "test"]
print(f"Daatsets: {DATASETS}")
print(f"Using model for sequence classification: {SEQ_CLF_MODEL}")
print(f"Number of models (classes): {NUM_CLASSES}")
print(f"Models (Classes): {CLASSES}")
print(CLASS_TO_IDX)
print(IDX_TO_CLASS)

Daatsets: ['wiki', 'gsm8k']
Using model for sequence classification: bert-base-cased
Number of models (classes): 5
Models (Classes): ['gpt2', 'gpt2-xl', 'microsoft-phi-2', 'tiiuae-falcon-7b', 'mistralai-Mistral-7B-Instruct-v0-2']
{'gpt2': 0, 'gpt2-xl': 1, 'microsoft-phi-2': 2, 'tiiuae-falcon-7b': 3, 'mistralai-Mistral-7B-Instruct-v0-2': 4}
{0: 'gpt2', 1: 'gpt2-xl', 2: 'microsoft-phi-2', 3: 'tiiuae-falcon-7b', 4: 'mistralai-Mistral-7B-Instruct-v0-2'}


In [6]:
# import os

# for DATASET, colname in zip(["wiki", "gsm8k"], ["truncated_sentence", "original_sentence"]):
#     for SPLIT in ["train", "val", "test"]:
#         MODEL_NAME_clean = "mistralai-Mistral-7B-Instruct-v0.2"
#         fname = f"{DATASET}_{SPLIT}_{MODEL_NAME_clean}"
#         rename_to = fname.replace(".", "-")
#         try:
#             os.rename(f"{fname}.csv", f"{rename_to}.csv")
#         except:
#             pass
# #         break

In [7]:
df = pd.DataFrame(columns=["dataset", "split", "original_text", "model_input", "model_output", "model"])
print(df)

for DATASET, colname in zip(DATASETS, ["truncated_sentence", "original_sentence"]):
    for SPLIT in SPLITS:
        for MODEL_NAME_clean in CLASSES:
            fname = f"{DATASET}_{SPLIT}_{MODEL_NAME_clean}.csv"
            print(f"Reading {fname}")
            read_df = pd.read_csv(fname)
#             print(read_df)
            read_df["dataset"] = DATASET
            read_df["split"] = SPLIT
            read_df["model"] = MODEL_NAME_clean
            if DATASET == "gsm8k":
                read_df["model_input"] = read_df[colname]
            
            if MODEL_NAME_clean == "mistralai-Mistral-7B-Instruct-v0-2":
                read_df.rename(columns={"mistralai-Mistral-7B-Instruct-v0.2": "model_output"}, inplace=True)
            read_df.rename(columns={"original_sentence": "original_text",
                                    "truncated_sentence": "model_input",
                                    MODEL_NAME_clean: "model_output"}, inplace=True)

            read_df = read_df[df.columns.to_list()]
            df = pd.concat([df, read_df])


Empty DataFrame
Columns: [dataset, split, original_text, model_input, model_output, model]
Index: []
Reading wiki_train_gpt2.csv
Reading wiki_train_gpt2-xl.csv
Reading wiki_train_microsoft-phi-2.csv
Reading wiki_train_tiiuae-falcon-7b.csv
Reading wiki_train_mistralai-Mistral-7B-Instruct-v0-2.csv
Reading wiki_val_gpt2.csv
Reading wiki_val_gpt2-xl.csv
Reading wiki_val_microsoft-phi-2.csv
Reading wiki_val_tiiuae-falcon-7b.csv
Reading wiki_val_mistralai-Mistral-7B-Instruct-v0-2.csv
Reading wiki_test_gpt2.csv
Reading wiki_test_gpt2-xl.csv
Reading wiki_test_microsoft-phi-2.csv
Reading wiki_test_tiiuae-falcon-7b.csv
Reading wiki_test_mistralai-Mistral-7B-Instruct-v0-2.csv
Reading gsm8k_train_gpt2.csv
Reading gsm8k_train_gpt2-xl.csv
Reading gsm8k_train_microsoft-phi-2.csv
Reading gsm8k_train_tiiuae-falcon-7b.csv
Reading gsm8k_train_mistralai-Mistral-7B-Instruct-v0-2.csv
Reading gsm8k_val_gpt2.csv
Reading gsm8k_val_gpt2-xl.csv
Reading gsm8k_val_microsoft-phi-2.csv
Reading gsm8k_val_tiiuae-falco

In [8]:
df["class"] = df["model"].map(CLASS_TO_IDX)

In [9]:
df

Unnamed: 0,dataset,split,original_text,model_input,model_output,model,class
0,wiki,train,A -wide meteorite impact crater is located in ...,A -wide meteorite impact crater,"A -wide meteorite impact crater, which has bee...",gpt2,0
1,wiki,train,A 2008 study found that this anthropogenic cha...,A 2008 study found that,A 2008 study found that the majority of the pe...,gpt2,0
2,wiki,train,A 60-gun ship of that name served at the Battl...,A 60-gun ship of that,"A 60-gun ship of that caliber, carrying 300 me...",gpt2,0
3,wiki,train,A Centers for Disease Control and Prevention s...,A Centers for Disease Control,A Centers for Disease Control and Prevention s...,gpt2,0
4,wiki,train,A broad categorisation can be made between aim...,A broad categorisation can be,A broad categorisation can be achieved by usin...,gpt2,0
...,...,...,...,...,...,...,...
62,gsm8k,test,Jack had $100. Sophia gave him 1/5 of her $100...,Jack had $100. Sophia gave him 1/5 of her $100...,Jack had $100. Sophia gave him 1/5 of her $100...,mistralai-Mistral-7B-Instruct-v0-2,4
63,gsm8k,test,Mike bought 5 face masks while Johnny bought 2...,Mike bought 5 face masks while Johnny bought 2...,Mike bought 5 face masks while Johnny bought 2...,mistralai-Mistral-7B-Instruct-v0-2,4
64,gsm8k,test,Digimon had its 20th anniversary. When it cam...,Digimon had its 20th anniversary. When it cam...,Digimon had its 20th anniversary. When it cam...,mistralai-Mistral-7B-Instruct-v0-2,4
65,gsm8k,test,Sally received the following scores on her mat...,Sally received the following scores on her mat...,Sally received the following scores on her mat...,mistralai-Mistral-7B-Instruct-v0-2,4


In [10]:
tokenizer = BertTokenizer.from_pretrained(SEQ_CLF_MODEL)



In [160]:
class LLMAttribDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

In [161]:
class CustomCallback(TrainerCallback):
    def __init__(self, trainer) -> None:
        super().__init__()
        self._trainer = trainer
    
    def on_step_end(self, args, state, control, **kwargs):
        if control.should_evaluate:
            control_copy = deepcopy(control)
            self._trainer.evaluate(eval_dataset=self._trainer.train_dataset, metric_key_prefix="train_")
            return control_copy

In [162]:
def get_dataset(dataset, split):
    if dataset=="all":
        temp_df = df[(df["split"]==split)]
    else:
        temp_df = df[(df["dataset"]==dataset) & (df["split"]==split)]
#     print(temp_df)
    clf_input_text = temp_df["model_output"].to_list()
    encodings = tokenizer(clf_input_text, truncation=True, padding="max_length", max_length=MAX_LENGTH)
    clf_label = temp_df["class"].to_list()
    pt_ds = LLMAttribDataset(encodings, clf_label)
    print(f"Created -> Dataset: {dataset} | Split: {split} | Len: {pt_ds.__len__()}")
    return pt_ds

In [163]:
def compute_metrics(p):
    preds = p.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(p.label_ids, preds, average='macro')
    acc = accuracy_score(p.label_ids, preds)
    return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}

In [172]:
def plot_training_loss(log_history):
#     train_state_df = pd.DataFrame(trainer.state.log_history, columns=["epoch", "step", "loss", "eval_loss"])
    training_loss = [log["loss"] for log in log_history if "loss" in log and "step" in log]
    validation_loss = [log["eval_loss"] for log in log_history if "eval_loss" in log]
    end_step = min(len(training_loss), len(validation_loss))
    training_loss = training_loss[:end_step]
    validation_loss = validation_loss[:end_step]
    
    steps = [log["step"] for log in log_history if "loss" in log][:end_step]
    
    print(len(training_loss), len(validation_loss), len(steps))
    
    # plot loss
    plt.plot(steps, training_loss, label="Training Loss", color='blue')
    plt.plot(steps, validation_loss, label="Validation Loss", color='orange')
    
    plt.xlabel("Steps")
    plt.ylabel("Loss")
    plt.title("Training and Validation Loss Over Time")
    plt.grid(True)
    plt.legend()
    plt.show()
    

In [173]:
def plot_confusion_mat(cm):
    plt.figure(figsize=(8, 4))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=TICKS, yticklabels=TICKS)
    plt.ylabel('Actual')
    plt.xlabel('Predicted')
    plt.title('Confusion Matrix')
    plt.show()

In [None]:
# get dataset
pt_dataset = defaultdict(dict)
    for ds in DATASETS+["all"]:
        for split in SPLITS:
            pt_dataset[ds][split] = get_dataset(ds, split)            

In [176]:
# Experiments

ds = "gsm8k"
epochs = 1
out_dir = f"{ds}-{MAX_LENGTH}-{epochs}"

print(ds)
print(MAX_LENGTH)
print(epochs)
print(out_dir)

def run_experiment():
    
    print(f"Allocated memory: {torch.cuda.memory_allocated() / 1024 ** 2} MB")
    print(f"Cached memory: {torch.cuda.memory_reserved() / 1024 ** 2} MB")
    torch.cuda.empty_cache()
    print("After emptying cache:")
    print(f"Allocated memory: {torch.cuda.memory_allocated() / 1024 ** 2} MB")
    print(f"Cached memory: {torch.cuda.memory_reserved() / 1024 ** 2} MB")

    print(f"Loading mode: {SEQ_CLF_MODEL}...")
    model = BertForSequenceClassification.from_pretrained(SEQ_CLF_MODEL, num_labels=NUM_CLASSES)
    print("Model loaded.")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    model.to(device)

    training_args = TrainingArguments(
        output_dir=f"./results/{out_dir}",
        logging_dir='./logs',
        per_device_train_batch_size=32,
        per_device_eval_batch_size=16,
        num_train_epochs=epochs,
        weight_decay=0.01,
        eval_strategy="steps",
        logging_steps=50,
        eval_steps=50,
        save_steps=50,     # Save the model every 500 steps
        save_total_limit=2,  
        report_to="none",  # Disable report to Weights & Biases
        load_best_model_at_end=True,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=pt_dataset[ds]["train"],
        eval_dataset=pt_dataset[ds]["val"],
        compute_metrics=compute_metrics,
    )
    print("Adding Custom Callback")
    trainer.add_callback(CustomCallback(trainer)) 
    # return trainer    
    print("Start training...")
    train_results = trainer.train()
    print("Done training...")

    print("Saving trainer state...")
    trainer.save_state()
    print("Saving best model...")
    trainer.save_model(output_dir=f"./results/{out_dir}/best_model")
    
    print("Train set:")
    print(train_results.metrics)
    trainer.log_metrics("train", train_results.metrics)
    trainer.save_metrics("train", train_results.metrics)
    print("Validation set:")
    val_results = trainer.evaluate()
    print(val_results)
    trainer.log_metrics("eval", val_results)
    trainer.save_metrics("eval", val_results)
    print("Test set:")
    test_results = trainer.predict(test_dataset=pt_dataset[ds]["test"])
    print(test_results.metrics)
    trainer.log_metrics("test", test_results.metrics)
    trainer.save_metrics("test", test_results.metrics)
    
    # plots
    plot_training_loss(trainer.state.log_history)
    
    TICKS = ['gpt2',
             'gpt2-xl',
             'phi-2',
             'falcon-7b',
             'mistral-7B']
    predictions = np.argmax(test_results.predictions, axis=1)
    cm = confusion_matrix(test_results.label_ids, predictions)
    plot_confusion_mat(cm)
    
    print("All done!")
    
    return trainer

gsm8k
256
1
gsm8k-256-1


In [177]:
trainer = run_experiment()

Allocated memory: 4979.15234375 MB
Cached memory: 10408.0 MB
After emptying cache:
Allocated memory: 4979.15234375 MB
Cached memory: 5626.0 MB
Loading mode: bert-base-cased...


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded.
Using device: cuda
Adding Custom Callback


In [179]:
print("Test set:")
test_results = trainer.predict(test_dataset=pt_dataset[ds]["test"])
print(test_results.metrics)
trainer.log_metrics("test", test_results.metrics)
trainer.save_metrics("test", test_results.metrics)

Test set:


  with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):


  _warn_prf(average, modifier, msg_start, len(result))


{'test_loss': 1.7380963563919067, 'test_accuracy': 0.18208955223880596, 'test_f1': 0.08085971285244767, 'test_precision': 0.05448897173035104, 'test_recall': 0.18208955223880596, 'test_runtime': 2.9895, 'test_samples_per_second': 112.058, 'test_steps_per_second': 3.68}
***** test metrics *****
  test_accuracy           =     0.1821
  test_f1                 =     0.0809
  test_loss               =     1.7381
  test_precision          =     0.0545
  test_recall             =     0.1821
  test_runtime            = 0:00:02.98
  test_samples_per_second =    112.058
  test_steps_per_second   =       3.68


In [44]:
!zip -r wiki-256-noshuffle.zip /kaggle/working/results/wiki-256

  pid, fd = os.forkpty()


  adding: kaggle/working/results/wiki-256/ (stored 0%)
  adding: kaggle/working/results/wiki-256/checkpoint-300/ (stored 0%)
  adding: kaggle/working/results/wiki-256/checkpoint-300/trainer_state.json (deflated 73%)
  adding: kaggle/working/results/wiki-256/checkpoint-300/model.safetensors (deflated 7%)
  adding: kaggle/working/results/wiki-256/checkpoint-300/rng_state.pth (deflated 25%)
  adding: kaggle/working/results/wiki-256/checkpoint-300/optimizer.pt (deflated 13%)
  adding: kaggle/working/results/wiki-256/checkpoint-300/scheduler.pt (deflated 56%)
  adding: kaggle/working/results/wiki-256/checkpoint-300/training_args.bin (deflated 51%)
  adding: kaggle/working/results/wiki-256/checkpoint-300/config.json (deflated 53%)
  adding: kaggle/working/results/wiki-256/checkpoint-860/ (stored 0%)
  adding: kaggle/working/results/wiki-256/checkpoint-860/trainer_state.json (deflated 78%)
  adding: kaggle/working/results/wiki-256/checkpoint-860/model.safetensors (deflated 7%)
  adding: kaggl

In [45]:
from IPython.display import FileLink
FileLink(r'wiki-256-noshuffle.zip')