## Differential Transformer

In [1]:
!pip install -q transformers==4.48.1

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.7/9.7 MB[0m [31m60.3 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25h

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import LlamaConfig
from typing import Callable, List, Optional, Tuple, Union
from transformers.cache_utils import Cache
import math
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaModel, LlamaForCausalLM


from transformers import LlamaPreTrainedModel
from typing import Optional, Tuple, Union
import torch
from torch import nn
import os
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm, LlamaRotaryEmbedding
import tempfile
from pathlib import Path
import safetensors.torch as sfts
from huggingface_hub import hf_hub_download, snapshot_download, HfApi

In [3]:
%%writefile modeling.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import LlamaConfig
from typing import Callable, List, Optional, Tuple, Union
from transformers.cache_utils import Cache
import math
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaModel, LlamaForCausalLM


from transformers import LlamaPreTrainedModel
from typing import Optional, Tuple, Union
import torch
from torch import nn
import os
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm, LlamaRotaryEmbedding
import tempfile
from pathlib import Path
import safetensors.torch as sfts
from huggingface_hub import hf_hub_download, snapshot_download, HfApi

def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
    bs, n_kv_heads, slen, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, None, :, :]
        .expand(bs, n_kv_heads, n_rep, slen, head_dim)
        .reshape(bs, n_kv_heads * n_rep, slen, head_dim)
    )

def lambda_init_fn(depth):
    return 0.8 - 0.6 * math.exp(-0.3 * depth)

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True, memory_efficient=False):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.elementwise_affine = elementwise_affine
        if self.elementwise_affine:
            self.weight = nn.Parameter(torch.ones(dim))
        else:
            self.register_parameter('weight', None)

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        if self.weight is not None:
            output = output * self.weight
        return output

    def extra_repr(self) -> str:
        return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}'

def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


class LlamaDiffAttention(nn.Module):
    """Multi-headed attention with DiffAttention technique."""

    def __init__(self, config: LlamaConfig, layer_idx: int):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
        self.scaling = self.head_dim**-0.5
        self.attention_dropout = config.attention_dropout
        self.is_causal = True

        # Keep original projection sizes
        self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias)
        self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias)

        # DiffAttention-specific parameters
        self.lambda_init = lambda_init_fn(layer_idx)
        self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1))
        self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1))
        self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1))
        self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1))

        self.subln = RMSNorm(self.head_dim, eps=1e-5, elementwise_affine=True)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_value: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        input_shape = hidden_states.shape[:-1]
        batch_size, seq_length = input_shape
        hidden_shape = (batch_size, seq_length, -1, self.head_dim)

        # Project to get query and key states (keeping original dimensions)
        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        # Apply rotary embeddings if provided
        if position_embeddings is not None:
            cos, sin = position_embeddings
            query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        # Handle past key values if provided
        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        # Repeat for grouped query attention
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        # Scale queries
        query_states *= self.scaling

        # Create two versions of attention weights using the same projections
        attn_weights_1 = torch.matmul(query_states, key_states.transpose(-1, -2))
        # For the second attention, we apply a learned transformation to query and key
        query_states_2 = query_states * self.lambda_q2.unsqueeze(0).unsqueeze(0)
        key_states_2 = key_states * self.lambda_k2.unsqueeze(0).unsqueeze(0)
        attn_weights_2 = torch.matmul(query_states_2, key_states_2.transpose(-1, -2))

        if attention_mask is not None:
            attn_weights_1 += attention_mask
            attn_weights_2 += attention_mask

        # Process attention weights
        attn_weights_1 = torch.nan_to_num(attn_weights_1)
        attn_weights_2 = torch.nan_to_num(attn_weights_2)
        
        attn_weights_1 = F.softmax(attn_weights_1, dim=-1, dtype=torch.float32).type_as(attn_weights_1)
        attn_weights_2 = F.softmax(attn_weights_2, dim=-1, dtype=torch.float32).type_as(attn_weights_2)

        # Compute differential lambdas
        lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(query_states)
        lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(query_states)
        lambda_full = lambda_1 - lambda_2 + self.lambda_init

        # Combine attention weights with differential scaling
        attn_weights = attn_weights_1 - lambda_full * attn_weights_2

        # Compute attention output
        attn_output = torch.matmul(attn_weights, value_states)
        attn_output = self.subln(attn_output)
        attn_output = attn_output * (1 - self.lambda_init)
        attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()

        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights

