In [None]:
!pip install -e ./sparseml[transformers]

In [1]:
from datasets import load_dataset

dataset = load_dataset("HuggingFaceH4/ultrachat_200k")

  table = cls._concat_blocks(blocks, axis=0)


In [3]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "HuggingFaceH4/mistral-7b-sft-beta"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(model_id)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [8]:
print(tokenizer.chat_template)

{% for message in messages %}
{% if message['role'] == 'user' %}
{{ '<|user|>
' + message['content'] + eos_token }}
{% elif message['role'] == 'system' %}
{{ '<|system|>
' + message['content'] + eos_token }}
{% elif message['role'] == 'assistant' %}
{{ '<|assistant|>
'  + message['content'] + eos_token }}
{% endif %}
{% if loop.last and add_generation_prompt %}
{{ '<|assistant|>' }}
{% endif %}
{% endfor %}


In [None]:
dataset = dataset["train_sft"]

In [10]:
def add_system_prompt(batch):
    system_prompt = {
        "content": "You are a friendly chatbot",
        "role": "system"
    }

    updated_messages = []
    for element in batch["messages"]:
        updated_messages.append([system_prompt] + element)

    return {"messages_with_sys_prompt": updated_messages}

dataset = dataset.map(
    add_system_prompt,
    batched=True,
    num_proc=32,
    batch_size=1000,
)

  table = cls._concat_blocks(blocks, axis=0)


In [20]:
from transformers import DataCollatorForLanguageModeling
from typing import List, Union, Any, Dict
import warnings
import numpy as np

class DataCollatorForChatLM(DataCollatorForLanguageModeling):
    def __init__(
        self,
        response_template: List[int],
        instruction_template: List[int],
        *args,
        **kwargs,
    ):
        super().__init__(*args, mlm=False, **kwargs)

        self.instruction_token_ids = instruction_template
        self.response_token_ids = response_template
        self.ignore_index = -100

    def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
        batch = super().torch_call(examples)

        for i in range(len(examples)):
            response_token_ids_idxs = []
            human_token_ids_idxs = []

            for assistant_idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]:
                # find the indexes of the start of a response.
                if (
                    self.response_token_ids
                    == batch["labels"][i][assistant_idx : assistant_idx + len(self.response_token_ids)].tolist()
                ):
                    response_token_ids_idxs.append(assistant_idx + len(self.response_token_ids))

            if len(response_token_ids_idxs) == 0:
                warnings.warn(
                    f"Could not find response key `{self.response_token_ids}` in the "
                    f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} '
                    f"This instance will be ignored in loss calculation. "
                    f"Note, if this happens often, consider increasing the `max_seq_length`."
                )
                batch["labels"][i, :] = self.ignore_index

            human_token_ids = self.instruction_token_ids
            for human_idx in np.where(batch["labels"][i] == human_token_ids[0])[0]:
                # find the indexes of the start of a human answer.
                if human_token_ids == batch["labels"][i][human_idx : human_idx + len(human_token_ids)].tolist():
                    human_token_ids_idxs.append(human_idx)

            if len(human_token_ids_idxs) == 0:
                warnings.warn(
                    f"Could not find instruction key `{self.instruction_token_ids}` in the "
                    f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} '
                    f"This instance will be ignored in loss calculation. "
                    f"Note, if this happens often, consider increasing the `max_seq_length`."
                )
                batch["labels"][i, :] = self.ignore_index

            for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)):
                print(f"start = {start} / end = {end}")
                assert start < end
                # Make pytorch loss function ignore all non response tokens
                if idx != 0:
                    batch["labels"][i, start:end] = self.ignore_index
                else:
                    batch["labels"][i, :end] = self.ignore_index
            
            # assert follows user // assistant back + forth with equal number of query // response pairs
            assert len(response_token_ids_idxs) == len(human_token_ids_idxs)
            print(human_token_ids_idxs[0])
            # mask out everything before the first user prompt (i.e. the system prompt)
            batch["labels"][i, :human_token_ids_idxs[0]] = self.ignore_index
            
        return batch

In [13]:
response_template_ids = tokenizer.encode("/n<|assistant|>", add_special_tokens=False)[2:]
print(response_template_ids)
tokenizer.decode(response_template_ids)

[28789, 28766, 489, 11143, 28766, 28767]


'<|assistant|>'

In [14]:
instruction_template_ids = tokenizer.encode("/n<|user|>", add_special_tokens=False)[2:]
print(instruction_template_ids)
tokenizer.decode(instruction_template_ids)

[28789, 28766, 1838, 28766, 28767]


'<|user|>'

In [47]:
from functools import partial
from trl import DataCollatorForCompletionOnlyLM

# collator = DataCollatorForChatLM(
collator = DataCollatorForCompletionOnlyLM(
    response_template=response_template_ids,
    instruction_template=instruction_template_ids,
    tokenizer=tokenizer,
)

def apply_chat_template(tokenizer, messages_col, batch):
    strs = []
    for example in batch[messages_col]:
        strs.append(tokenizer.apply_chat_template(example, tokenize=False))

    return strs

chat_formatting_func = partial(apply_chat_template, tokenizer, "messages_with_sys_prompt")

In [48]:
def tokenize(element):
    chat_format = chat_formatting_func(element)

    outputs = tokenizer(
        chat_format,
        truncation=True,
        padding=False,
        max_length=2048,
        return_overflowing_tokens=False,
        return_length=False,
    )

    return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"], "chat_format": chat_format}

tokenized_dataset = dataset.select(range(10000)).map(
    tokenize,
    batched=True,
    num_proc=32,
    batch_size=10,
)

Map (num_proc=32):   0%|          | 0/10000 [00:00<?, ? examples/s]

  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
  return cls._concat_blocks(pa_tables_to_concat_

In [49]:
tokenized_dataset_removed = tokenized_dataset.remove_columns([col for col in tokenized_dataset.column_names if col not in ["input_ids", "attention_mask"]])

In [50]:
from torch.utils.data import DataLoader

dataloader_params = {
    "batch_size": 1,
    "collate_fn": collator,
}

dataloader = DataLoader(tokenized_dataset_removed, **dataloader_params)

batch = next(iter(dataloader))

In [51]:
import torch
print(torch.tensor(tokenized_dataset[0]["input_ids"][:100]))
print(tokenized_dataset[0]["chat_format"])

tensor([    1,   523, 28766,  6574, 28766, 28767,    13,  1976,   460,   264,
        10131, 10706, 10093,     2, 28705,    13, 28789, 28766,  1838, 28766,
        28767,    13, 18171, 11382,  5580,   298,  4211, 28733,  5527, 18978,
          325,  1146, 13532,   495, 28705, 28784, 28723, 28734, 28806, 28725,
         8337,  1380, 28705, 28781, 28723, 28734, 28806, 28725,  2316,   455,
          897, 28705, 28770, 28723, 28734, 28806,  6372,  1798, 28705, 28750,
        28723, 28734, 28806, 28725,   351,   598, 16712, 28705, 28782, 28723,
        28734, 28806,   609,  1824,  7335,  2751,   837,   315,  1413, 28804,
           13,  2486,   574, 27395,  6718,   567, 22114, 28715, 27395, 12458,
        28725,   368,   541,  5061,  1347,   272, 13461,  3469,   302,   264])
<|system|>
You are a friendly chatbot</s>
<|user|>
These instructions apply to section-based themes (Responsive 6.0+, Retina 4.0+, Parallax 3.0+ Turbo 2.0+, Mobilia 5.0+). What theme version am I using?
On your Collecti

In [44]:
print(tokenizer.encode("\n<|assistant|>\nThis feature only applies to Collection pages and Featured Collections sections of the section-based themes listed in the text material."))

[1, 28705, 13, 28789, 28766, 489, 11143, 28766, 28767, 13, 3260, 4480, 865, 15588, 298, 13079, 6718, 304, 22114, 28715, 27395, 12458, 302, 272, 4211, 28733, 5527, 18978, 9206, 297, 272, 2245, 3388, 28723]


In [46]:
print(tokenizer.encode("\nThis feature only applies to Collection pages"))

[1, 28705, 13, 3260, 4480, 865, 15588, 298, 13079, 6718]


In [57]:
tokenizer.decode([13,  3260,  4480,   865, 15588,   298, 13079,  6718,   304, 22114, 28715, 27395, 12458,   302,   272,  4211, 28733,  5527, 18978,  9206,   297,   272,  2245,  3388, 28723,     2])

'\nThis feature only applies to Collection pages and Featured Collections sections of the section-based themes listed in the text material.</s>'

In [62]:
labels = batch["labels"][0]

for idx, label in enumerate(labels):
    if label == -100:
        labels[idx]=1


In [67]:
print(tokenized_dataset[0]["chat_format"])

<|system|>
You are a friendly chatbot</s>
<|user|>
These instructions apply to section-based themes (Responsive 6.0+, Retina 4.0+, Parallax 3.0+ Turbo 2.0+, Mobilia 5.0+). What theme version am I using?
On your Collections pages & Featured Collections sections, you can easily show the secondary image of a product on hover by enabling one of the theme's built-in settings!
Your Collection pages & Featured Collections sections will now display the secondary product image just by hovering over that product image thumbnail.
Does this feature apply to all sections of the theme or just specific ones as listed in the text material?</s>
<|assistant|>
This feature only applies to Collection pages and Featured Collections sections of the section-based themes listed in the text material.</s>
<|user|>
Can you guide me through the process of enabling the secondary image hover feature on my Collection pages and Featured Collections sections?</s>
<|assistant|>
Sure, here are the steps to enable the se

In [66]:
print(tokenizer.decode(labels))

<s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s>
This feature only applies to Collection pages and Featured Collections sections of the section-based themes listed in the text material.<s> 
<s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s>
Sure, here are the steps to enable the secondary image hover feature on your Collection pages and Featured Collections sections:

1. Log in to your Shopify account and go to your Online Store.
2. Click on Customize the

In [69]:
from trl import SFTTrainer

sft_trainer = SFTTrainer(
    model=model,
    train_dataset=dataset.select(range(10000)),
    data_collator=collator,
    formatting_func=chat_formatting_func,
    max_seq_length=2048
)

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 250.00 MiB (GPU 0; 15.74 GiB total capacity; 0 bytes already allocated; 201.19 MiB free; 0 bytes reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [195]:
batch = next(iter(sft_trainer.get_train_dataloader()))

start = 1648 / end = 2186
start = 2263 / end = 2322
start = 2442 / end = 2510
start = 2659 / end = 2715
start = 16 / end = 47
start = 852 / end = 901
start = 1498 / end = 1563
start = 2142 / end = 2211
start = 1524 / end = 2246
start = 2379 / end = 2426
start = 2610 / end = 2642
start = 1482 / end = 1574
start = 2581 / end = 2647
start = 910 / end = 994
start = 1712 / end = 1753
start = 2044 / end = 2092
start = 2384 / end = 2448
start = 2008 / end = 2091
start = 2542 / end = 2575
start = 1882 / end = 2389
start = 2520 / end = 2558
start = 2699 / end = 2748
start = 1077 / end = 1171
start = 1761 / end = 1818
start = 2098 / end = 2141
start = 2430 / end = 2487
start = 828 / end = 855
start = 1711 / end = 1749
start = 2030 / end = 2099
start = 1423 / end = 1567
start = 2223 / end = 2266
start = 2416 / end = 2471
start = 1621 / end = 1696
start = 1858 / end = 1920
start = 2118 / end = 2179
start = 2465 / end = 2535
start = 1043 / end = 1099
start = 1371 / end = 1408
start = 1749 / end = 1

In [196]:
print(batch["labels"][0][0:1000])
print(batch["labels"][0][1000:2000])
print(batch["labels"][0][2000:])

tensor([-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -1

In [114]:
print(tokenized_dataset[0]["input_ids"])

[1, 523, 28766, 1838, 28766, 28767, 13, 18171, 11382, 5580, 298, 4211, 28733, 5527, 18978, 325, 1146, 13532, 495, 28705, 28784, 28723, 28734, 28806, 28725, 8337, 1380, 28705, 28781, 28723, 28734, 28806, 28725, 2316, 455, 897, 28705, 28770, 28723, 28734, 28806, 6372, 1798, 28705, 28750, 28723, 28734, 28806, 28725, 351, 598, 16712, 28705, 28782, 28723, 28734, 28806, 609, 1824, 7335, 2751, 837, 315, 1413, 28804, 13, 2486, 574, 27395, 6718, 567, 22114, 28715, 27395, 12458, 28725, 368, 541, 5061, 1347, 272, 13461, 3469, 302, 264, 2093, 356, 18848, 486, 25748, 624, 302, 272, 7335, 28742, 28713, 4429, 28733, 262, 6472, 28808, 13, 11159, 13079, 6718, 567, 22114, 28715, 27395, 12458, 622, 1055, 4249, 272, 13461, 2093, 3469, 776, 486, 18848, 288, 754, 369, 2093, 3469, 15762, 21418, 28723, 13, 20510, 456, 4480, 5580, 298, 544, 12458, 302, 272, 7335, 442, 776, 2948, 4413, 390, 9206, 297, 272, 2245, 3388, 28804, 2, 28705, 13, 28789, 28766, 489, 11143, 28766, 28767, 13, 3260, 4480, 865, 15588, 298, 

In [123]:
tokenized_dataset_removed = tokenized_dataset.remove_columns([col for col in tokenized_dataset.column_names if col not in ["input_ids", "attention_mask"]])

In [124]:
tokenized_dataset_removed

Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 1000
})

In [125]:
tokenized_dataset

Dataset({
    features: ['prompt', 'prompt_id', 'messages', 'input_ids', 'attention_mask', 'chat_format'],
    num_rows: 1000
})

In [140]:
from torch.utils.data import DataLoader

dataloader = DataLoader(tokenized_dataset_removed, **dataloader_params)

batch = next(iter(dataloader))

In [106]:
batch

{'input_ids': tensor([[    1, 26588, 28725,  ...,     2, 28705,    13],
        [    1,  9305, 28731,  ...,     2, 28705,    13],
        [    1, 18413,   302,  ...,     2, 28705,    13],
        ...,
        [    2,     2,     2,  ...,     2, 28705,    13],
        [    1,   829,   427,  ...,     2, 28705,    13],
        [    2,     2,     2,  ...,     2, 28705,    13]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [0, 0, 0,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1]]), 'labels': tensor([[ -100,  -100,  -100,  ...,  -100, 28705,    13],
        [ -100,  -100,  -100,  ...,  -100, 28705,    13],
        [ -100,  -100,  -100,  ...,  -100, 28705,    13],
        ...,
        [ -100,  -100,  -100,  ...,  -100, 28705,    13],
        [ -100,  -100,  -100,  ...,  -100, 28705,    13],
        [ -100,  -100,  -100,  ...,  -100, 28705,    13]])}

In [98]:
idx = 351
collator.response_token_ids == batch["input_ids"][0][idx : idx + len(collator.response_token_ids)].tolist()

True

In [96]:
batch["input_ids"][0][351:351+10]

tensor([28789, 28766,   489, 11143, 28766, 28767,    13, 13617,   404,   356])

In [91]:
(batch["input_ids"][0][:1000] == 28789)

tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, 

In [50]:
dataset["train_sft"].select(range(1000))

Dataset({
    features: ['prompt', 'prompt_id', 'messages'],
    num_rows: 1000
})

In [60]:
tokenized_dataset[0]["chat_format"]

"<|user|>\nThese instructions apply to section-based themes (Responsive 6.0+, Retina 4.0+, Parallax 3.0+ Turbo 2.0+, Mobilia 5.0+). What theme version am I using?\nOn your Collections pages & Featured Collections sections, you can easily show the secondary image of a product on hover by enabling one of the theme's built-in settings!\nYour Collection pages & Featured Collections sections will now display the secondary product image just by hovering over that product image thumbnail.\nDoes this feature apply to all sections of the theme or just specific ones as listed in the text material?</s>\n<|assistant|>\nThis feature only applies to Collection pages and Featured Collections sections of the section-based themes listed in the text material.</s>\n<|user|>\nCan you guide me through the process of enabling the secondary image hover feature on my Collection pages and Featured Collections sections?</s>\n<|assistant|>\nSure, here are the steps to enable the secondary image hover feature on 

In [63]:
response_template_ids

[1, 732, 28711, 28789, 28766, 489, 11143, 28766, 28767]

In [62]:
print(tokenized_dataset[0]["input_ids"])

[1, 523, 28766, 1838, 28766, 28767, 13, 18171, 11382, 5580, 298, 4211, 28733, 5527, 18978, 325, 1146, 13532, 495, 28705, 28784, 28723, 28734, 28806, 28725, 8337, 1380, 28705, 28781, 28723, 28734, 28806, 28725, 2316, 455, 897, 28705, 28770, 28723, 28734, 28806, 6372, 1798, 28705, 28750, 28723, 28734, 28806, 28725, 351, 598, 16712, 28705, 28782, 28723, 28734, 28806, 609, 1824, 7335, 2751, 837, 315, 1413, 28804, 13, 2486, 574, 27395, 6718, 567, 22114, 28715, 27395, 12458, 28725, 368, 541, 5061, 1347, 272, 13461, 3469, 302, 264, 2093, 356, 18848, 486, 25748, 624, 302, 272, 7335, 28742, 28713, 4429, 28733, 262, 6472, 28808, 13, 11159, 13079, 6718, 567, 22114, 28715, 27395, 12458, 622, 1055, 4249, 272, 13461, 2093, 3469, 776, 486, 18848, 288, 754, 369, 2093, 3469, 15762, 21418, 28723, 13, 20510, 456, 4480, 5580, 298, 544, 12458, 302, 272, 7335, 442, 776, 2948, 4413, 390, 9206, 297, 272, 2245, 3388, 28804, 2, 28705, 13, 28789, 28766, 489, 11143, 28766, 28767, 13, 3260, 4480, 865, 15588, 298, 