#### Finetuning Llama-3 8B

In [None]:
# Import libraries
import os
import gc
import threading
import psutil
import json
import copy
import random
from util_fnx import b2mb, TorchTracemalloc, DataCollatorForInstructionTuning, IGNORE_INDEX
from pathlib import Path
import argparse
import logging
import math
from typing import List, Dict
from dataclasses import dataclass
import torch
from torch.utils.data import DataLoader
from torch.nn.utils import rnn
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
import datasets
from datasets import load_dataset, load_from_disk
from huggingface_hub import Repository, create_repo
from peft import AutoPeftModelForCausalLM, LoraConfig, TaskType, get_peft_model
import transformers
from transformers import (
    AutoTokenizer,
    PreTrainedTokenizerBase,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    SchedulerType,
    get_scheduler,
)
from tqdm.auto import tqdm
from transformers import PreTrainedTokenizerBase

logger = get_logger(__name__)

device = "cuda" if torch.cuda.is_available() else "cpu"

  from .autonotebook import tqdm as notebook_tqdm


In [1]:
from datasets import load_dataset, DatasetDict

ds = load_dataset("kowndinya23/flan2022")

In [2]:
# Split the 'cot' subset into train and test sets
cot_train_test_split = ds['cot'].train_test_split(test_size=0.2)

# Create a new DatasetDict object for 'cot' with train and test splits
cot_dataset = DatasetDict({
    'train': cot_train_test_split['train'],
    'test': cot_train_test_split['test']
})

# Update the original dataset dictionary to include the new DatasetDict for 'cot'
ds['cot'] = cot_dataset

In [4]:
from datasets import load_dataset, DatasetDict

# Load the dataset
ds = load_dataset("kowndinya23/flan2022")

# Create a new DatasetDict to hold the train and test splits
split_ds = DatasetDict()

# Iterate over each subset in ds and split into train and test sets
for subset in ds:
    train_test_split = ds[subset].train_test_split(test_size=0.2)
    split_ds[subset] = DatasetDict({
        'train': train_test_split['train'],
        'test': train_test_split['test']
    })

# Save the split dataset to disk
split_ds.save_to_disk('flan2022')