def replace_llama_attention_with_diffattention(model: nn.Module):
    """
    Replace all instances of LlamaAttention in the model with LlamaDiffAttention,
    keeping the weights intact.
    
    Args:
        model (nn.Module): The model containing LlamaAttention layers.
    
    Returns:
        nn.Module: The updated model with LlamaDiffAttention.
    """
    for name, module in model.named_modules():
        if isinstance(module, LlamaAttention):
            # Extract the configuration and layer index
            config = module.config
            layer_idx = module.layer_idx

            # Create a new LlamaDiffAttention instance
            diff_attention = LlamaDiffAttention(config, layer_idx)

            # Copy weights from the original LlamaAttention
            diff_attention.q_proj.weight.data = module.q_proj.weight.data.clone()
            diff_attention.k_proj.weight.data = module.k_proj.weight.data.clone()
            diff_attention.v_proj.weight.data = module.v_proj.weight.data.clone()
            diff_attention.o_proj.weight.data = module.o_proj.weight.data.clone()
            if config.attention_bias:
                diff_attention.q_proj.bias.data = module.q_proj.bias.data.clone()
                diff_attention.k_proj.bias.data = module.k_proj.bias.data.clone()
                diff_attention.v_proj.bias.data = module.v_proj.bias.data.clone()
                diff_attention.o_proj.bias.data = module.o_proj.bias.data.clone()

            # Replace the module in the model
            parent_module = model
            for part in name.split('.')[:-1]:
                parent_module = getattr(parent_module, part)
            setattr(parent_module, name.split('.')[-1], diff_attention)

    return model

