# Setup environment

## Environment variables

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" # Only use 1 GPU
os.environ["TOKENIZERS_PARALLELISM"] = "false"

## Get secrets

In [2]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
HF_TOKEN = user_secrets.get_secret("HF_TOKEN")
WANDB_API_KEY = user_secrets.get_secret("WANDB_API_KEY")

## Import modules

In [3]:
!pip install -qU transformers==4.51.0 accelerate bitsandbytes deepspeed mpi4py

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m27.0 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m466.3/466.3 kB[0m [31m33.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.4/10.4 MB[0m [31m103.9 MB/s[0m eta [36m0:00:00[0m00:01[0m0:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m365.3/365.3 kB[0m [31m30.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.0/67.0 MB[0m [31m26.9 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    get_linear_schedule_with_warmup
)
from datasets import load_dataset

import wandb
import os
import numpy as np
from datetime import datetime
import json
from tqdm.auto import tqdm
import gc
import math
import time
import deepspeed

[2025-06-21 10:39:26,639] [INFO] [real_accelerator.py:254:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/usr/bin/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
/usr/bin/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status


[2025-06-21 10:39:29,890] [INFO] [logging.py:107:log_dist] [Rank -1] [TorchCheckpointEngine] Initialized with serialization = False


2025-06-21 10:39:32.478288: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750502372.700884      35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750502372.788975      35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


## Random seed & device

In [5]:
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


# Finetune config

In [None]:
class Config:
    # Model configuration
    model_name = "Qwen/Qwen3-0.6B-Instruct"
    # model_name = "Qwen/Qwen3-1.7B"
    dataset_name = "vietgpt/wikipedia_vi"
    
    # Training configuration
    output_dir = "./qwen-vietnamese-wiki-finetuned"
    # output_dir = "./qwen-vietnamese-wiki-finetuned-2"
    num_train_epochs = 5
    per_device_train_batch_size = 1
    per_device_valid_batch_size = 1
    gradient_accumulation_steps = 8
    learning_rate = 5e-5
    weight_decay = 0.01
    warmup_ratio = 0.1
    max_length = 128

    # Optimization settings
    adam_epsilon = 1e-8
    max_grad_norm = 1.0
    
    # Logging and saving
    logging_steps = 40
    save_strategy = "epoch"
    valid_strategy = "epoch"
    
    # Other settings
    fp16 = True
    num_workers = os.cpu_count()
    
    # W&B configuration
    use_wandb = True
    wandb_run_id = None
    wandb_project = "DeepSpeed-Pipeline-Qwen"
    # wandb_project = "PARADIS-Qwen3_1.7B"
    wandb_run_name = "2GPU-qwen3_0.6B"

    # HuggingFace configuration
    use_hf = True
    hf_repo = "ThanhND1501/DeepSpeed-Pipeline-Qwen-10kWikiVi"
    # hf_repo = "h9art/PARADIS-Qwen3_1.7B-10kWikiVi-1GPU"
    
    # Dataset
    train_size = 10000
    valid_size = 10000
    test_size = 5000
    min_text_length = 50
    random_seed = 42

config = Config()

In [7]:
config_dict = {k: v for k, v in Config.__dict__.items() if not k.startswith("__") and not callable(v)}
config_dict

{'model_name': 'Qwen/Qwen2.5-0.5B-Instruct',
 'dataset_name': 'vietgpt/wikipedia_vi',
 'output_dir': './qwen-vietnamese-wiki-finetuned',
 'num_train_epochs': 5,
 'per_device_train_batch_size': 1,
 'per_device_valid_batch_size': 1,
 'gradient_accumulation_steps': 8,
 'learning_rate': 5e-05,
 'weight_decay': 0.01,
 'warmup_ratio': 0.1,
 'max_length': 128,
 'adam_epsilon': 1e-08,
 'max_grad_norm': 1.0,
 'logging_steps': 40,
 'save_strategy': 'epoch',
 'valid_strategy': 'epoch',
 'fp16': True,
 'num_workers': 4,
 'use_wandb': True,
 'wandb_run_id': None,
 'wandb_project': 'DeepSpeed-Pipeline-Qwen',
 'wandb_run_name': '2GPU-qwen2.5_0.5B',
 'use_hf': True,
 'hf_repo': 'ThanhND1501/DeepSpeed-Pipeline-Qwen-10kWikiVi',
 'train_size': 10000,
 'valid_size': 10000,
 'test_size': 5000,
 'min_text_length': 50,
 'random_seed': 42}

In [8]:
ds_config = {
    "train_batch_size": config.per_device_train_batch_size * config.gradient_accumulation_steps,
    "gradient_accumulation_steps": config.gradient_accumulation_steps,
    "fp16": {
        "enabled": config.fp16
    },
    "optimizer": {
        "type": "Adam",
        "params": {
            "lr": config.learning_rate,
            "eps": config.adam_epsilon,
            "weight_decay": config.weight_decay
        }
    },
    "pipeline": {
        "stages": 4,
        "partition_method": "size"
    },
    "zero_optimization": {
        "stage": 2
    }
}
ds_config_path = "ds_config.json"
with open(ds_config_path, "w") as f:
    json.dump(ds_config, f, indent=2)

# Setup wandb

In [9]:
wandb.login(key=WANDB_API_KEY)
if config.use_wandb:
    if config.wandb_run_id is None:
        wandb.init( # New run
            project=config.wandb_project,
            name=config.wandb_run_name,
            config=config_dict,
        )
    else:
        wandb.init( # Resume to created run
            project=config.wandb_project,
            id=config.wandb_run_id,
            resume='allow',
        )

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mthanh_nd[0m ([33mthanh_nd_ai[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


# Setup HuggingFace

In [10]:
if config.use_hf:
    from huggingface_hub import login, HfApi
    login(HF_TOKEN)
    hf_api = HfApi()

# Model and tokenizer

## Download and quantization

In [11]:
print("Loading tokenizer and model...")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    config.model_name,
    trust_remote_code=True,
    padding_side="right"
)

# Add pad token if it doesn't exist
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Cấu hình 4-bit quantization
# quantization_config = BitsAndBytesConfig(load_in_4bit=True)

# Load model
model = AutoModelForCausalLM.from_pretrained(
    config.model_name,
    device_map="auto", # automatically move to correct device
    #quantization_config=quantization_config,
    torch_dtype=torch.float32,
    trust_remote_code=True
)

# Turn on gradient checkpointing to save memory
model.config.use_cache = False
model.gradient_checkpointing_enable()

# Num parameters
print(f"Model loaded. Parameters: {model.num_parameters():,}")

Loading tokenizer and model...


tokenizer_config.json:   0%|          | 0.00/7.30k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/2.78M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/1.67M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/7.03M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/659 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/988M [00:00<?, ?B/s]

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


generation_config.json:   0%|          | 0.00/242 [00:00<?, ?B/s]

Model loaded. Parameters: 494,032,768


## Generation function

In [12]:
def generate_text(
    prompt,
    max_length=config.max_length,
    temperature=0.7,
    top_p=0.9,
    top_k=50
):
    """Generate text using DeepSpeed pipeline-enabled model."""

    model.eval()

    # Tokenize and move to correct device
    inputs = tokenizer.encode(prompt, return_tensors='pt').to(model.device)

    with torch.no_grad():
        outputs = model.module.generate(
            inputs,
            max_length=max_length,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.1
        )

    return tokenizer.decode(outputs[0], skip_special_tokens=True)


# Dataset

## Custom dataset

In [13]:
wikiclass = """from torch.utils.data import DataLoader, Dataset
class WikiViDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # Get data
        item = self.dataset[idx]
        combined_text = f"Tiêu đề: {item['title']}\\n\\nNội dung: {item['text']}"

        # Tokenize data
        tokenized_text = self.tokenizer(
            combined_text,
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt",
        )

        # # Print a tokenized sample
        # print(tokenized_text)

        # Prepare data from tokenizer output
        input_ids = tokenized_text["input_ids"]
        attention_mask = tokenized_text["attention_mask"]
        labels = input_ids.clone() # In causal LM, labels is the same with input_ids
        labels[attention_mask == 0] = -100 # Do not calculate loss on padding tokens

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }"""
with open('wikidataset.py', 'w', encoding='utf-8') as f:
    f.write(wikiclass)

In [14]:
from wikidataset import WikiViDataset

## Load wikipedia_vi dataset

In [15]:
print("Loading dataset...")
dataset = load_dataset(config.dataset_name, split="train")
print(f"Dataset loaded. Total samples: {len(dataset)}")

Loading dataset...


README.md:   0%|          | 0.00/632 [00:00<?, ?B/s]

(…)-00000-of-00003-6218d2963e302058.parquet:   0%|          | 0.00/245M [00:00<?, ?B/s]

(…)-00001-of-00003-12e6c4fadbec91d4.parquet:   0%|          | 0.00/55.2M [00:00<?, ?B/s]

(…)-00002-of-00003-175fcfe1c45b0b85.parquet:   0%|          | 0.00/270M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1284930 [00:00<?, ? examples/s]

Dataset loaded. Total samples: 1284930


In [16]:
dataset[0]

{'id': 2,
 'revid': '90949',
 'url': 'https://vi.wikipedia.org/wiki?curid=2',
 'title': 'Trang Chính',
 'text': '&lt;templatestyles src="Wiki2021/styles.css" /&gt;__NOEDITSECTION__\n \n \n \n '}

## Preprocess data

In [17]:
# keep only title and text column
dataset = dataset.select_columns(['title', 'text'])

In [18]:
dataset[0]

{'title': 'Trang Chính',
 'text': '&lt;templatestyles src="Wiki2021/styles.css" /&gt;__NOEDITSECTION__\n \n \n \n '}

In [19]:
def filter_function(example):
    """Filter out empty or very short texts"""
    
    return (
        example['text'] is not None and 
        example['title'] is not None and
        len(example['text'].strip()) > config.min_text_length
    )

dataset = dataset.filter(filter_function)
print(f"After filtering: {len(dataset)} samples")

Filter:   0%|          | 0/1284930 [00:00<?, ? examples/s]

After filtering: 1263196 samples


## Create splits

In [20]:
dataset = dataset.shuffle(seed=config.random_seed)

train_split = dataset.select(range(
    config.train_size
))

valid_split = dataset.select(range(
    config.train_size,
    config.train_size + config.valid_size
))

test_split = dataset.select(range(
    config.train_size + config.valid_size,
    config.train_size + config.valid_size + config.test_size
))

print(f'train split: {len(train_split)} samples')
print(f'valid split: {len(valid_split)} samples')
print(f'test split: {len(test_split)} samples')

train split: 10000 samples
valid split: 10000 samples
test split: 5000 samples


In [21]:
train_ds = WikiViDataset(train_split, tokenizer, config.max_length)
valid_ds = WikiViDataset(valid_split, tokenizer, config.max_length)
test_ds = WikiViDataset(test_split, tokenizer, config.max_length)

In [22]:
# # Display a sample
# train_ds[0]

## Data loader

In [23]:
train_dataloader = DataLoader(
    train_ds,
    batch_size=config.per_device_train_batch_size,
    shuffle=True,
    num_workers=config.num_workers,
    pin_memory=False,
)

valid_dataloader = DataLoader(
    valid_ds,
    batch_size=config.per_device_valid_batch_size,
    shuffle=True,
    num_workers=config.num_workers,
    pin_memory=False,
)

In [24]:
print(f"Train batches: {len(train_dataloader)}")
print(f"Valid batches: {len(valid_dataloader)}")

Train batches: 10000
Valid batches: 10000


# Optimizer & scheduler

In [25]:
total_steps = len(train_dataloader) * config.num_train_epochs // config.gradient_accumulation_steps
warmup_steps = int(total_steps * config.warmup_ratio)

print(f"Total training steps: {total_steps}")
print(f"Warmup steps: {warmup_steps}")

Total training steps: 6250
Warmup steps: 625


In [31]:
# Initialize DeepSpeed engine
model, optimizer, _, scheduler = deepspeed.initialize(
    model=model,
    config_params=ds_config,
    dist_init_required=True
)

[2025-06-21 10:43:14,448] [INFO] [logging.py:107:log_dist] [Rank 0] DeepSpeed info: version=0.17.1, git-hash=unknown, git-branch=unknown
[2025-06-21 10:43:14,450] [INFO] [comm.py:700:init_distributed] Distributed backend already initialized
[2025-06-21 10:43:14,451] [INFO] [config.py:655:__init__] Config mesh_device None world_size = 1
[2025-06-21 10:43:14,465] [INFO] [engine.py:1325:_configure_distributed_model] ********** distributed groups summary **********
	 self.dp_world_size=1
	 self.mp_world_size=1
	 self.seq_dp_world_size=1
	 self.sequence_parallel_size=1
***********************************************
[2025-06-21 10:43:14,480] [INFO] [logging.py:107:log_dist] [Rank 0] DeepSpeed Flops Profiler Enabled: False
[2025-06-21 10:43:14,482] [INFO] [logging.py:107:log_dist] [Rank 0] Using DeepSpeed Optimizer param name adam as basic optimizer
[2025-06-21 10:43:14,483] [INFO] [logging.py:107:log_dist] [Rank 0] Removing param_group that has no 'params' in the basic Optimizer
[2025-06-21

# Training function

In [32]:
def train_epoch(model, dataloader, optimizer, scheduler, epoch):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    model.zero_grad()
    progress_bar = tqdm(dataloader, desc=f"Training Epoch {epoch + 1}")

    for step, batch in enumerate(progress_bar):
        input_ids = batch['input_ids']  # REMOVE .to(model.device)
        attention_mask = batch['attention_mask']
        labels = batch['labels']
        #print(input_ids.shape)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        loss = outputs.loss / config.gradient_accumulation_steps

        model.backward(loss)

        total_loss += loss.item()

        if (step + 1) % config.gradient_accumulation_steps == 0:
            model.step()
            model.zero_grad()

        progress_bar.set_postfix({
            'loss': f"{loss.item() * config.gradient_accumulation_steps:.4f}",
            'lr': f"{scheduler.get_last_lr()[0]:.2e}"
        })

        if (step + 1) % config.logging_steps == 0:
            avg_loss = total_loss / (step + 1) * config.gradient_accumulation_steps
            print(f"Step {step + 1}/{len(dataloader)}, Loss: {avg_loss:.4f}, LR: {scheduler.get_last_lr()[0]:.2e}")

            if config.use_wandb:
                wandb.log({
                    "train_loss": avg_loss,
                    "learning_rate": scheduler.get_last_lr()[0],
                    "train_step": epoch * len(dataloader) + step + 1
                })

    return total_loss / len(dataloader) * config.gradient_accumulation_steps

# Validation function

In [33]:
def validate(model, dataloader):
    """Validate the model."""
    model.eval()
    total_loss = 0
    total_steps = 0

    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc="Validating")
        for batch in progress_bar:
            input_ids = batch['input_ids']
            attention_mask = batch['attention_mask']
            labels = batch['labels']
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )

            loss = outputs.loss
            total_loss += loss.item()
            total_steps += 1

            progress_bar.set_postfix({'valid_loss': f"{loss.item():.4f}"})

    avg_loss = total_loss / total_steps
    perplexity = math.exp(avg_loss)

    return avg_loss, perplexity

# Training loop

## Test before training

In [29]:
#test_prompts = [
#    "Việt Nam là một quốc gia",
#    "Tiêu đề: Hà Nội\n\nNội dung:",
#    "Lịch sử Việt Nam bắt đầu từ",
#    "Văn hóa truyền thống của người Việt",
#    "Tiêu đề: Phở\n\nNội dung: Phở là"
#]

#print("\n" + "=" * 50)
#print("TESTING THE ORIGINAL MODEL")
#print("=" * 50)

#for i, prompt in enumerate(test_prompts, 1):
#    print(f"\n--- Test {i} ---")
#    print(f"Prompt: {prompt}")
#    print("-" * 40)
#    
#    generated = generate_text(prompt, max_length=150, temperature=0.7)
#    print(f"Generated: {generated}")

## Main loop

In [34]:
print("Starting training...")

# Create output directory
os.makedirs(config.output_dir, exist_ok=True)

# Training history
training_history = {
    'train_losses': [],
    'train_times': [],
    'valid_losses': [],
    'valid_perplexities': [],
    'valid_times': [],
    'learning_rates': []
}

best_valid_loss = float('inf')
step_count = 0

for epoch in range(config.num_train_epochs):
    print(f"\n{'=' * 50}")
    print(f"Epoch {epoch + 1}/{config.num_train_epochs}")
    print(f"{'=' * 50}")
    
    # Training
    start_time = time.time()
    train_loss = train_epoch(model, train_dataloader, optimizer, scheduler, epoch)
    end_time = time.time()
    
    elapsed_time = end_time - start_time
    train_mins, train_secs = divmod(elapsed_time, 60)
    training_history['train_times'].append(train_mins)
    print(f"Training Time: {int(train_mins)} mins {int(train_secs)} seconds")
    
    training_history['train_losses'].append(train_loss)
    print(f"Training Loss: {train_loss:.4f}")
    
    # Validation
    start_time = time.time()
    valid_loss, perplexity = validate(model, valid_dataloader)
    end_time = time.time()
    
    elapsed_time = end_time - start_time
    valid_mins, valid_secs = divmod(elapsed_time, 60)
    training_history['valid_times'].append(valid_mins)
    print(f"Training Time: {int(valid_mins)} mins {int(valid_secs)} seconds")
    
    training_history['valid_losses'].append(valid_loss)
    training_history['valid_perplexities'].append(perplexity)
    print(f"Validation Loss: {valid_loss:.4f}")
    print(f"Perplexity: {perplexity:.2f}")
    
    # Log to wandb
    if config.use_wandb:
        wandb.log({
            "epoch": epoch + 1,
            "train_time (m)": train_mins,
            "valid_time (m)": valid_mins,
            "valid_loss": valid_loss,
            "perplexity": perplexity,
        })
    
    # Save best model
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        
        model.save_pretrained(config.output_dir)
        tokenizer.save_pretrained(config.output_dir)
        print(f"New best model! Saved to {config.output_dir}")
        
        if config.use_hf:
            model.push_to_hub(config.hf_repo)
            tokenizer.push_to_hub(config.hf_repo)
            print(f"Also saved to repo {config.hf_repo}")
        
    # Save training state
    torch.save({
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'best_valid_loss': best_valid_loss,
        'training_history': training_history
    }, os.path.join(config.output_dir, 'training_state.pt'))
    print(f"Training state saved to {config.output_dir}!")

    if config.use_hf:
        hf_api.upload_file(
            path_or_fileobj=os.path.join(config.output_dir, 'training_state.pt'),
            path_in_repo="training_state.pt",
            repo_id=config.hf_repo,
            repo_type="model",
        )
    print(f"Training state pushed to repo {config.hf_repo}!")
    
    # Clean up GPU memory
    torch.cuda.empty_cache()
    gc.collect()

Starting training...

Epoch 1/5


Training Epoch 1:   0%|          | 0/10000 [00:12<?, ?it/s]

ValueError: too many values to unpack (expected 4)

# After training

## Test after training

In [None]:
test_prompts = [
    "Việt Nam là một quốc gia",
    "Tiêu đề: Hà Nội\n\nNội dung:",
    "Lịch sử Việt Nam bắt đầu từ",
    "Văn hóa truyền thống của người Việt",
    "Tiêu đề: Phở\n\nNội dung: Phở là"
]

print("\n" + "=" * 60)
print("TESTING THE FINE-TUNED MODEL")
print("=" * 60)

for i, prompt in enumerate(test_prompts, 1):
    print(f"\n--- Test {i} ---")
    print(f"Prompt: {prompt}")
    print("-" * 40)
    
    generated = generate_text(prompt, max_length=150, temperature=0.7)
    print(f"Generated: {generated}")

## Save training log

In [None]:
# Save comprehensive training log
training_log = {
    'config': vars(config),
    'model_info': {
        'model_name': config.model_name,
        'num_parameters': model.module.num_parameters(),
        'dataset_name': config.dataset_name,
        'train_samples': len(train_ds),
        'valid_samples': len(valid_ds)
    },
    'training_results': {
        'best_valid_loss': best_valid_loss,
        'final_perplexity': training_history['valid_perplexities'][-1],
        'total_epochs': config.num_train_epochs,
        'total_steps': total_steps
    },
    'training_history': training_history,
    'training_date': datetime.now().isoformat()
}

log_path = os.path.join(config.output_dir, 'training_log.json')
with open(log_path, 'w', encoding='utf-8') as f:
    json.dump(training_log, f, indent=2, ensure_ascii=False)

print(f"\nTraining log saved to {log_path}")

if config.use_hf:
    hf_api.upload_file(
        path_or_fileobj=log_path,
        path_in_repo="training_log.json",
        repo_id=config.hf_repo,
        repo_type="model",
    )
    print(f"\nTraining log pushed to repo {config.hf_repo}")


## Clean up

In [None]:
if config.use_wandb:
    wandb.finish()

In [None]:
# !deepspeed --num_gpus 1 train.py --deepspeed ds_config_pipeline.json