Saving the dataset (15/15 shards): 100%|██████████| 4289888/4289888 [00:29<00:00, 144233.37 examples/s]
Saving the dataset (4/4 shards): 100%|██████████| 1072473/1072473 [00:06<00:00, 156708.12 examples/s]
Saving the dataset (8/8 shards): 100%|██████████| 1320246/1320246 [00:09<00:00, 140141.95 examples/s]
Saving the dataset (2/2 shards): 100%|██████████| 330062/330062 [00:02<00:00, 138829.20 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 147078/147078 [00:00<00:00, 197763.19 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 36770/36770 [00:00<00:00, 181025.58 examples/s]
Saving the dataset (21/21 shards): 100%|██████████| 8053516/8053516 [00:52<00:00, 153929.79 examples/s]
Saving the dataset (6/6 shards): 100%|██████████| 2013380/2013380 [00:12<00:00, 155264.08 examples/s]
Saving the dataset (2/2 shards): 100%|██████████| 443095/443095 [00:02<00:00, 164692.59 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 110774/110774 [00:00<00:00, 163850.02 

In [5]:
subset = DatasetDict.load_from_disk(f"flan2022/flan2021")["train"]

In [None]:
subset 

Dataset({
    features: ['inputs', 'targets', 'task_source', 'task_name', 'template_type'],
    num_rows: 4289888
})

In [14]:
unique_task_names = set(subset['task_name'])
print(f"Total number of unique task names: {len(unique_task_names)}")
print(unique_task_names)

Total number of unique task names: 70
{'ai2_arc/ARC-Easy:1.0.0', 'aeslc:1.0.0', 'wmt16_translate/tr-en:1.0.0', 'wmt14_translate/fr-en:1.0.0', 'wmt16_translate/cs-en:1.0.0', 'natural_questions_open:1.0.0', 'lambada:1.0.0', 'super_glue/wic:1.0.2', 'huggingface:xsum', 'piqa:1.0.0', 'anli/r3:0.1.0', 'ai2_arc/ARC-Challenge:1.0.0', 'glue/qnli:2.0.0', 'samsum:1.0.0', 'gem/wiki_lingua_english_en:1.1.0', 'gem/dart:1.1.0', 'super_glue/wsc.fixed:1.0.2', 'word_segment', 'gem/e2e_nlg:1.1.0', 'glue/sst2:2.0.0', 'super_glue/cb:1.0.2', 'super_glue/copa:1.0.2', 'glue/wnli:2.0.0', 'glue/qqp:2.0.0', 'drop:2.0.0', 'coqa:1.0.0', 'paws_wiki:1.1.0', 'multi_news:1.0.0', 'wmt16_translate/ro-en:1.0.0', 'cnn_dailymail:3.4.0', 'story_cloze/2016:1.0.0', 'gigaword:1.2.0', 'ag_news_subset:1.0.0', 'snli:1.1.0', 'hellaswag:1.1.0', 'wmt16_translate/ru-en:1.0.0', 'trivia_qa/rc:1.1.0', 'definite_pronoun_resolution:1.1.0', 'wmt16_translate/fi-en:1.0.0', 'trec:1.0.0', 'unified_qa_science_inst', 'glue/stsb:2.0.0', 'super_gl

In [3]:
ds.save_to_disk('flan2022')

Saving the dataset (15/15 shards): 100%|██████████| 4289888/4289888 [00:25<00:00, 171351.61 examples/s]
Saving the dataset (4/4 shards): 100%|██████████| 1072473/1072473 [00:06<00:00, 172117.08 examples/s]
Saving the dataset (8/8 shards): 100%|██████████| 1320246/1320246 [00:08<00:00, 164318.24 examples/s]
Saving the dataset (2/2 shards): 100%|██████████| 330062/330062 [00:01<00:00, 172510.28 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 147078/147078 [00:00<00:00, 233681.73 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 36770/36770 [00:00<00:00, 234782.74 examples/s]
Saving the dataset (21/21 shards): 100%|██████████| 8053516/8053516 [00:46<00:00, 174857.94 examples/s]
Saving the dataset (6/6 shards): 100%|██████████| 2013380/2013380 [00:11<00:00, 168334.92 examples/s]
Saving the dataset (2/2 shards): 100%|██████████| 443095/443095 [01:05<00:00, 6772.28 examples/s]  
Saving the dataset (1/1 shards): 100%|██████████| 110774/110774 [00:00<00:00, 201926.18 

In [5]:
ds.save_to_disk('flan2022')

Saving the dataset (18/18 shards): 100%|██████████| 5362361/5362361 [00:06<00:00, 784431.61 examples/s]
Saving the dataset (10/10 shards): 100%|██████████| 1650308/1650308 [00:03<00:00, 515223.51 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 183848/183848 [00:00<00:00, 910694.77 examples/s] 
Saving the dataset (27/27 shards): 100%|██████████| 10066896/10066896 [00:11<00:00, 866352.95 examples/s]
Saving the dataset (3/3 shards): 100%|██████████| 553869/553869 [00:01<00:00, 478013.90 examples/s]


In [4]:
5362361 + 1650308 + 183848 + 10066896 + 553869

17817282

In [8]:
ds

DatasetDict({
    flan2021: Dataset({
        features: ['inputs', 'targets', 'task_source', 'task_name', 'template_type'],
        num_rows: 5362361
    })
    t0: Dataset({
        features: ['inputs', 'targets', 'task_source', 'task_name', 'template_type'],
        num_rows: 1650308
    })
    cot: Dataset({
        features: ['inputs', 'targets', 'task_source', 'task_name', 'template_type'],
        num_rows: 183848
    })
    niv2: Dataset({
        features: ['inputs', 'targets', 'task_source', 'task_name', 'template_type'],
        num_rows: 10066896
    })
    dialog: Dataset({
        features: ['inputs', 'targets', 'task_source', 'task_name', 'template_type'],
        num_rows: 553869
    })
})

In [13]:
ds['flan2021']

{'inputs': 'input ---- Just as with the color of your walls the color of your flooring also plays a significant role in the appearance of space\noutput ---- Just as with the color of your walls, the color of your flooring also plays a significant role in the appearance of space.\n\n\nAdd punctuation: 3821 But if you refuse to go forth this is the word that Yahweh has shown me\nA: 38:21 But if you refuse to go forth, this is the word that Yahweh has shown me:\n\n\nQUESTION: Fix punctuation: Later this month the German government will present its new energy outlook for 2050 with a key focus on the nuclear phaseout and the composition of the countrys future energy mix\nANS: Later this month, the German government will present its new “energy outlook for 2050,” with a key focus on the nuclear phaseout and the composition of the country’s future energy mix.\n\n\nQUES: Your vision of life may be more universal and you may be drawn to spiritual or esoteric subjects which previously you might 

In [14]:
# Count the unique template types
unique_template_types = set(item['template_type'] for item in ds['flan2021'])
total_unique_template_types = len(unique_template_types)

print(f"Total number of unique template types: {total_unique_template_types}")

Total number of unique template types: 4


In [15]:
unique_template_types

{'fs_noopt', 'fs_opt', 'zs_noopt', 'zs_opt'}

In [16]:
df = ds['flan2021'].to_pandas()

In [17]:
df.describe(include='all')

Unnamed: 0,inputs,targets,task_source,task_name,template_type
count,5362361,5362361,5362361,5362361,5362361
unique,4211385,1347467,1,70,4
top,Write a sentence not in English.,no,Flan2021,glue/mnli:2.0.0,zs_opt
freq,37707,272296,5362361,216560,1341883


In [None]:
TORCH_DTYPES={
    "float32": torch.float32,
    "float16": torch.float16,
    "bfloat16": torch.bfloat16,
    "auto": "auto"
}
torch_dtype = 'auto' #default
torch_dtype=TORCH_DTYPES[torch_dtype]

In [None]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")

Setup

In [None]:
# tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    cache_dir=cache_dir, 
    token=hf_access_token
)

tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side="right"

In [None]:
# model configurations
base_model=AutoModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path=model_name,
    torch_dtype=torch_dtype,
    cache_dir=cache_dir,
    token=hf_access_token,
    use_flash_attention_2=use_flash_attention_2
)
base_model.config.use_cache=False
base_model.config.sliding_window=sliding_window

embedding_size=base_model.get_input_embeddings().weight.shape[0]
if len(tokenizer)>embedding_size:
    base_model.resize_token_embeddings(len(tokenizer))

Configure the Environment

In [None]:
# Make one log on every process with the configuration for debugging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
# logger.info(accelerator.state, main_process_only=False)

In [None]:
from pathlib import Path
import os
import torch

# Define variables to replace args.<variable>
push_to_hub = True
hub_model_id = None  # Replace with model ID or set to None to infer from output_dir
output_dir = "/path/to/output"
private_repo = True
hub_token = "your_hub_token_here"

if push_to_hub:
    # Retrieve or infer repo_name
    repo_name = hub_model_id if hub_model_id is not None else Path(output_dir).absolute().name
    
    # Create repo and retrieve repo_id
    repo_id = create_repo(repo_name, exist_ok=True, token=hub_token, private=private_repo).repo_id
    
    # Clone repo locally
    repo = Repository(output_dir, clone_from=repo_id, token=hub_token)

    # Set up .gitignore for specific patterns
    gitignore_path = os.path.join(output_dir, ".gitignore")
    with open(gitignore_path, "w+") as gitignore:
        if "step_*" not in gitignore:
            gitignore.write("step_*\n")
        if "epoch_*" not in gitignore:
            gitignore.write("epoch_*\n")
elif output_dir is not None:
    os.makedirs(output_dir, exist_ok=True)


Process the dataset

In [None]:
# load the dataset
data_path = ""
raw_dataset = load_dataset(path=data_path)
# Preprocessing the datasets
raw_dataset_column_names=raw_dataset["train"].column_names

 We preprocess the data in THREE steps:
   1. Concatenate prompts and responses
   2. Tokenize the concatenated prompt-response pairs
   3. Set the labels corresponding to the prompt tokens to IGNORE_INDEX

In [None]:
def preprocess_function(examples):
    prompts_responses=[p+" "+r for p, r in zip(examples["prompt"], examples["response"])]
    prompts_responses_tokenized=tokenizer(prompts_responses, truncation=True, max_length=max_seq_length)
    prompts_tokenized=tokenizer(examples["prompt"], truncation=True, max_length=max_seq_length)
    all_labels=copy.deepcopy(prompts_responses_tokenized["input_ids"])
    prompts_len=[len(prompt) for prompt in prompts_tokenized["input_ids"]]
    for labels, prompt_len in zip(all_labels, prompts_len):
        labels[:prompt_len]=[IGNORE_INDEX]*prompt_len
    result={k: v for k, v in prompts_responses_tokenized.items()}
    result["labels"]=all_labels
    return result

preprocessed_dataset=raw_dataset.map(
    preprocess_function,
    batched=True,
    num_proc=preprocessing_num_workers,
    load_from_cache_file=not overwrite_cache,
    remove_columns=raw_dataset_column_names,
    desc="Preprocessing the raw dataset",
)

train_dataset=preprocessed_dataset["train"]
eval_dataset=preprocessed_dataset["validation"]

# DataLoaders creation
data_collator=DataCollatorForInstructionTuning(tokenizer)
train_dataloader=DataLoader(
    train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size, pin_memory=True, num_workers=8
)
eval_dataloader=DataLoader(
    eval_dataset, shuffle=False, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size, pin_memory=True, num_workers=8
)

In [None]:
# Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 3):
    logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")

In [None]:
# Load LoRA configuration if --use_peft is passed
if args.use_peft:
    peft_config=LoraConfig(
        r=args.peft_lora_r,
        lora_alpha=args.peft_lora_alpha,
        lora_dropout=args.peft_lora_dropout,
        target_modules=args.peft_target_modules.split(","),
        bias="none",
        task_type=TaskType.CAUSAL_LM,
    )
    model=get_peft_model(base_model, peft_config)
else:
    model=base_model

model.to(device)

Training the Model

In [None]:
optimizer=torch.optim.AdamW(params=model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay, fused=args.adamw_fused)

# Scheduler and math around the number of training steps
overrode_max_train_steps=False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
    args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    overrode_max_train_steps = True

lr_scheduler=get_scheduler(
    name=args.lr_scheduler_type,
    optimizer=optimizer,
    num_warmup_steps=math.floor(args.lr_warmup_fraction*args.max_train_steps),
    num_training_steps=args.max_train_steps
)


In [None]:
import torch
import math
import os
from torch.utils.data import DataLoader

def b2mb(x): 
    return x / 2**20

class TorchTracemalloc:
    def __init__(self):
        self.begin = 0
        self.peaked = 0
        self.used = 0
        self.cpu_begin = 0
        self.cpu_peaked = 0
        self.cpu_used = 0

    def __enter__(self):
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.empty_cache()
        self.begin = torch.cuda.memory_allocated()
        self.cpu_begin = torch.cuda.memory_reserved()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.used = torch.cuda.memory_allocated() - self.begin
        self.peaked = torch.cuda.max_memory_allocated() - self.begin
        self.cpu_used = torch.cuda.memory_reserved() - self.cpu_begin
        self.cpu_peaked = torch.cuda.max_memory_reserved() - self.cpu_begin

for epoch in range(starting_epoch, args.num_train_epochs):
    with TorchTracemalloc() as tracemalloc:
        model.train()
        total_loss = 0 if args.with_tracking else None
        
        # Handle checkpoint resumption
        if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
            for _ in range(resume_step):
                next(iter(train_dataloader))
                
        for step, batch in enumerate(train_dataloader):
            batch = {k: v.to(model.device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss
            
            if args.with_tracking:
                total_loss += loss.detach().float()
            
            # Gradient accumulation
            loss = loss / args.gradient_accumulation_steps
            loss.backward()
            
            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
                
                if args.with_tracking:
                    logger.log({
                        "instant_loss": loss.item() * args.gradient_accumulation_steps,
                        "lr": optimizer.param_groups[0]["lr"],
                        "step": completed_steps
                    })
                
                completed_steps += 1
                progress_bar.update(1)
                
                if isinstance(checkpointing_steps, int) and completed_steps % checkpointing_steps == 0:
                    output_dir = f"step_{completed_steps}"
                    if args.output_dir is not None:
                        output_dir = os.path.join(args.output_dir, output_dir)
                    torch.save({
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'scheduler_state_dict': lr_scheduler.state_dict(),
                        'epoch': epoch,
                    }, output_dir)
            
            if completed_steps >= args.max_train_steps:
                break
    
    # Print memory usage for training
    print(f"GPU Memory before entering the train : {b2mb(tracemalloc.begin)}")
    print(f"GPU Memory consumed at the end of the train: {b2mb(tracemalloc.used)}")
    print(f"GPU Peak Memory consumed during the train: {b2mb(tracemalloc.peaked)}")
    print(f"GPU Total Peak Memory consumed during the train: {b2mb(tracemalloc.peaked + tracemalloc.begin)}")
    
    print(f"CPU Memory before entering the train : {b2mb(tracemalloc.cpu_begin)}")
    print(f"CPU Memory consumed at the end of the train: {b2mb(tracemalloc.cpu_used)}")
    print(f"CPU Peak Memory consumed during the train: {b2mb(tracemalloc.cpu_peaked)}")
    print(f"CPU Total Peak Memory consumed during the train: {b2mb(tracemalloc.cpu_peaked + tracemalloc.cpu_begin)}")
            
    model.eval()
    losses = []
    with TorchTracemalloc() as tracemalloc:
        with torch.no_grad():
            for batch in eval_dataloader:
                batch = {k: v.to(model.device) for k, v in batch.items()}
                outputs = model(**batch)
                loss = outputs.loss
                losses.extend([loss] * args.per_device_eval_batch_size)

    # Print memory usage for evaluation
    print(f"GPU Memory before entering the eval : {b2mb(tracemalloc.begin)}")
    print(f"GPU Memory consumed at the end of the eval: {b2mb(tracemalloc.used)}")
    print(f"GPU Peak Memory consumed during the eval: {b2mb(tracemalloc.peaked)}")
    print(f"GPU Total Peak Memory consumed during the eval: {b2mb(tracemalloc.peaked + tracemalloc.begin)}")
    
    print(f"CPU Memory before entering the eval : {b2mb(tracemalloc.cpu_begin)}")
    print(f"CPU Memory consumed at the end of the eval: {b2mb(tracemalloc.cpu_used)}")
    print(f"CPU Peak Memory consumed during the eval: {b2mb(tracemalloc.cpu_peaked)}")
    print(f"CPU Total Peak Memory consumed during the eval: {b2mb(tracemalloc.cpu_peaked + tracemalloc.cpu_begin)}")

    try:
        eval_loss = torch.mean(torch.tensor(losses))
        perplexity = math.exp(eval_loss.item())
    except OverflowError:
        perplexity = float("inf")

    logger.info(f"epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}")

    if args.with_tracking:
        logger.log({
            "perplexity": perplexity,
            "eval_loss": eval_loss.item(),
            "train_loss": total_loss.item() / len(train_dataloader) if total_loss is not None else None,
            "epoch": epoch,
            "step": completed_steps,
        })
    
    if args.push_to_hub and epoch < args.num_train_epochs - 1:
        model.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)
        repo.push_to_hub(
            commit_message=f"Training in progress epoch {epoch}",
            blocking=False,
            auto_lfs_prune=True
        )
    
    if args.checkpointing_steps == "epoch":
        output_dir = f"epoch_{epoch}"
        if args.output_dir is not None:
            output_dir = os.path.join(args.output_dir, output_dir)
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': lr_scheduler.state_dict(),
            'epoch': epoch,
        }, output_dir)

In [None]:
# Set training parameters as variables
starting_epoch = 0
num_train_epochs = 10  # Example value
with_tracking = True
resume_from_checkpoint = False
gradient_accumulation_steps = 4
output_dir = "/path/to/output"
max_train_steps = 1000
checkpointing_steps = "epoch"
per_device_eval_batch_size = 8
push_to_hub = False

for epoch in range(starting_epoch, num_train_epochs):
    with TorchTracemalloc() as tracemalloc:
        model.train()
        total_loss = 0 if with_tracking else None

        # Handle checkpoint resumption
        # if resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
        #     for _ in range(resume_step):
        #         next(iter(train_dataloader))
                
        for step, batch in enumerate(train_dataloader):
            batch = {k: v.to(model.device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss
            
            if with_tracking:
                total_loss += loss.detach().float()
            
            # Gradient accumulation
            loss = loss / gradient_accumulation_steps
            loss.backward()
            
            if (step + 1) % gradient_accumulation_steps == 0:
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
                
                if with_tracking:
                    logger.log({
                        "instant_loss": loss.item() * gradient_accumulation_steps,
                        "lr": optimizer.param_groups[0]["lr"],
                        "step": completed_steps
                    })
                
                completed_steps += 1
                progress_bar.update(1)
                
                if isinstance(checkpointing_steps, int) and completed_steps % checkpointing_steps == 0:
                    step_output_dir = f"step_{completed_steps}"
                    if output_dir is not None:
                        step_output_dir = os.path.join(output_dir, step_output_dir)
                    torch.save({
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'scheduler_state_dict': lr_scheduler.state_dict(),
                        'epoch': epoch,
                    }, step_output_dir)
            
            if completed_steps >= max_train_steps:
                break
    
    # Print memory usage for training
    print(f"GPU Memory before entering the train : {b2mb(tracemalloc.begin)}")
    print(f"GPU Memory consumed at the end of the train: {b2mb(tracemalloc.used)}")
    print(f"GPU Peak Memory consumed during the train: {b2mb(tracemalloc.peaked)}")
    print(f"GPU Total Peak Memory consumed during the train: {b2mb(tracemalloc.peaked + tracemalloc.begin)}")
    
    print(f"CPU Memory before entering the train : {b2mb(tracemalloc.cpu_begin)}")
    print(f"CPU Memory consumed at the end of the train: {b2mb(tracemalloc.cpu_used)}")
    print(f"CPU Peak Memory consumed during the train: {b2mb(tracemalloc.cpu_peaked)}")
    print(f"CPU Total Peak Memory consumed during the train: {b2mb(tracemalloc.cpu_peaked + tracemalloc.cpu_begin)}")
            
    model.eval()
    losses = []
    with TorchTracemalloc() as tracemalloc:
        with torch.no_grad():
            for batch in eval_dataloader:
                batch = {k: v.to(model.device) for k, v in batch.items()}
                outputs = model(**batch)
                loss = outputs.loss
                losses.extend([loss] * per_device_eval_batch_size)

    # Print memory usage for evaluation
    print(f"GPU Memory before entering the eval : {b2mb(tracemalloc.begin)}")
    print(f"GPU Memory consumed at the end of the eval: {b2mb(tracemalloc.used)}")
    print(f"GPU Peak Memory consumed during the eval: {b2mb(tracemalloc.peaked)}")
    print(f"GPU Total Peak Memory consumed during the eval: {b2mb(tracemalloc.peaked + tracemalloc.begin)}")
    
    print(f"CPU Memory before entering the eval : {b2mb(tracemalloc.cpu_begin)}")
    print(f"CPU Memory consumed at the end of the eval: {b2mb(tracemalloc.cpu_used)}")
    print(f"CPU Peak Memory consumed during the eval: {b2mb(tracemalloc.cpu_peaked)}")
    print(f"CPU Total Peak Memory consumed during the eval: {b2mb(tracemalloc.cpu_peaked + tracemalloc.cpu_begin)}")

    try:
        eval_loss = torch.mean(torch.tensor(losses))
        perplexity = math.exp(eval_loss.item())
    except OverflowError:
        perplexity = float("inf")

    logger.info(f"epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}")

    if with_tracking:
        logger.log({
            "perplexity": perplexity,
            "eval_loss": eval_loss.item(),
            "train_loss": total_loss.item() / len(train_dataloader) if total_loss is not None else None,
            "epoch": epoch,
            "step": completed_steps,
        })
    
    if push_to_hub and epoch < num_train_epochs - 1:
        model.save_pretrained(output_dir)
        tokenizer.save_pretrained(output_dir)
        repo.push_to_hub(
            commit_message=f"Training in progress epoch {epoch}",
            blocking=False,
            auto_lfs_prune=True
        )
    
    if checkpointing_steps == "epoch":
        epoch_output_dir = f"epoch_{epoch}"
        if output_dir is not None:
            epoch_output_dir = os.path.join(output_dir, epoch_output_dir)
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': lr_scheduler.state_dict(),
            'epoch': epoch,
        }, epoch_output_dir)


In [None]:
import os
import json
import torch

# Define variables to replace args.<variable>
with_tracking = True
output_dir = "/path/to/output"
push_to_hub = False
use_peft = False
merge_weights = False

# End tracking if specified
if with_tracking:
    # Logic to end tracking (use wandb, tensorboard, etc. if needed)
    print("End of tracking")

# Save model and tokenizer if output directory is provided
if output_dir is not None:
    # Save the model
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)

    # If pushing to a hub
    if push_to_hub:
        # Replace with the hub-specific code for pushing the repo
        repo.push_to_hub(commit_message="End of Training", auto_lfs_prune=True)

    # Save final results in JSON
    with open(os.path.join(output_dir, "all_results.json"), "w") as f:
        json.dump({"perplexity": perplexity}, f)

# Merge weights if using PEFT and merging is specified
if use_peft and merge_weights:
    # Free memory for merging weights
    del base_model
    torch.cuda.empty_cache()

    # Load and merge model weights
    model = AutoPeftModelForCausalLM.from_pretrained(output_dir, device_map="auto", torch_dtype=torch_dtype)
    model = model.merge_and_unload()

    # Save the merged model to a separate directory
    output_merged_dir = os.path.join(output_dir, "final_merged_checkpoint")
    model.save_pretrained(output_merged_dir)
    tokenizer.save_pretrained(output_merged_dir)


In [2]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B")

Loading checkpoint shards: 100%|██████████| 4/4 [00:05<00:00,  1.29s/it]


In [None]:
def download_flan2022():
    print("Check if FLAN 2022 dataset is already downloaded. If not, download it and load it")
    for submixture in SUBMIXTURES:
        print(f"Loading {submixture} dataset...")
        dataset=load_dataset(f"{HUB_USERNAME}/{submixture}-submix-4096")
        print(dataset)