In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_id = "mistralai/Mistral-7B-v0.1"

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(model_id)

Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00,  2.72s/it]


In [3]:
tokenizer.special_tokens_map

{'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>'}

In [4]:
special_token_dict = {"additional_special_tokens": ["[No Retrieval]", "[Retrieval]", "[Continue to Use Evidence]", "[Irrelevant]", "[Relevant]", "<paragraph>", "</paragraph>", "[Utility:1]", "[Utility:2]", "[Utility:3]", "[Utility:4]", "[Utility:5]", "[Fully supported]", "[Partially supported]", "[No support / Contradictory]"]}
special_token_dict["pad_token"] = "<pad>"
num_added_tokens = tokenizer.add_special_tokens(special_token_dict)
print(num_added_tokens)

context_markups = []
for token in ["<paragraph>", "</paragraph>"]:
    context_markups.append(tokenizer.convert_tokens_to_ids(token))

embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
    model.resize_token_embeddings(len(tokenizer))

print(model.get_input_embeddings().weight.shape[0])

16
32016


In [5]:
from functools import partial
from self_rag.retrieval_lm.finetune import encode_with_prompt_completion_format

max_seq_length = model.config.sliding_window

encode_function = partial(
    encode_with_prompt_completion_format,
    tokenizer=tokenizer,
    max_seq_length=max_seq_length,
    context_markups=context_markups
)

In [6]:
from datasets import load_dataset

dataset = load_dataset("selfrag/selfrag_train_data")

In [7]:
lm_datasets = dataset.map(
    encode_function,
    batched=False,
    num_proc=32,
    remove_columns=[name for name in dataset["train"].column_names if name not in ["input_ids", "labels", "attention_mask"]],
    desc="Tokenizing and reformatting instruction data",
)

In [8]:
lm_datasets.set_format(type="pt")

In [9]:
lm_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'labels', 'attention_mask'],
        num_rows: 145619
    })
})

In [10]:
lm_datasets = lm_datasets.filter(lambda example: (example['labels'] != -100).any())

In [18]:
train_dataset = lm_datasets["train"]

In [29]:
from torch.utils.data import DataLoader
from transformers import DataCollatorForSeq2Seq

train_dataloader = DataLoader(
    train_dataset, 
    shuffle=True, 
    collate_fn=DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding="longest"),
    batch_size=2
)

In [30]:
batch = next(iter(train_dataloader))

print(batch["input_ids"].shape)
print(batch["labels"].shape)

torch.Size([2, 226])
torch.Size([2, 226])


In [None]:
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)