class DiffLlamaModel(LlamaModel):
    config_class = LlamaConfig
    def __init__(self, config):
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList(
            [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.rotary_emb = LlamaRotaryEmbedding(config=config)
        self.gradient_checkpointing = False

        # Initialize weights and apply final processing
        self.post_init()
        
        self.layers = replace_llama_attention_with_diffattention(self.layers)
        
    def push_to_hub(self, repo_id: str, token: str = None, private: bool = False):
        api = HfApi(token=token)
        repo_id = api.create_repo(repo_id=repo_id, exist_ok=True, private=private).repo_id
        print("Creating repository:", repo_id)
        
        with tempfile.TemporaryDirectory() as tmp:
            # Create the full path for saving
            saved_path = Path(tmp) / repo_id
            saved_path.mkdir(parents=True, exist_ok=True)
            
            # Prepare the model file path
            model_path = saved_path / "model.safetensors"
            
            # Get and process state dict
            state_dict = self.state_dict()
            state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()}
            
            print(f"Saving model to {model_path}")
            sfts.save_file(state_dict, str(model_path))
            
            print("Uploading to Hub...")
            api.upload_folder(
                folder_path=str(saved_path),
                repo_id=repo_id,
                path_in_repo="",
            )
            print("Upload complete!")

    
    @classmethod
    def from_llama_model(cls, model):
        base_model = model.model if hasattr(model, 'model') else model
        
        # Create a new DiffLlamaModel instance with the same config
        diff_model = cls(base_model.config)
        
        # Copy weights from the original model
        diff_model.embed_tokens.weight.data = base_model.embed_tokens.weight.data.clone()
        diff_model.norm.weight.data = base_model.norm.weight.data.clone()
        
        # Copy layer weights
        for i, (orig_layer, diff_layer) in enumerate(zip(base_model.layers, diff_model.layers)):
            # Copy non-attention weights
            diff_layer.mlp.gate_proj.weight.data = orig_layer.mlp.gate_proj.weight.data.clone()
            diff_layer.mlp.down_proj.weight.data = orig_layer.mlp.down_proj.weight.data.clone()
            diff_layer.mlp.up_proj.weight.data = orig_layer.mlp.up_proj.weight.data.clone()
            diff_layer.input_layernorm.weight.data = orig_layer.input_layernorm.weight.data.clone()
            diff_layer.post_attention_layernorm.weight.data = orig_layer.post_attention_layernorm.weight.data.clone()
            
            # The attention weights are handled by replace_llama_attention_with_diffattention
            # which was called in __init__
        
        return diff_model

    
    @classmethod
    def from_pretrained(
            cls,
            pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
            token = None,
            config = None,
            *model_args,
            **kwargs
        ):
        with tempfile.TemporaryDirectory() as tmp:
            # load the repo and the config
            snapshot_download(repo_id=pretrained_model_name_or_path, local_dir=f"{tmp}/repo/", token=token)
            if config is None:
                config = LlamaConfig.from_pretrained(f"{tmp}/repo/", token=token)
                
            model = replace_llama_attention_with_diffattention(cls(config))
            sfts.load_model(model, f"{tmp}/repo/model.safetensors")
            
        return model


class DiffLlamaForCausalLM(LlamaForCausalLM):
    tied_weights_keys = ["lm_head.weight"]
    tp_plan = {"lm_head": "colwise_rep"}
    
    def __init__(self, config):
        super().__init__(config)
        self.model = DiffLlamaModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        
        # Initialize weights and apply final processing
        self.post_init()

    def push_to_hub(self, repo_id: str, token: str = None, private: bool = False):
        api = HfApi(token=token)
        repo_id = api.create_repo(repo_id=repo_id, exist_ok=True, private=private).repo_id
        print("Creating repository:", repo_id)
        
        with tempfile.TemporaryDirectory() as tmp:
            saved_path = Path(tmp) / repo_id
            saved_path.mkdir(parents=True, exist_ok=True)
            
            model_path = saved_path / "model.safetensors"
            
            state_dict = self.state_dict()
            print(f"Saving model to {model_path}")
            sfts.save_file(state_dict, str(model_path))
        
            api.upload_folder(
                folder_path=str(saved_path),
                repo_id=repo_id,
                path_in_repo="",
            )
    
    @classmethod
    def from_llama_model(cls, model):
        diff_model = cls(model.config)
        diff_model.model = DiffLlamaModel.from_llama_model(model.model)
        diff_model.lm_head.weight.data = model.lm_head.weight.data.clone()
        return diff_model
    
    @classmethod
    def from_pretrained(
            cls,
            pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
            token = None,
            config = None,
            *model_args,
            **kwargs
        ):
        with tempfile.TemporaryDirectory() as tmp:
            snapshot_download(repo_id=pretrained_model_name_or_path, local_dir=f"{tmp}/repo/", token=token)
            if config is None:
                config = LlamaConfig.from_pretrained(f"{tmp}/repo/", token=token)
            
            model = cls(config)
            
            # Load the state dict
            state_dict = sfts.load_file(f"{tmp}/repo/model.safetensors")
            
            # Load the processed state dict
            model.load_state_dict(state_dict)
            
        return model

Writing modeling.py


## Train Diff

In [4]:
!pip install -q accelerate --upgrade

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m336.6/336.6 kB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0mta [36m0:00:01[0m
[?25h

In [5]:
!pip install -q wandb

import wandb
wandb.login(key="")

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33marthur-lagacherie[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [6]:
from huggingface_hub import login

login("hf_xxx")

In [7]:
%%writefile main.py
import torch
from accelerate import Accelerator
from tqdm.auto import tqdm
import wandb
import argparse
from dataclasses import dataclass
from typing import Optional, List
import yaml
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
import numpy as np
from modeling import * 

@dataclass
class TrainingConfig:
    # Model configuration
    model_name: str = "HuggingFaceTB/SmolLM2-360M-Instruct"
    max_length: int = 2048
    
    # Dataset configuration
    dataset_name: str = "Arthur-LAGACHERIE/EntierInstruct-66k"
    dataset_split: str = "train"
    subset:str = "all"
    select: int = None
    question_column: str = "question"
    answer_column: str = "answer"
    val_size: int = 2000
    
    # Training configuration
    batch_size: int = 1
    num_epochs: int = 3
    learning_rate: float = 3e-4
    eval_steps: int = 100
    num_workers: int = 1
    
    # Logging configuration
    project_name: str = "test"
    run_name: Optional[str] = None
    
    @classmethod
    def from_yaml(cls, yaml_path: str) -> 'TrainingConfig':
        with open(yaml_path, 'r') as f:
            config_dict = yaml.safe_load(f)
        return cls(**config_dict)
    
    def save_yaml(self, yaml_path: str):
        with open(yaml_path, 'w') as f:
            yaml.dump(self.__dict__, f)

def get_dataloaders(config: TrainingConfig):
    
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)
    from datasets import load_dataset
    """
    if config.select is not None:
        dataset = load_dataset(config.dataset_name, split=config.dataset_split)
        dataset = dataset.select(range(config.select))
    else:
        dataset = load_dataset(config.dataset_name, split=config.dataset_split)
    
    def apply_template(example):
        formatted_messages = []
        for message in example['messages']:
            role = message['role']
            content = message['content']
            if role == "user":
                formatted_messages.append(f"<|im_start|>user\n{content}<|im_end|>")
            elif role == "assistant":
                formatted_messages.append(f"<|im_start|>assistant\n{content}<|im_end|>")
        return {"text": "\n".join(formatted_messages)}
        
    dataset = dataset.map(apply_template)
    
    def tokenization(example):
        return tokenizer(example["text"], truncation=True, max_length=config.max_length)
    
    dataset = dataset.map(tokenization, batched=True)
    column_to_remove = [col for col in dataset.column_names if col not in ["input_ids", "attention_mask"]]
    dataset = dataset.remove_columns(column_to_remove)
    print(dataset)
    """
    dataset = load_dataset(config.dataset_name, split="train")
    print(dataset)
    train_dataset = dataset
    val_dataset = dataset.select(np.random.choice(len(dataset), config.val_size, replace=False))
    
    from transformers import DataCollatorForLanguageModeling
    data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
    
    from torch.utils.data import DataLoader
    train_dataloader = DataLoader(
        train_dataset, 
        shuffle=True, 
        batch_size=config.batch_size, 
        collate_fn=data_collator, 
        num_workers=config.num_workers
    )
    val_dataloader = DataLoader(
        val_dataset, 
        shuffle=False, 
        batch_size=config.batch_size, 
        collate_fn=data_collator, 
        num_workers=config.num_workers
    )
    return train_dataloader, val_dataloader

def get_model(config: TrainingConfig):
    model = AutoModelForCausalLM.from_pretrained(config.model_name)
    return DiffLlamaForCausalLM.from_llama_model(model)


def train(config: TrainingConfig, model, train_dataloader, val_dataloader):
    accelerator = Accelerator()
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
    
    train_dataloader, val_dataloader, model, optimizer = accelerator.prepare(
        train_dataloader, val_dataloader, model, optimizer
    )
    
    run_name = config.run_name or f"{config.project_name}-{config.model_name.split('/')[-1]}"
    wandb.init(project=config.project_name, name=run_name, config=config.__dict__)
    
    model.train()
    total_steps = len(train_dataloader) * config.num_epochs
    
    def evaluate(model, val_dataloader):
        model.eval()
        total_loss = 0
        with torch.no_grad():
            for batch in tqdm(val_dataloader, desc="Evaluating"):
                outputs = model(**batch)
                total_loss += outputs.loss.item()
        model.train()
        return total_loss / len(val_dataloader)
    
    with tqdm(desc="Training", total=total_steps) as pbar:
        for epoch in range(config.num_epochs):
            for step, batch in enumerate(train_dataloader):
                outputs = model(**batch)
                loss = outputs.loss
                
                accelerator.backward(loss)
                optimizer.step()
                optimizer.zero_grad()
                
                pbar.update(1)
                
                if step % config.eval_steps == 0:
                    val_loss = evaluate(model, val_dataloader)
                    wandb.log({
                        "val_loss": val_loss,
                        "global_step": step + epoch * len(train_dataloader),
                    })
    
    wandb.finish()

def main():
    parser = argparse.ArgumentParser(description="Training script with Accelerate")
    parser.add_argument("--config", type=str, help="Path to YAML config file")
    parser.add_argument("--model_name", type=str, help="Model name or path")
    parser.add_argument("--batch_size", type=int, help="Batch size")
    parser.add_argument("--num_epochs", type=int, help="Number of training epochs")
    parser.add_argument("--learning_rate", type=float, help="Learning rate")
    parser.add_argument("--project_name", type=str, help="WandB project name")
    parser.add_argument("--run_name", type=str, help="WandB run name")
    
    args = parser.parse_args()
    
    if args.config:
        config = TrainingConfig.from_yaml(args.config)
    else:
        config = TrainingConfig()
    
    # Override config with command line arguments
    for arg in vars(args):
        if getattr(args, arg) is not None and arg != 'config':
            setattr(config, arg, getattr(args, arg))
    
    mloramodel = get_model(config)
    train_dataloader, val_dataloader = get_dataloaders(config)
    
    train(config, mloramodel, train_dataloader, val_dataloader)

if __name__ == "__main__":
    main()

Writing main.py


In [11]:
from dataclasses import dataclass
from typing import Optional, List
import yaml

@dataclass
class TrainingConfig():
    # Model configuration
    model_name: str = "HuggingFaceTB/SmolLM2-360M-Instruct"
    max_length: int = 2048
    
    # Dataset configuration
    dataset_name: str = "Arthur-LAGACHERIE/smoltalk-semhashed"
    dataset_split: str = "train"
    subset:str = "main"
    select: int = None
    question_column: str = "messages"
    answer_column: str = "messages"
    val_size: int = 2000
    
    # Training configuration
    batch_size: int = 1
    num_epochs: int = 1
    learning_rate: float = 3e-4
    eval_steps: int = 500
    num_workers: int = 1
    
    # Logging configuration
    project_name: str = "Test"
    run_name: Optional[str] = "test"
    @classmethod
    def from_yaml(cls, yaml_path: str) -> 'TrainingConfig':
        with open(yaml_path, 'r') as f:
            config_dict = yaml.safe_load(f)
        return cls(**config_dict)
    
    def save_yaml(self, yaml_path: str):
        with open(yaml_path, 'w') as f:
            yaml.dump(self.__dict__, f)

conf = TrainingConfig()
conf.save_yaml("config.yaml")

In [12]:
!accelerate launch main.py --config "config.yaml"

2025-01-28 17:19:47.516005: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-01-28 17:19:47.516526: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-01-28 17:19:47.551927: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-01-28 17:19:47.552639: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-01-28 17:19:47.564311: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory