In [None]:
!pip install -q transformers datasets deepspeed torch accelerate evaluate bitsandbytes pyyaml wandb

In [None]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
WANDB_API_KEY = user_secrets.get_secret("WANDB_API_KEY")

In [None]:
import wandb
wandb.login(key=WANDB_API_KEY)

In [None]:
%%writefile train_torch_pipeline.py

import torch
import torch.nn as nn
import time
import math
import argparse
import os
import wandb
import torch.distributed as dist
from torch.distributed.pipelining import pipeline, PipelineStage, ScheduleGPipe
from torch.optim import AdamW
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator, get_scheduler
from transformers.models.bloom.modeling_bloom import build_alibi_tensor
from datasets import load_dataset
from tqdm.auto import tqdm

def set_seed(seed):
    """Hàm để set random seed cho reproducibility."""
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# --- Các lớp cơ bản ---
class LanguageModelLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, logits, labels):
        label_batch_size = labels.size(0)
        logit_batch_size = logits.size(0)
        if logit_batch_size > label_batch_size:
            logits = logits[-label_batch_size:]
        logits_view = logits.view(-1, logits.size(-1))
        labels_view = labels.view(-1)
        return self.loss_fn(logits_view, labels_view)

class BloomPipelineStage1(nn.Module):
    def __init__(self, embedding, layernorm, blocks, num_attention_heads):
        super().__init__()
        self.embedding = embedding
        self.layernorm = layernorm
        self.blocks = nn.ModuleList(blocks)
        self.num_attention_heads = num_attention_heads

    def forward(self, input_ids, attention_mask):
        hidden_states = self.embedding(input_ids)
        hidden_states = self.layernorm(hidden_states)
        alibi = build_alibi_tensor(attention_mask, self.num_attention_heads, dtype=hidden_states.dtype).to(hidden_states.device)
        batch_size, seq_length = input_ids.shape
        causal_mask = attention_mask[:, None, None, :].expand(batch_size, 1, seq_length, seq_length)
        causal_mask = causal_mask.to(dtype=hidden_states.dtype) 
        causal_mask = causal_mask.contiguous()
        for block in self.blocks:
            outputs = block(hidden_states, alibi=alibi, attention_mask=causal_mask)
            hidden_states = outputs[0]
        return hidden_states, alibi, causal_mask

class BloomPipelineStage2(nn.Module):
    def __init__(self, blocks, final_layernorm, lm_head):
        super().__init__()
        self.blocks = nn.ModuleList(blocks)
        self.final_layernorm = final_layernorm
        self.lm_head = lm_head

    def forward(self, hidden_states, alibi, attention_mask):
        for block in self.blocks:
            outputs = block(hidden_states, alibi=alibi, attention_mask=attention_mask)
            hidden_states = outputs[0]
        hidden_states = self.final_layernorm(hidden_states)
        logits = self.lm_head(hidden_states)
        return logits

def main():
    # --- Cấu hình tham số ---
    parser = argparse.ArgumentParser(description="Finetune BLOOM with torch.distributed.pipelining")
    
    parser.add_argument("--model_name", type=str, default="bigscience/bloom-560m")
    parser.add_argument("--dataset_name", type=str, default="Salesforce/wikitext")
    parser.add_argument("--dataset_config", type=str, default="wikitext-2-raw-v1")
    
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--num_epochs", type=int, default=2)
    parser.add_argument("--block_size", type=int, default=128)
    parser.add_argument("--batch_size", type=int, default=2) 
    parser.add_argument("--lr", type=float, default=3e-6)
    parser.add_argument("--weight_decay", type=float, default=0.01)
    parser.add_argument("--warmup_steps", type=int, default=20)
    
    parser.add_argument("--logging_steps", type=int, default=5)
    parser.add_argument("--wandb_project", type=str, default="torch_pipeline_finetune")
    parser.add_argument("--wandb_run_name", type=str, default=None)
    
    args = parser.parse_args()

    # --- Khởi tạo môi trường phân tán ---
    rank = int(os.environ["LOCAL_RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    device = torch.device(f"cuda:{rank}")
    dist.init_process_group(backend="nccl")
    
    set_seed(args.seed)

    if args.wandb_run_name:
        args.wandb_run_name += f"-{int(time.time())}"

    if rank == 0:
        wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=vars(args))
        print("Arguments:", args)

    # --- Tải Model và chia Stage ---
    if rank == 0: print("Đang tải tokenizer và model...")
    model_hf = AutoModelForCausalLM.from_pretrained(args.model_name).cpu()
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token

    num_layers = len(model_hf.transformer.h)
    layers_per_stage = num_layers // world_size
    if rank == 0:
        print(f"Chia {num_layers} lớp thành {world_size} giai đoạn, mỗi giai đoạn {layers_per_stage} lớp.")

    if rank == 0:
        stage_model = BloomPipelineStage1(
            embedding=model_hf.transformer.word_embeddings,
            layernorm=model_hf.transformer.word_embeddings_layernorm,
            blocks=model_hf.transformer.h[:layers_per_stage],
            num_attention_heads=model_hf.config.num_attention_heads
        ).to(device)
    else: # rank == 1
        stage_model = BloomPipelineStage2(
            blocks=model_hf.transformer.h[layers_per_stage:],
            final_layernorm=model_hf.transformer.ln_f,
            lm_head=model_hf.lm_head
        ).to(device)
    
    del model_hf

    # --- Tải và xử lý Dataset ---
    if rank == 0: print("Đang tải và xử lý dữ liệu...")
    raw_datasets = load_dataset(args.dataset_name, args.dataset_config)
    raw_datasets['train'] = raw_datasets['train'].select(range(1000))
    splits_to_remove = [split for split in raw_datasets.keys() if split != 'train']
    for split in splits_to_remove: del raw_datasets[split]

    column_names = raw_datasets["train"].column_names
    text_column_name = "text" if "text" in column_names else column_names[0]

    def tokenize_function(examples):
        return tokenizer(examples[text_column_name])

    tokenized_datasets = raw_datasets.map(
        tokenize_function, batched=True, remove_columns=column_names, desc="Running tokenizer on dataset" if rank == 0 else None
    )

    def group_texts(examples):
        concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        total_length = (total_length // args.block_size) * args.block_size
        result = {
            k: [t[i : i + args.block_size] for i in range(0, total_length, args.block_size)]
            for k, t in concatenated_examples.items()
        }
        result["labels"] = result["input_ids"].copy()
        return result

    lm_datasets = tokenized_datasets.map(
        group_texts, batched=True, desc=f"Grouping texts in chunks of {args.block_size}" if rank == 0 else None
    )
    if rank == 0: print(lm_datasets)
    
    train_dataset = lm_datasets["train"]
    train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
    dataloader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=default_data_collator, sampler=train_sampler)

    # --- Thiết lập Pipeline ---
    stage = PipelineStage(stage_model, rank, world_size, device)
    loss_fn = LanguageModelLoss().to(device)
    n_microbatches = world_size 
    schedule = ScheduleGPipe(stage, n_microbatches, loss_fn=loss_fn)
    
    # --- Optimizer & Scheduler ---
    optimizer = AdamW(stage_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    num_training_steps = args.num_epochs * len(dataloader)
    lr_scheduler = get_scheduler(
        name="linear",
        optimizer=optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=num_training_steps,
    )

    # --- Vòng lặp huấn luyện ---
    if rank == 0: print("\n*** Bắt đầu huấn luyện ***")
    
    global_step = 0
    start_training_time = time.time()
    
    for epoch in range(args.num_epochs):
        stage_model.train()
        train_sampler.set_epoch(epoch)
        start_epoch_time = time.time()
        if rank == 0:
            print(f"\n--- Bắt đầu Epoch {epoch+1}/{args.num_epochs} ---")
        
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}", disable=(rank != 0))

        for batch in progress_bar:
            optimizer.zero_grad()
            
            inputs = (batch['input_ids'].to(device), batch['attention_mask'].to(device))
            targets = batch['labels'].to(device)
            
            # Khởi tạo tensor loss trên tất cả các rank
            loss = torch.tensor(0.0, device=device)
            
            if rank == 0:
                schedule.step(*inputs)
            elif rank == world_size - 1: # Stage cuối cùng
                losses = []
                schedule.step(target=targets, losses=losses)
                loss = losses[0]
            else: # Các stage ở giữa
                schedule.step()

            optimizer.step()
            lr_scheduler.step()
            
            global_step += 1
            
            # --- SỬA LỖI: Gửi loss từ rank cuối về rank 0 để log ---
            # 1. Phát sóng (broadcast) giá trị loss từ rank cuối đến tất cả các rank khác.
            dist.broadcast(loss, src=world_size - 1)
            
            # 2. Bây giờ tất cả các rank đều có giá trị loss, nhưng chỉ rank 0 mới thực hiện logging.
            if rank == 0:
                progress_bar.set_postfix({"loss": loss.item()})
                if global_step % args.logging_steps == 0:
                    log_metrics = {
                        "train_loss": loss.item(),
                        "learning_rate": lr_scheduler.get_last_lr()[0],
                        "epoch": epoch + (progress_bar.n / progress_bar.total),
                    }
                    wandb.log(log_metrics, step=global_step)
            # --- KẾT THÚC SỬA LỖI ---

        dist.barrier()
        if rank == 0:
            epoch_time = time.time() - start_epoch_time
            print(f"--- Epoch {epoch+1} hoàn tất trong {epoch_time:.2f} giây ---")
            wandb.log({
                "epoch_time (s)": epoch_time,
                "epoch": epoch + 1
            }, step=global_step)

    # --- Kết thúc ---
    dist.barrier()
    if rank == 0:
        total_training_time = time.time() - start_training_time
        print(f"\n*** Huấn luyện hoàn tất trong: {total_training_time:.2f} giây ***")
        wandb.finish()
        
    dist.destroy_process_group()

if __name__ == "__main__":
    main()

In [None]:
!torchrun --nnodes 1 --nproc_per_node 2 /kaggle/working/train_torch_pipeline.py \
    --model_name "bigscience/bloom-1b7" \
    --batch_size 4 \
    --num_epochs 3 \
    --logging_steps 2 \
    --wandb_project "PARADIS-bloom_1b7" \
    --wandb_run_name "TD_Pipeline"