# Fine tuning starcoder-3b for tex-generation!

First we need some installs and imports

In [None]:
!pip install datasets accelerate bitsandbytes wandb huggingface_hub

In [None]:
import os
import transformers
import torch
import bitsandbytes
import accelerate

from torch.utils.data import IterableDataset
from datasets import load_dataset
from transformers import  AutoConfig, AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, logging, set_seed
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from tqdm import tqdm
from accelerate import Accelerator

device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

MODEL_NAME='bigcode/starcoderbase-1b'

## Loading Dataset and DataLoaders

In [None]:
data=load_dataset('loganbarnhart/text-to-tex')
train_ratio = 0.95
data = data["train"].train_test_split(train_size=train_ratio)
train_data = data['train']
val_data = data['test']

Downloading readme:   0%|          | 0.00/571 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.91M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [None]:
def chars_token_ratio(dataset, tokenizer, input_column_name="text", nb_examples=400):
    """
    Estimate the average number of characters per token in the dataset.
    """
    total_characters, total_tokens = 0, 0
    for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
        text = example[input_column_name]
        total_characters += len(text)
        if tokenizer.is_fast:
            total_tokens += len(tokenizer(text).tokens())
        else:
            total_tokens += len(tokenizer.tokenize(text))

    return total_characters / total_tokens

class ConstantLengthDataset(IterableDataset):
    """
    Iterable dataset that returns constant length chunks of tokens from stream of text files.
        Args:
            tokenizer (Tokenizer): The processor used for proccessing the data.
            dataset (dataset.Dataset): Dataset with text files.
            infinite (bool): If True the iterator is reset after dataset reaches end else stops.
            seq_length (int): Length of token sequences to return.
            num_of_sequences (int): Number of token sequences to keep in buffer.
            chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.
    """

    def __init__(
        self,
        tokenizer,
        dataset,
        infinite=False,
        seq_length=1024,
        num_of_sequences=1024,
        chars_per_token=3.6,
        input_column_name="text",
    ):
        self.tokenizer = tokenizer
        self.concat_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else args.eos_token_id
        self.dataset = dataset
        self.seq_length = seq_length
        self.infinite = infinite
        self.current_size = 0
        self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
        self.input_column_name = input_column_name

    def __iter__(self):
        iterator = iter(self.dataset)
        more_examples = True
        while more_examples:
            buffer, buffer_len = [], 0
            while True:
                if buffer_len >= self.max_buffer_size:
                    break
                try:
                    buffer.append(next(iterator)[self.input_column_name])
                    buffer_len += len(buffer[-1])
                except StopIteration:
                    if self.infinite:
                        iterator = iter(self.dataset)
                    else:
                        more_examples = False
                        break
            tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
            all_token_ids = []
            for tokenized_input in tokenized_inputs:
                all_token_ids.extend(tokenized_input + [self.concat_token_id])
            for i in range(0, len(all_token_ids), self.seq_length):
                input_ids = all_token_ids[i : i + self.seq_length]
                if len(input_ids) == self.seq_length:
                    self.current_size += 1
                    yield {
                        "input_ids": torch.LongTensor(input_ids),
                        "labels": torch.LongTensor(input_ids),
                    }


## Training loop and args

In [None]:
seed=0
output_dir='./checkpoints'
max_steps=5000
eval_freq=50
save_freq=300
log_freq=50
batch_size=8
learning_rate=1e-5
lr_scheduler_type='cosine'
num_warmup_steps=100
gradient_accumulation_steps=8
no_gradient_checkpointing=False
bf16=False
weight_decay=0.05

In [None]:
def run_training(model, train_data, val_data):
    train_data.start_iteration = 0

    print("Starting main loop")

    training_args = TrainingArguments(
        output_dir=output_dir,
        dataloader_drop_last=True,
        evaluation_strategy="steps",
        save_strategy="steps",
        load_best_model_at_end=True,
        max_steps=max_steps,
        eval_steps=eval_freq,
        save_steps=save_freq,
        logging_steps=log_freq,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        learning_rate=learning_rate,
        lr_scheduler_type=lr_scheduler_type,
        warmup_steps=num_warmup_steps,
        gradient_accumulation_steps=gradient_accumulation_steps,
        gradient_checkpointing=not no_gradient_checkpointing,
        bf16=bf16,
        weight_decay=weight_decay,
        run_name="StarCoder-finetuned",
        report_to="wandb",
        ddp_find_unused_parameters=False,
    )

    trainer = Trainer(model=model, args=training_args, train_dataset=train_data, eval_dataset=val_data)

    print("Training...")
    trainer.train()

    print("Saving last checkpoint of the model")
    model.save_pretrained(os.path.join(output_dir, "final_checkpoint/"))

## Training:

In [None]:
set_seed(seed)
os.makedirs(output_dir, exist_ok=True)
logging.set_verbosity_error()

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=True)
chars_per_token = chars_token_ratio(train_data, tokenizer)

train_dataset = ConstantLengthDataset(
    tokenizer,
    train_data,
    chars_per_token=chars_per_token
    )
eval_dataset = ConstantLengthDataset(
    tokenizer,
    val_data,
    chars_per_token=chars_per_token
    )



tokenizer_config.json:   0%|          | 0.00/677 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/777k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/442k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.06M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/532 [00:00<?, ?B/s]

100%|██████████| 400/400 [00:00<00:00, 1620.05it/s]


In [None]:
model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        use_auth_token=True,
        use_cache=not no_gradient_checkpointing,
        # load_in_8bit=True, figure out why this throws bitsandbytes error
        device_map={"": Accelerator().process_index},
    )



config.json:   0%|          | 0.00/1.05k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/4.55G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

In [None]:
### wandb api key: 62c9b64949264b620a9e1e79d3aca14d039d4824

In [None]:
run_training(model, train_dataset, eval_dataset)

In [None]:
model.save_pretrained(os.path.join(output_dir, "final_checkpoint/"))