In [None]:
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

PROJECT_ROOT = ""  # TODO: Change to the correct path if needed

SAVE_ROOT = ""  # TODO: Change to the correct path if needed


In [2]:
import os
os.chdir(PROJECT_ROOT)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import tyro
from typing import Iterable

import torch
from transformers import AutoModelForCausalLM

import torchtitan.protocols.train_spec as train_spec_module
from torchtitan.components.checkpoint import CheckpointManager
from torchtitan.config import ConfigManager
from torchtitan.tools import utils
from torchtitan.experiments.evaluation.generator.utils import DummyOptimizerContainer, DummyLRSchedulerContainer

TOML = {
    "llama3_2_1b": "torchtitan/experiments/evaluation/llama3/train_configs/llama3.2_1b.toml",
    "llama3_2_3b": "torchtitan/experiments/evaluation/llama3/train_configs/llama3.2_3b.toml",
}

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
exp_name =  "llama_3.2_1b_dcp"  # TODO: Change to the correct experiment name if needed
model_size = "llama3_2_1b"

file_path = os.path.join(PROJECT_ROOT, TOML[model_size])

In [4]:
# Initialize ConfigManager and load the configuration
config_manager = ConfigManager()

args = [f"--job.config_file={file_path}",]

toml_values = config_manager._maybe_load_toml(args)
config_cls = config_manager._maybe_add_custom_args(args, toml_values)

base_config = (
    config_manager._dict_to_dataclass(config_cls, toml_values)
    if toml_values
    else config_cls()
)
custom_registry = tyro.constructors.ConstructorRegistry()
job_config = tyro.cli(
    config_cls, args=args, default=base_config, registry=custom_registry
)

# TODO: Change the dataset and dataset path if needed
job_config.training.dataset = "c4_test"
job_config.training.dataset_path = "/home/sangminbae/torchtitan/tests/assets/c4_test"

job_config.job.dump_folder = os.path.join(SAVE_ROOT, "outputs", exp_name)
job_config.checkpoint.folder = os.path.join(SAVE_ROOT, "checkpoints", exp_name)
job_config.checkpoint.load_step = -1
job_config.checkpoint.enable_checkpoint = True

In [5]:
# Build Model and Tokenizer
train_spec = train_spec_module.get_train_spec(job_config.model.name)

# build model (using meta init)
model_cls = train_spec.model_cls
model_args = train_spec.model_args[job_config.model.flavor]

tokenizer = (
    train_spec.build_tokenizer_fn(job_config)
    if train_spec.build_tokenizer_fn is not None
    else None
)
# set the model args from training job configs
model_args.update_from_config(job_config)

with torch.device("cuda"):
    model = model_cls(model_args)

dummy_dataloader = train_spec.build_dataloader_fn(
    dp_world_size=1,
    dp_rank=0,
    tokenizer=tokenizer,
    job_config=job_config,
)

dummy_optimizers = DummyOptimizerContainer()
dummy_lr_schedulers = DummyLRSchedulerContainer()

checkpointer = CheckpointManager(
    dataloader=dummy_dataloader, # from your notebook
    model_parts=[model],
    optimizers=dummy_optimizers, # Replace with actual or proper dummy
    lr_schedulers=dummy_lr_schedulers, # Replace with actual or proper dummy
    states={}, # Any other custom states
    checkpoint_config=job_config.checkpoint,
    sd_adapter=None,
    ft_manager=None,
)

try:
    checkpointer.load(step=job_config.checkpoint.load_step)
except Exception as e:
    # If WORLD_SIZE is different, it might raise an error,
    # but checkpoint loading should still work
    print(f"Failed to load checkpoint: {e}")



In [None]:
# Load the model from Hugging Face
hf_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")

In [7]:
dataloader = train_spec.build_dataloader_fn(
    dp_world_size=1,
    dp_rank=0,
    tokenizer=tokenizer,
    job_config=job_config,
)

data_iterator = iter(dataloader)

def next_batch(
    data_iterator: Iterable
) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
    batch = next(data_iterator)
    input_dict, labels = batch

    device_type = utils.device_type
    for k, _ in input_dict.items():
        input_dict[k] = input_dict[k].to(device_type)
    labels = labels.to(device_type)
    return input_dict, labels

In [8]:
loss_fn = train_spec.build_loss_fn(job_config)

sample_numbers = 5

model.eval()
hf_model.eval()
with torch.no_grad():
    model.to("cuda")
    hf_model.to("cuda")

    for idx in range(sample_numbers):
        input_dict, labels = next_batch(data_iterator)
        input = input_dict["input"]
        
        pred1 = model(input)
        pred2 = hf_model(input)

        loss1 = loss_fn(pred1, labels)
        print(f"TorchTitan loss: {loss1.item():.10f}")
        loss2 = loss_fn(pred2.logits, labels)
        print(f"HuggingFace loss: {loss2.item():.10f}")

TorchTitan loss: 3.0213506222
HuggingFace loss: 3.0213508606
TorchTitan loss: 3.1651334763
HuggingFace loss: 3.1651337147
TorchTitan loss: 3.0910332203
HuggingFace loss: 3.0910334587
TorchTitan loss: 2.9420754910
HuggingFace loss: 2.9420757294
TorchTitan loss: 3.0103466511
HuggingFace loss: 3.0103468895
