## ERA V3 Session 13 Assignment - SmolLM2-135 model


1. Find out the details about SmolLM2-135 model. You won't get model.py free of cost online. You need to read its github, get yaml file for training, and download the 135 model, and then reverse engineer it. 
2. Training it for 5000 steps while predicting every 500 steps on what it utters. Now fully stop the model and save a checkpoint. Now load this checkpoint and train for 50 more steps.
3. Use all the speedups that we have used. 
4. Submit the GitHub link where I can see your README.md explaining the model definition. Parameter Calculation. Upload the model to Spaces as well and then share both the links.

### Defining the Model using LlamaConfig class

SmolLM2-135M model has ```is_llama_config = True``` as one of the parameters, which means it can be created using the LlamaConfig Class by inputing the same parameters as written in ```config_smollm2_135M.yaml```

In [2]:
# model.py
from transformers import LlamaConfig, LlamaForCausalLM, AutoTokenizer
import torch

def create_smollm2_model():
    """
    Constructs a SmoLLM2 model based on the provided configuration.

    Returns:
        tuple: A tuple containing the initialized model and tokenizer.
    """

    model_config = LlamaConfig(
        vocab_size=49152,
        hidden_size=576,
        intermediate_size=1536,
        num_hidden_layers=30,
        num_attention_heads=9,
        num_key_value_heads=3,
        hidden_act="silu",
        max_position_embeddings=2048,
        initializer_range=0.041666666666666664,
        rms_norm_eps=1.0e-05,
        # use_cache=True, As seen in training, this is not needed
        tie_word_embeddings=True,
        rope_theta=10000.0,
        rope_scaling=None,
        rope_interleaved=False,
        pretraining_tp=1,
        bos_token_id=0,
        eos_token_id=0,
        pad_token_id=None, #  pad_token_id is null in config, setting to None
    )

    model = LlamaForCausalLM(model_config)

    # Initialize weights with std from init_method if needed (Transformers usually handles initialization well)
    # init_std = 0.041666666666666664
    # You can add custom weight initialization here if required based on init_method.std

    # Set the dtype to bfloat16
    model.to(torch.bfloat16)

    # Load the tokenizer
    tokenizer_name_or_path = "HuggingFaceTB/cosmo2-tokenizer"
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)


    return model, tokenizer

model, tokenizer = create_smollm2_model()
print("SmoLLM2 model and tokenizer successfully created!")
print(f"Model dtype: {next(model.parameters()).dtype}")
print(f"Tokenizer: {tokenizer.__class__.__name__} loaded from {tokenizer.name_or_path}")

  from .autonotebook import tqdm as notebook_tqdm


SmoLLM2 model and tokenizer successfully created!
Model dtype: torch.bfloat16
Tokenizer: GPT2TokenizerFast loaded from HuggingFaceTB/cosmo2-tokenizer


### Model Architecture and parameter count

In [3]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(49152, 576)
    (layers): ModuleList(
      (0-29): 30 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=576, out_features=576, bias=False)
          (k_proj): Linear(in_features=576, out_features=192, bias=False)
          (v_proj): Linear(in_features=576, out_features=192, bias=False)
          (o_proj): Linear(in_features=576, out_features=576, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
          (up_proj): Linear(in_features=576, out_features=1536, bias=False)
          (down_proj): Linear(in_features=1536, out_features=576, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((576,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((576,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm

In [4]:
tokenizer

GPT2TokenizerFast(name_or_path='HuggingFaceTB/cosmo2-tokenizer', vocab_size=49152, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'additional_special_tokens': ['<|endoftext|>', '<|im_start|>', '<|im_end|>', '<repo_name>', '<reponame>', '<file_sep>', '<filename>', '<gh_stars>', '<issue_start>', '<issue_comment>', '<issue_closed>', '<jupyter_start>', '<jupyter_text>', '<jupyter_code>', '<jupyter_output>', '<jupyter_script>', '<empty_output>']}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	0: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<|im_start|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("<|im_end|>", rstrip=False, lstrip=False, single_word=False, normalized=False, speci

In [5]:
model.config

LlamaConfig {
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "eos_token_id": 0,
  "hidden_act": "silu",
  "hidden_size": 576,
  "initializer_range": 0.041666666666666664,
  "intermediate_size": 1536,
  "max_position_embeddings": 2048,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 9,
  "num_hidden_layers": 30,
  "num_key_value_heads": 3,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_interleaved": false,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "tie_word_embeddings": true,
  "transformers_version": "4.44.1",
  "use_cache": true,
  "vocab_size": 49152
}

In [6]:
total_params = sum(p.numel() for p in model.parameters())
print(f'Total parameters: {total_params}')
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Trainable parameters: {trainable_params}')

Total parameters: 134515008
Trainable parameters: 134515008


## Sample output generation

In [2]:
# Example usage (optional):
input_text = "Hello, world!"
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs)
decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Generated output: {decoded_output}")

Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


Generated output: Hello, world!immer Joseph params transactionDateTimeField plagiarism sniffShdis borders savage debit interfering polyunsaturatedothySh


## Loading input text and training the model

In [1]:
import os
with open("datasets/input.txt", "r", encoding="utf-8") as f:
    text = f.read()

In [3]:
text[:1000]

"First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you know Caius Marcius is chief enemy to the people.\n\nAll:\nWe know't, we know't.\n\nFirst Citizen:\nLet us kill him, and we'll have corn at our own price.\nIs't a verdict?\n\nAll:\nNo more talking on't; let it be done: away, away!\n\nSecond Citizen:\nOne word, good citizens.\n\nFirst Citizen:\nWe are accounted poor citizens, the patricians good.\nWhat authority surfeits on would relieve us: if they\nwould yield us but the superfluity, while it were\nwholesome, we might guess they relieved us humanely;\nbut they think we are too dear: the leanness that\nafflicts us, the object of our misery, is as an\ninventory to particularise their abundance; our\nsufferance is a gain to them Let us revenge this with\nour pikes, ere we become rakes: for the gods know I\nspeak this in hunger 

In [3]:
# train.py
import torch
from transformers import AdamW, get_linear_schedule_with_warmup, AutoModelForCausalLM # Changed import for checkpointing
# from model import create_smollm2_model
import os

# --- 1. Hyperparameters and Configuration ---
TRAIN_STEPS = 5000
PREDICT_EVERY_STEPS = 500
CHECKPOINT_EVERY_STEPS = 500
SEQUENCE_LENGTH = 2048  # As defined in config - May need to reduce if OOM
MICRO_BATCH_SIZE = 8     # As defined in config - REDUCE THIS FIRST to fix OOM
LEARNING_RATE = 3e-4     # You can adjust, simplified from config for now
WARMUP_STEPS = 500      # Simplified warmup
CHECKPOINT_PATH = "smollm2-checkpoints"
INPUT_FILE = "datasets\input.txt"
PREDICTION_PROMPT = "To be or not to be," # A starting prompt for prediction

GRADIENT_ACCUMULATION_STEPS = 4 # --- Speedup 1: Gradient Accumulation --- Accumulate gradients over this many steps
USE_ACTIVATION_CHECKPOINTING = True # --- Speedup 2: Activation Checkpointing --- Enable or disable activation checkpointing

# --- ***MEMORY OPTIMIZATION - REDUCE THESE IF OOM ERROR PERSISTS*** ---
REDUCE_MICRO_BATCH_SIZE_FACTOR = 2 # --- Reduce Micro Batch Size --- Reduce micro_batch_size by this factor
# If you STILL get OOM, try reducing SEQUENCE_LENGTH_FACTOR (but micro_batch_size reduction is usually more effective first)
REDUCE_SEQUENCE_LENGTH_FACTOR = 1 # --- Reduce Sequence Length --- Reduce sequence length by this factor if needed

# --- 2. Adjusted Hyperparameters based on Reduction Factors ---
ADJUSTED_MICRO_BATCH_SIZE = MICRO_BATCH_SIZE // REDUCE_MICRO_BATCH_SIZE_FACTOR
if ADJUSTED_MICRO_BATCH_SIZE <= 0:
    ADJUSTED_MICRO_BATCH_SIZE = 1 # Ensure micro_batch_size is at least 1
ADJUSTED_SEQUENCE_LENGTH = SEQUENCE_LENGTH // REDUCE_SEQUENCE_LENGTH_FACTOR
if ADJUSTED_SEQUENCE_LENGTH <= 0:
    ADJUSTED_SEQUENCE_LENGTH = 64 # Ensure sequence_length is at least reasonably sized

MICRO_BATCH_SIZE = ADJUSTED_MICRO_BATCH_SIZE # Update micro batch size
SEQUENCE_LENGTH = ADJUSTED_SEQUENCE_LENGTH # Update sequence length


# --- 3. Create Checkpoint Directory ---
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

# --- 4. Load Model and Tokenizer ---
model, tokenizer = create_smollm2_model()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

# --- Speedup 3: Mixed Precision Training (bfloat16) --- Model is already in bfloat16 from model.py
# We are using bfloat16 dtype for model which speeds up training and reduces memory usage on compatible GPUs (like NVIDIA A100, H100)

if USE_ACTIVATION_CHECKPOINTING: # --- Enable activation checkpointing if flag is set ---
    model.gradient_checkpointing_enable() # Enable activation checkpointing from transformers

# --- 5. Prepare Dataset ---
print("Loading and tokenizing dataset...")
with open(INPUT_FILE, "r", encoding="utf-8") as f:
    text = f.read()

tokenized_dataset = tokenizer(text, return_tensors="pt", truncation=False) # No truncation initially
input_ids = tokenized_dataset['input_ids']

# --- 6. Create Data Batches ---
def create_batches(input_ids, seq_len, micro_batch_size):
    num_tokens = input_ids.shape[1]
    num_batches = num_tokens // seq_len
    truncated_input_ids = input_ids[:, :num_batches * seq_len] # Truncate to fit full sequences based on potentially reduced SEQUENCE_LENGTH
    batched_input_ids = truncated_input_ids.reshape(-1, seq_len) # Reshape into sequences
    num_micro_batches = num_batches // micro_batch_size
    micro_batches = []
    for i in range(num_micro_batches):
        start_index = i * micro_batch_size
        end_index = (i + 1) * micro_batch_size
        batch = batched_input_ids[start_index:end_index]
        micro_batches.append(batch)
    return micro_batches

micro_batches = create_batches(input_ids.to(device), SEQUENCE_LENGTH, MICRO_BATCH_SIZE) # Using potentially reduced SEQUENCE_LENGTH and MICRO_BATCH_SIZE
print(f"Dataset prepared with {len(micro_batches)} micro-batches using micro_batch_size={MICRO_BATCH_SIZE} and sequence_length={SEQUENCE_LENGTH}.")


# --- 7. Optimizer and Scheduler ---
# --- Speedup 4: Fused AdamW Optimizer --- Using fused AdamW if available (often faster)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01, fused=True) # Using fused=True for potential speedup
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=WARMUP_STEPS, num_training_steps=TRAIN_STEPS // GRADIENT_ACCUMULATION_STEPS # Adjusted lr scheduler steps
) # Simplified scheduler

# --- imports ---
from tqdm import tqdm

# --- training loop ---
print("Starting training...")
model.train() # Set model to training mode
global_step = 0
accumulated_loss = 0 # To track loss over accumulation steps

# Initialize tqdm progress bar, total is global steps (weight updates)
progress_bar = tqdm(range(1, TRAIN_STEPS + 1), desc="Training", unit="step")

for step in progress_bar: # Wrap training loop with tqdm
    batch_index = (step - 1) % len(micro_batches) # Cycle through batches
    batch = micro_batches[batch_index]

    inputs = batch
    targets = torch.roll(batch, shifts=-1, dims=1) # Next token prediction

    outputs = model(inputs, labels=targets) # Labels for loss calculation
    loss = outputs.loss
    loss = loss / GRADIENT_ACCUMULATION_STEPS # --- Scale loss for gradient accumulation ---
    accumulated_loss += loss.item() # Accumulate loss for logging

    loss.backward()

    if step % GRADIENT_ACCUMULATION_STEPS == 0: # --- Update weights every GRADIENT_ACCUMULATION_STEPS ---
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        global_step += 1 # Increment global step only when weights are updated
        avg_loss = accumulated_loss / GRADIENT_ACCUMULATION_STEPS # Calculate average loss over accumulation steps
        accumulated_loss = 0 # Reset accumulated loss

        if global_step % 10 == 0: # Log loss every 10 global steps (after accumulation)
            # No need for separate print here, tqdm will handle logging

            # --- Update tqdm progress bar with current average loss ---
            progress_bar.set_postfix({"Avg Loss": f"{avg_loss:.4f}"})


    # --- 8. Prediction Interval ---
    if step % PREDICT_EVERY_STEPS == 0:
        model.eval() # Set model to evaluation mode
        print(f"\n--- Prediction at Micro-batch Step {step} (Global Step: {global_step}) ---") # Clarify step counts
        prompt_ids = tokenizer.encode(PREDICTION_PROMPT, return_tensors="pt").to(device)
        sample_outputs = model.generate(
            prompt_ids,
            max_length=len(prompt_ids[0]) + 50, # Generate up to 50 new tokens
            num_return_sequences=1,
            temperature=0.7, # Adjust temperature for creativity
            top_p=0.9,
        )
        predicted_text = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
        print(f"Prompt: '{PREDICTION_PROMPT}'")
        print(f"Generated: '{predicted_text}'\n")
        model.train() # Set back to training mode

    # --- 9. Checkpoint Saving ---
    if step % CHECKPOINT_EVERY_STEPS == 0:
        checkpoint_file = os.path.join(CHECKPOINT_PATH, f"smollm2_checkpoint_step_{step}.pth") # Step here is still micro-batch step
        torch.save({
            'step': step, # Step here is still micro-batch step
            'global_step': global_step, # Saving global step as well (steps with weight updates)
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': loss.item(), # Save current micro-batch loss (last loss in accumulation)
        }, checkpoint_file)
        print(f"Checkpoint saved at Micro-batch Step {step} (Global Step {global_step}) to {checkpoint_file}") # Clarified step counts in checkpoint message

progress_bar.close() # Close progress bar when training finishes
print("Training finished!")

Loading and tokenizing dataset...
Dataset prepared with 41 micro-batches using micro_batch_size=4 and sequence_length=2048.
Starting training...


Training:   0%|          | 0/5000 [00:00<?, ?step/s]`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
Training:  10%|▉         | 499/5000 [09:56<1:28:35,  1.18s/step, Avg Loss=1.9720]


--- Prediction at Micro-batch Step 500 (Global Step: 125) ---


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Prompt: 'To be or not to be,'
Generated: 'To be or not to be, dm emotionally Whole theware
:



::
:::::






























'



Training:  10%|█         | 500/5000 [09:59<2:22:02,  1.89s/step, Avg Loss=1.9720]

Checkpoint saved at Micro-batch Step 500 (Global Step 125) to smollm2-checkpoints\smollm2_checkpoint_step_500.pth


Training:  20%|█▉        | 999/5000 [20:54<1:21:06,  1.22s/step, Avg Loss=1.6093]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.



--- Prediction at Micro-batch Step 1000 (Global Step: 250) ---
Prompt: 'To be or not to be,'
Generated: 'To be or not to be, of
,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,'



Training:  20%|██        | 1000/5000 [20:56<2:10:11,  1.95s/step, Avg Loss=1.6093]

Checkpoint saved at Micro-batch Step 1000 (Global Step 250) to smollm2-checkpoints\smollm2_checkpoint_step_1000.pth


Training:  30%|██▉       | 1499/5000 [32:12<1:24:19,  1.45s/step, Avg Loss=1.3746]


--- Prediction at Micro-batch Step 1500 (Global Step: 375) ---


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


Prompt: 'To be or not to be,'
Generated: 'To be or not to be,
K
KARI: I,
, the of of of
,,,,,,
 I
 IARD
 I:,,,,,, I
 I to to
 I I
 I:,,,,'



Training:  30%|███       | 1500/5000 [32:16<2:08:24,  2.20s/step, Avg Loss=1.3746]

Checkpoint saved at Micro-batch Step 1500 (Global Step 375) to smollm2-checkpoints\smollm2_checkpoint_step_1500.pth


Training:  40%|███▉      | 1999/5000 [43:52<1:15:02,  1.50s/step, Avg Loss=1.2741]


--- Prediction at Micro-batch Step 2000 (Global Step: 500) ---


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,
 will not:,,,,, so
,,,, a
 I,,,,,,,,, a,
 I, the the the the the the,
 I, the the the,

 you'



Training:  40%|████      | 2000/5000 [43:55<1:55:21,  2.31s/step, Avg Loss=1.2741]

Checkpoint saved at Micro-batch Step 2000 (Global Step 500) to smollm2-checkpoints\smollm2_checkpoint_step_2000.pth


Training:  50%|████▉     | 2499/5000 [54:53<50:01,  1.20s/step, Avg Loss=0.9884]  The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.



--- Prediction at Micro-batch Step 2500 (Global Step: 625) ---


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


Prompt: 'To be or not to be,'
Generated: 'To be or not to be, a to to I
 a's to to a to
 a?
 a to to to
, I I I I I
 a to, that
 a of, I I I I
 a of I I I I I
 a.'



Training:  50%|█████     | 2500/5000 [54:56<1:19:07,  1.90s/step, Avg Loss=0.9884]

Checkpoint saved at Micro-batch Step 2500 (Global Step 625) to smollm2-checkpoints\smollm2_checkpoint_step_2500.pth


Training:  60%|█████▉    | 2999/5000 [1:05:00<40:08,  1.20s/step, Avg Loss=0.6190]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.



--- Prediction at Micro-batch Step 3000 (Global Step: 750) ---
Prompt: 'To be or not to be,'
Generated: 'To be or not to be, I, a; the!
 so is with again my the?
N now'd! will the so, it will, I
 we of or.
INGICH II
 a: I you I so
 so a- and and.'



Training:  60%|██████    | 3000/5000 [1:05:03<1:04:23,  1.93s/step, Avg Loss=0.6190]

Checkpoint saved at Micro-batch Step 3000 (Global Step 750) to smollm2-checkpoints\smollm2_checkpoint_step_3000.pth


Training:  70%|██████▉   | 3499/5000 [1:15:04<30:02,  1.20s/step, Avg Loss=0.3180]  The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.



--- Prediction at Micro-batch Step 3500 (Global Step: 875) ---


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,,, a
 so;! name. but on!
all him
antIL: I will it
s,, of
 made to,--, or! a
 by but the made but made and
 him'.'



Training:  70%|███████   | 3500/5000 [1:15:07<48:46,  1.95s/step, Avg Loss=0.3180]

Checkpoint saved at Micro-batch Step 3500 (Global Step 875) to smollm2-checkpoints\smollm2_checkpoint_step_3500.pth


Training:  80%|███████▉  | 3999/5000 [1:25:09<19:58,  1.20s/step, Avg Loss=0.2034]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.



--- Prediction at Micro-batch Step 4000 (Global Step: 1000) ---


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,,,, I;, aw,
's thy now;
 mine butI'
P I to it or not by: he king.
L will, now I not both not.
L those!
 will at a'



Training:  80%|████████  | 4000/5000 [1:25:12<33:30,  2.01s/step, Avg Loss=0.2034]

Checkpoint saved at Micro-batch Step 4000 (Global Step 1000) to smollm2-checkpoints\smollm2_checkpoint_step_4000.pth


Training:  90%|████████▉ | 4499/5000 [1:35:13<09:59,  1.20s/step, Avg Loss=0.1876]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.



--- Prediction at Micro-batch Step 4500 (Global Step: 1125) ---


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,,,, I;, aw,
's thy now;
 mine butI'
P I to it or not by: he king.
L will, now I not both not.
LENT it.
 me--'



Training:  90%|█████████ | 4500/5000 [1:35:16<15:29,  1.86s/step, Avg Loss=0.1876]

Checkpoint saved at Micro-batch Step 4500 (Global Step 1125) to smollm2-checkpoints\smollm2_checkpoint_step_4500.pth


Training: 100%|█████████▉| 4999/5000 [1:45:42<00:01,  1.34s/step, Avg Loss=0.1753]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.



--- Prediction at Micro-batch Step 5000 (Global Step: 1250) ---


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,,,, I;, aw,
's thy now!
 it
 didy
 but! thee
 sweet,; did,,,, my to,
 did else and's; the if you
 will. all do'



Training: 100%|██████████| 5000/5000 [1:45:45<00:00,  1.27s/step, Avg Loss=0.1753]

Checkpoint saved at Micro-batch Step 5000 (Global Step 1250) to smollm2-checkpoints\smollm2_checkpoint_step_5000.pth
Training finished!





I have trained it for 5000 micro-steps (1250 weight-updating steps because of Gradient Accumulation)

## Loading Saved Checkpoint and re-training

To achieve the intended objective, we will load the checkpoint and further train it till 5000th weight-updating step (20000 total micro-steps)

In [3]:
RESUME_CHECKPOINT_PATH = "smollm2-checkpoints\smollm2_checkpoint_step_5000.pth"

# train.py
import torch
from transformers import get_linear_schedule_with_warmup # Changed import for checkpointing
# from model import create_smollm2_model
import os

# --- 1. Hyperparameters and Configuration ---
TARGET_GLOBAL_STEPS = 5000
PREDICT_EVERY_STEPS = 500
CHECKPOINT_EVERY_STEPS = 500
SEQUENCE_LENGTH = 2048  # As defined in config - May need to reduce if OOM
MICRO_BATCH_SIZE = 8     # As defined in config - REDUCE THIS FIRST to fix OOM
LEARNING_RATE = 3e-4     # You can adjust, simplified from config for now
WARMUP_STEPS = 500      # Simplified warmup
CHECKPOINT_PATH = "smollm2-checkpoints"
INPUT_FILE = "datasets\input.txt"
PREDICTION_PROMPT = "To be or not to be," # A starting prompt for prediction

GRADIENT_ACCUMULATION_STEPS = 4 # --- Speedup 1: Gradient Accumulation --- Accumulate gradients over this many steps
TARGET_TRAIN_STEPS = TARGET_GLOBAL_STEPS * GRADIENT_ACCUMULATION_STEPS # Adjusted total steps based on accumulation steps

USE_ACTIVATION_CHECKPOINTING = True # --- Speedup 2: Activation Checkpointing --- Enable or disable activation checkpointing

# --- ***MEMORY OPTIMIZATION - REDUCE THESE IF OOM ERROR PERSISTS*** ---
REDUCE_MICRO_BATCH_SIZE_FACTOR = 2 # --- Reduce Micro Batch Size --- Reduce micro_batch_size by this factor
# If you STILL get OOM, try reducing SEQUENCE_LENGTH_FACTOR (but micro_batch_size reduction is usually more effective first)
REDUCE_SEQUENCE_LENGTH_FACTOR = 1 # --- Reduce Sequence Length --- Reduce sequence length by this factor if needed

# --- 2. Adjusted Hyperparameters based on Reduction Factors ---
ADJUSTED_MICRO_BATCH_SIZE = MICRO_BATCH_SIZE // REDUCE_MICRO_BATCH_SIZE_FACTOR
if ADJUSTED_MICRO_BATCH_SIZE <= 0:
    ADJUSTED_MICRO_BATCH_SIZE = 1 # Ensure micro_batch_size is at least 1
ADJUSTED_SEQUENCE_LENGTH = SEQUENCE_LENGTH // REDUCE_SEQUENCE_LENGTH_FACTOR
if ADJUSTED_SEQUENCE_LENGTH <= 0:
    ADJUSTED_SEQUENCE_LENGTH = 64 # Ensure sequence_length is at least reasonably sized

MICRO_BATCH_SIZE = ADJUSTED_MICRO_BATCH_SIZE # Update micro batch size
SEQUENCE_LENGTH = ADJUSTED_SEQUENCE_LENGTH # Update sequence length


# --- 3. Create Checkpoint Directory ---
os.makedirs(CHECKPOINT_PATH, exist_ok=True)


# --- 4. Load Model and Tokenizer ---
model, tokenizer = create_smollm2_model()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

# --- Speedup 3: Mixed Precision Training (bfloat16) --- Model is already in bfloat16 from model.py
# We are using bfloat16 dtype for model which speeds up training and reduces memory usage on compatible GPUs (like NVIDIA A100, H100)

if USE_ACTIVATION_CHECKPOINTING: # --- Enable activation checkpointing if flag is set ---
    model.gradient_checkpointing_enable() # Enable activation checkpointing from transformers

# --- 5. Prepare Dataset ---
print("Loading and tokenizing dataset...")
with open(INPUT_FILE, "r", encoding="utf-8") as f:
    text = f.read()

tokenized_dataset = tokenizer(text, return_tensors="pt", truncation=False) # No truncation initially
input_ids = tokenized_dataset['input_ids']

# --- 8. Optimizer and Scheduler ---
# --- Speedup 4: Fused AdamW Optimizer --- Using fused AdamW if available (often faster)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01, fused=True) # Using fused=True for potential speedup
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=WARMUP_STEPS, num_training_steps=TARGET_TRAIN_STEPS // GRADIENT_ACCUMULATION_STEPS # Adjusted lr scheduler steps
) # Simplified scheduler

# --- 6. Load Checkpoint if `resume_checkpoint_path` is provided ---
initial_global_step = 0 # Track initial global step, default is 0 for new training
if RESUME_CHECKPOINT_PATH:
    print(f"Loading checkpoint from: {RESUME_CHECKPOINT_PATH}")
    checkpoint = torch.load(RESUME_CHECKPOINT_PATH, map_location=device, weights_only=False) # Load checkpoint to correct device
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    initial_global_step = checkpoint.get('global_step', 0) # Try to get global_step, default to 0 if not found
    print(f"Resuming training from global step: {initial_global_step}")
else:
    print("Starting training from scratch.")

start_step = initial_global_step * GRADIENT_ACCUMULATION_STEPS + 1 # Calculate micro-batch step to start from

# --- 7. Create Data Batches ---
def create_batches(input_ids, seq_len, micro_batch_size):
    num_tokens = input_ids.shape[1]
    num_batches = num_tokens // seq_len
    truncated_input_ids = input_ids[:, :num_batches * seq_len] # Truncate to fit full sequences based on potentially reduced SEQUENCE_LENGTH
    batched_input_ids = truncated_input_ids.reshape(-1, seq_len) # Reshape into sequences
    num_micro_batches = num_batches // micro_batch_size
    micro_batches = []
    for i in range(num_micro_batches):
        start_index = i * micro_batch_size
        end_index = (i + 1) * micro_batch_size
        batch = batched_input_ids[start_index:end_index]
        micro_batches.append(batch)
    return micro_batches

micro_batches = create_batches(input_ids.to(device), SEQUENCE_LENGTH, MICRO_BATCH_SIZE) # Using potentially reduced SEQUENCE_LENGTH and MICRO_BATCH_SIZE
print(f"Dataset prepared with {len(micro_batches)} micro-batches using micro_batch_size={MICRO_BATCH_SIZE} and sequence_length={SEQUENCE_LENGTH}.")


# --- imports ---
from tqdm import tqdm

# --- 9. Training Loop ---
print("Starting training...")
model.train() # Set model to training mode
global_step = initial_global_step # Initialize global_step from checkpoint or 0
accumulated_loss = 0 # To track loss over accumulation steps

# Initialize tqdm progress bar, total is remaining global steps
remaining_global_steps = TARGET_GLOBAL_STEPS - initial_global_step
if remaining_global_steps <= 0:
    print("Training already completed to target steps or beyond based on checkpoint.")
    exit()

progress_bar = tqdm(range(start_step, TARGET_TRAIN_STEPS + 1), desc="Training", unit="step", initial=start_step -1, total=TARGET_TRAIN_STEPS) # Initialize tqdm with start and total

for step in progress_bar: # Wrap training loop with tqdm
    batch_index = (step - 1) % len(micro_batches) # Cycle through batches
    batch = micro_batches[batch_index]

    inputs = batch
    targets = torch.roll(batch, shifts=-1, dims=1) # Next token prediction

    outputs = model(inputs, labels=targets) # Labels for loss calculation
    loss = outputs.loss
    loss = loss / GRADIENT_ACCUMULATION_STEPS # --- Scale loss for gradient accumulation ---
    accumulated_loss += loss.item() # Accumulate loss for logging

    loss.backward()

    if step % GRADIENT_ACCUMULATION_STEPS == 0: # --- Update weights every GRADIENT_ACCUMULATION_STEPS ---
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        global_step += 1 # Increment global step only when weights are updated
        avg_loss = accumulated_loss / GRADIENT_ACCUMULATION_STEPS # Calculate average loss over accumulation steps
        accumulated_loss = 0 # Reset accumulated loss

        if global_step % 10 == 0: # Log loss every 10 global steps (after accumulation)
            # No need for separate print here, tqdm will handle logging

            # --- Update tqdm progress bar with current average loss ---
            progress_bar.set_postfix({"Avg Loss": f"{avg_loss:.4f}"})


    # --- 8. Prediction Interval ---
    if step % PREDICT_EVERY_STEPS == 0:
        model.eval() # Set model to evaluation mode
        print(f"\n--- Prediction at Micro-batch Step {step} (Global Step: {global_step}) ---") # Clarify step counts
        prompt_ids = tokenizer.encode(PREDICTION_PROMPT, return_tensors="pt").to(device)
        sample_outputs = model.generate(
            prompt_ids,
            max_length=len(prompt_ids[0]) + 50, # Generate up to 50 new tokens
            num_return_sequences=1,
            temperature=0.7, # Adjust temperature for creativity
            top_p=0.9,
        )
        predicted_text = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
        print(f"Prompt: '{PREDICTION_PROMPT}'")
        print(f"Generated: '{predicted_text}'\n")
        model.train() # Set back to training mode

    # --- 9. Checkpoint Saving ---
    if step % CHECKPOINT_EVERY_STEPS == 0:
        checkpoint_file = os.path.join(CHECKPOINT_PATH, f"smollm2_checkpoint_step_{step}.pth") # Step here is still micro-batch step
        torch.save({
            'step': step, # Step here is still micro-batch step
            'global_step': global_step, # Saving global step as well (steps with weight updates)
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': loss.item(), # Save current micro-batch loss (last loss in accumulation)
        }, checkpoint_file)
        print(f"Checkpoint saved at Micro-batch Step {step} (Global Step {global_step}) to {checkpoint_file}") # Clarified step counts in checkpoint message

progress_bar.close() # Close progress bar when training finishes
print("Training finished!")

Loading and tokenizing dataset...
Loading checkpoint from: smollm2-checkpoints\smollm2_checkpoint_step_5000.pth
Resuming training from global step: 1250
Dataset prepared with 41 micro-batches using micro_batch_size=4 and sequence_length=2048.
Starting training...


Training:  25%|██▌       | 5000/20000 [00:00<?, ?step/s]`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.



--- Prediction at Micro-batch Step 5500 (Global Step: 1375) ---
Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,,,, I:
 most low o, a wife
 mine thenoth up to lives
 your power your, itt
y of.
BDEL from
-?
N: who it not pray,; not'
'



Training:  28%|██▊       | 5500/20000 [08:48<7:41:58,  1.91s/step, Avg Loss=0.0933]

Checkpoint saved at Micro-batch Step 5500 (Global Step 1375) to smollm2-checkpoints\smollm2_checkpoint_step_5500.pth


Training:  30%|██▉       | 5999/20000 [17:35<4:12:05,  1.08s/step, Avg Loss=0.1063]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.



--- Prediction at Micro-batch Step 6000 (Global Step: 1500) ---
Prompt: 'To be or not to be,'
Generated: 'To be or not to be, now with
 o God make with, his.
 true; the with world with de, know not
 would withs I for with son
 King up ifsRE.
B thy to: God; the for with is?
A'



Training:  30%|███       | 6000/20000 [17:38<6:56:24,  1.78s/step, Avg Loss=0.1063]

Checkpoint saved at Micro-batch Step 6000 (Global Step 1500) to smollm2-checkpoints\smollm2_checkpoint_step_6000.pth


Training:  32%|███▏      | 6499/20000 [26:34<4:02:56,  1.08s/step, Avg Loss=0.0257]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.



--- Prediction at Micro-batch Step 6500 (Global Step: 1625) ---
Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,,,, Iar,'t
y'd general a child
 what weio; my!
 use play him these you
 I! will more!
S brother' with: by good,
end go,
 father upons'



Training:  32%|███▎      | 6500/20000 [26:38<6:33:22,  1.75s/step, Avg Loss=0.0257]

Checkpoint saved at Micro-batch Step 6500 (Global Step 1625) to smollm2-checkpoints\smollm2_checkpoint_step_6500.pth


Training:  35%|███▍      | 6999/20000 [35:38<3:53:10,  1.08s/step, Avg Loss=0.0220]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.



--- Prediction at Micro-batch Step 7000 (Global Step: 1750) ---
Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,,,, Iar,'t
y'd general a child
 what weio; my!
 use play him these you
 I tell thy.
N: now,?; else but there:
 mostw are.
First'



Training:  35%|███▌      | 7000/20000 [35:40<6:12:01,  1.72s/step, Avg Loss=0.0220]

Checkpoint saved at Micro-batch Step 7000 (Global Step 1750) to smollm2-checkpoints\smollm2_checkpoint_step_7000.pth


Training:  37%|███▋      | 7499/20000 [44:39<3:45:15,  1.08s/step, Avg Loss=0.0195]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.



--- Prediction at Micro-batch Step 7500 (Global Step: 1875) ---
Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,,,, Iar,'t
y'd general a child
 by upon; it fromv
 name with before the; thens you
umberland be who bute his.
CAMLOest can!
 what the is's,'



Training:  38%|███▊      | 7500/20000 [44:42<5:57:29,  1.72s/step, Avg Loss=0.0195]

Checkpoint saved at Micro-batch Step 7500 (Global Step 1875) to smollm2-checkpoints\smollm2_checkpoint_step_7500.pth


Training:  40%|███▉      | 7999/20000 [53:42<3:35:37,  1.08s/step, Avg Loss=0.0196]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.



--- Prediction at Micro-batch Step 8000 (Global Step: 2000) ---
Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,,,, Iar,'t
y'd general a child
 what weio; my!
 use play him these you
 I tell thy.
N: now,?; else but there:
 mostw are.
First'



Training:  40%|████      | 8000/20000 [53:44<5:43:33,  1.72s/step, Avg Loss=0.0196]

Checkpoint saved at Micro-batch Step 8000 (Global Step 2000) to smollm2-checkpoints\smollm2_checkpoint_step_8000.pth


Training:  42%|████▏     | 8499/20000 [1:02:43<3:26:52,  1.08s/step, Avg Loss=0.0174]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.



--- Prediction at Micro-batch Step 8500 (Global Step: 2125) ---
Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,,,, Iar,'t
 dis did your you
 are, will not
yTh-, am buted,y more my
 treason- to make
 it at andy
 time no, the will at, I so'



Training:  42%|████▎     | 8500/20000 [1:02:46<5:30:28,  1.72s/step, Avg Loss=0.0174]

Checkpoint saved at Micro-batch Step 8500 (Global Step 2125) to smollm2-checkpoints\smollm2_checkpoint_step_8500.pth


Training:  45%|████▍     | 8999/20000 [1:11:47<3:20:13,  1.09s/step, Avg Loss=0.0174]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.



--- Prediction at Micro-batch Step 9000 (Global Step: 2250) ---
Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,,,, Iar,'t
 dis did your you
 are, will not
yTh-, am but
y be name with what will no.' my:
 with, will,, a and's,
--itor'



Training:  45%|████▌     | 9000/20000 [1:11:49<5:31:20,  1.81s/step, Avg Loss=0.0174]

Checkpoint saved at Micro-batch Step 9000 (Global Step 2250) to smollm2-checkpoints\smollm2_checkpoint_step_9000.pth


Training:  47%|████▋     | 9499/20000 [1:20:58<3:12:58,  1.10s/step, Avg Loss=0.0169]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.



--- Prediction at Micro-batch Step 9500 (Global Step: 2375) ---
Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,,,, Iar,'t
 dis did your you
 are, will not
yTh-, am buted,y more my
 treason- to make
 it at andy
 time no, the will at, I so'



Training:  48%|████▊     | 9500/20000 [1:21:01<5:02:37,  1.73s/step, Avg Loss=0.0169]

Checkpoint saved at Micro-batch Step 9500 (Global Step 2375) to smollm2-checkpoints\smollm2_checkpoint_step_9500.pth


Training:  50%|████▉     | 9999/20000 [1:30:12<3:03:26,  1.10s/step, Avg Loss=0.0164]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.



--- Prediction at Micro-batch Step 10000 (Global Step: 2500) ---
Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,,,, Iar,'t
 dis did your you
 are, will not
yTh-, am but
y be name with what will no.' my:
 with,H will,?
BK for'd
 but'



Training:  50%|█████     | 10000/20000 [1:30:15<4:54:18,  1.77s/step, Avg Loss=0.0164]

Checkpoint saved at Micro-batch Step 10000 (Global Step 2500) to smollm2-checkpoints\smollm2_checkpoint_step_10000.pth


Training:  52%|█████▏    | 10499/20000 [1:39:21<2:52:47,  1.09s/step, Avg Loss=0.0165]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.



--- Prediction at Micro-batch Step 10500 (Global Step: 2625) ---
Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,,,, Iar,'t
 dis did your you
 are, will not
yTh-, am but
y be name with what will no.' my:
 with, will, a--, a that
 time to to'



Training:  52%|█████▎    | 10500/20000 [1:39:24<4:40:51,  1.77s/step, Avg Loss=0.0165]

Checkpoint saved at Micro-batch Step 10500 (Global Step 2625) to smollm2-checkpoints\smollm2_checkpoint_step_10500.pth


Training:  55%|█████▍    | 10999/20000 [1:48:32<2:43:55,  1.09s/step, Avg Loss=0.0170]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.



--- Prediction at Micro-batch Step 11000 (Global Step: 2750) ---
Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,,,, Iar,'t
 dis did your you
 are, will not
yTh-, am but
y be name with what will no.' my:
 with, will, a--, a that--;
 o'



Training:  55%|█████▌    | 11000/20000 [1:48:34<4:20:58,  1.74s/step, Avg Loss=0.0170]

Checkpoint saved at Micro-batch Step 11000 (Global Step 2750) to smollm2-checkpoints\smollm2_checkpoint_step_11000.pth


Training:  57%|█████▋    | 11499/20000 [1:57:44<2:36:39,  1.11s/step, Avg Loss=0.0163]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.



--- Prediction at Micro-batch Step 11500 (Global Step: 2875) ---
Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,,,, Iar,'t
 dis did your you
 are, will not
yTh-, am but
y be name with what will no.' my:
 with, the,, a 's with,
 by-'



Training:  57%|█████▊    | 11500/20000 [1:57:47<4:10:02,  1.77s/step, Avg Loss=0.0163]

Checkpoint saved at Micro-batch Step 11500 (Global Step 2875) to smollm2-checkpoints\smollm2_checkpoint_step_11500.pth


Training:  60%|█████▉    | 11999/20000 [2:06:58<2:26:19,  1.10s/step, Avg Loss=0.0171]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.



--- Prediction at Micro-batch Step 12000 (Global Step: 3000) ---
Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,,, Iy,y, land
 mine; beingest;est;ate' to,
 I with up to mad Mar theal
 did I yourom whoanish and,'d
 one on no time
L now
 most'



Training:  60%|██████    | 12000/20000 [2:07:00<3:55:12,  1.76s/step, Avg Loss=0.0171]

Checkpoint saved at Micro-batch Step 12000 (Global Step 3000) to smollm2-checkpoints\smollm2_checkpoint_step_12000.pth


Training:  62%|██████▏   | 12499/20000 [2:16:09<2:17:11,  1.10s/step, Avg Loss=0.0174]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.



--- Prediction at Micro-batch Step 12500 (Global Step: 3125) ---
Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,,,, Iy, master't; that
y it friends: it!
sty who but king: then to!
th he king will straight there there at man
 loss'd I a
 by! that you.
'



Training:  62%|██████▎   | 12500/20000 [2:16:12<3:39:45,  1.76s/step, Avg Loss=0.0174]

Checkpoint saved at Micro-batch Step 12500 (Global Step 3125) to smollm2-checkpoints\smollm2_checkpoint_step_12500.pth


Training:  65%|██████▍   | 12999/20000 [2:25:20<2:07:21,  1.09s/step, Avg Loss=0.0184]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.



--- Prediction at Micro-batch Step 13000 (Global Step: 3250) ---
Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,,,, Iar,'t
 dis did your you
 are, will not
yTh-, am buted,y more my
 treason- to make
 it at andy
 do: by, will are, knoww'



Training:  65%|██████▌   | 13000/20000 [2:25:22<3:23:44,  1.75s/step, Avg Loss=0.0184]

Checkpoint saved at Micro-batch Step 13000 (Global Step 3250) to smollm2-checkpoints\smollm2_checkpoint_step_13000.pth


Training:  67%|██████▋   | 13499/20000 [2:34:31<1:58:52,  1.10s/step, Avg Loss=0.0171]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.



--- Prediction at Micro-batch Step 13500 (Global Step: 3375) ---
Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,,,, Iar,'t
 dis did your you
 are, will not
yTh-, am but
y be name with what will no.' my:
 with, the,, a 's with,
 by-'



Training:  68%|██████▊   | 13500/20000 [2:34:34<3:08:03,  1.74s/step, Avg Loss=0.0171]

Checkpoint saved at Micro-batch Step 13500 (Global Step 3375) to smollm2-checkpoints\smollm2_checkpoint_step_13500.pth


Training:  70%|██████▉   | 13999/20000 [2:43:41<1:50:08,  1.10s/step, Avg Loss=0.0170]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.



--- Prediction at Micro-batch Step 14000 (Global Step: 3500) ---
Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,,,, Iar,'t
 dis did your you
 are, will not
yTh-, am but
y be name with what will no.' my:
 with, the,, a 's with,
-,'



Training:  70%|███████   | 14000/20000 [2:43:44<2:58:00,  1.78s/step, Avg Loss=0.0170]

Checkpoint saved at Micro-batch Step 14000 (Global Step 3500) to smollm2-checkpoints\smollm2_checkpoint_step_14000.pth


Training:  72%|███████▏  | 14499/20000 [2:52:53<1:40:24,  1.10s/step, Avg Loss=0.0171]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.



--- Prediction at Micro-batch Step 14500 (Global Step: 3625) ---
Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,,,, Iar,'t
 dis did your you
 are, will not
yTh-, am buted,y more my
 treason- to make
 it at andy
 do: by, will are, knoww'



Training:  72%|███████▎  | 14500/20000 [2:52:56<2:41:21,  1.76s/step, Avg Loss=0.0171]

Checkpoint saved at Micro-batch Step 14500 (Global Step 3625) to smollm2-checkpoints\smollm2_checkpoint_step_14500.pth


Training:  75%|███████▍  | 14999/20000 [3:02:04<1:30:54,  1.09s/step, Avg Loss=0.0164]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.



--- Prediction at Micro-batch Step 15000 (Global Step: 3750) ---
Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,,,, Iar,'t
 dis did your you
 are, will not
yTh-, am but
y be name with what will no.' my:
 with, will, a--, a that
 time to to'



Training:  75%|███████▌  | 15000/20000 [3:02:06<2:27:55,  1.78s/step, Avg Loss=0.0164]

Checkpoint saved at Micro-batch Step 15000 (Global Step 3750) to smollm2-checkpoints\smollm2_checkpoint_step_15000.pth


Training:  77%|███████▋  | 15499/20000 [3:11:12<1:22:18,  1.10s/step, Avg Loss=0.0164]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.



--- Prediction at Micro-batch Step 15500 (Global Step: 3875) ---
Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,,,, Iar,'t
 dis did your you
 are, will not
yTh-, am but
y be name with what will no.' my:
 with, will,, a and's,
--itor'



Training:  78%|███████▊  | 15500/20000 [3:11:16<2:10:38,  1.74s/step, Avg Loss=0.0164]

Checkpoint saved at Micro-batch Step 15500 (Global Step 3875) to smollm2-checkpoints\smollm2_checkpoint_step_15500.pth


Training:  80%|███████▉  | 15999/20000 [3:20:23<1:13:16,  1.10s/step, Avg Loss=0.0167]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.



--- Prediction at Micro-batch Step 16000 (Global Step: 4000) ---
Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,,,, Iar,'t
 dis did your you
 are, will not
yTh-, am buted,y more my
 treason- to make
 it at andy
 do: by, will are, knoww'



Training:  80%|████████  | 16000/20000 [3:20:25<1:55:33,  1.73s/step, Avg Loss=0.0167]

Checkpoint saved at Micro-batch Step 16000 (Global Step 4000) to smollm2-checkpoints\smollm2_checkpoint_step_16000.pth


Training:  82%|████████▏ | 16499/20000 [3:29:33<1:03:56,  1.10s/step, Avg Loss=0.0163]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.



--- Prediction at Micro-batch Step 16500 (Global Step: 4125) ---
Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,,,, Iar,'t
 dis did your you
 are, will not
yTh-, am but
y be name with what will no.' my:
 with, the, a ' aish and
 but. but'



Training:  82%|████████▎ | 16500/20000 [3:29:36<1:43:01,  1.77s/step, Avg Loss=0.0163]

Checkpoint saved at Micro-batch Step 16500 (Global Step 4125) to smollm2-checkpoints\smollm2_checkpoint_step_16500.pth


Training:  85%|████████▍ | 16999/20000 [3:38:43<54:40,  1.09s/step, Avg Loss=0.0163]  The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.



--- Prediction at Micro-batch Step 17000 (Global Step: 4250) ---
Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,,,, Iar,'t
 dis did your you
 are, will not
yTh-, am buted,y more my
 treason- to make
 it at andy
 time no, the will at, I so'



Training:  85%|████████▌ | 17000/20000 [3:38:46<1:27:45,  1.76s/step, Avg Loss=0.0163]

Checkpoint saved at Micro-batch Step 17000 (Global Step 4250) to smollm2-checkpoints\smollm2_checkpoint_step_17000.pth


Training:  87%|████████▋ | 17499/20000 [3:47:52<45:30,  1.09s/step, Avg Loss=0.0176]  The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.



--- Prediction at Micro-batch Step 17500 (Global Step: 4375) ---
Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,,, Iy,y, mastery before your; he king it most it
-! your eyesw get that daughter gone no with
e by no to: and you you
 are from time inayed tell, mistress
'



Training:  88%|████████▊ | 17500/20000 [3:47:55<1:13:01,  1.75s/step, Avg Loss=0.0176]

Checkpoint saved at Micro-batch Step 17500 (Global Step 4375) to smollm2-checkpoints\smollm2_checkpoint_step_17500.pth


Training:  90%|████████▉ | 17999/20000 [3:57:02<36:27,  1.09s/step, Avg Loss=0.0167]  The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.



--- Prediction at Micro-batch Step 18000 (Global Step: 4500) ---
Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,,,, Iar,'t
 dis did your you
 are, will not
yTh-, am but
y be name with what will no.' my:
 with, will, a--, a that--;
 o'



Training:  90%|█████████ | 18000/20000 [3:57:05<58:17,  1.75s/step, Avg Loss=0.0167]

Checkpoint saved at Micro-batch Step 18000 (Global Step 4500) to smollm2-checkpoints\smollm2_checkpoint_step_18000.pth


Training:  92%|█████████▏| 18499/20000 [4:06:14<27:25,  1.10s/step, Avg Loss=0.0176]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.



--- Prediction at Micro-batch Step 18500 (Global Step: 4625) ---
Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,,,, Iy, master't; that
y it friends: it!
sty who but king: then to!
th he king will straight there there to,
 my:The-- you;
 you there hope did'



Training:  92%|█████████▎| 18500/20000 [4:06:18<43:36,  1.74s/step, Avg Loss=0.0176]

Checkpoint saved at Micro-batch Step 18500 (Global Step 4625) to smollm2-checkpoints\smollm2_checkpoint_step_18500.pth


Training:  95%|█████████▍| 18999/20000 [4:15:25<18:15,  1.09s/step, Avg Loss=0.0167]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.



--- Prediction at Micro-batch Step 19000 (Global Step: 4750) ---
Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,,,, Iar,'t
 dis did your you
 are, will not
yTh-, am but
y be name with what will no.' my:
 with, will, a--, a that
 with, the'



Training:  95%|█████████▌| 19000/20000 [4:15:27<29:06,  1.75s/step, Avg Loss=0.0167]

Checkpoint saved at Micro-batch Step 19000 (Global Step 4750) to smollm2-checkpoints\smollm2_checkpoint_step_19000.pth


Training:  97%|█████████▋| 19499/20000 [4:24:36<09:10,  1.10s/step, Avg Loss=0.0180]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.



--- Prediction at Micro-batch Step 19500 (Global Step: 4875) ---
Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,,,, Iar,'t
 dis did your you
 are, will not
yTh-, am buted,y more my
 treason- to make
 it at andy
 time no, the will at, I so'



Training:  98%|█████████▊| 19500/20000 [4:24:39<14:38,  1.76s/step, Avg Loss=0.0180]

Checkpoint saved at Micro-batch Step 19500 (Global Step 4875) to smollm2-checkpoints\smollm2_checkpoint_step_19500.pth


Training: 100%|█████████▉| 19999/20000 [4:33:47<00:01,  1.09s/step, Avg Loss=0.0168]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.



--- Prediction at Micro-batch Step 20000 (Global Step: 5000) ---
Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,,,, Iar,'t
 dis did your you
 are, will not
yTh-, am but
y be name with what will no.' my:
 with, will,, a and's,
--itor'



Training: 100%|██████████| 20000/20000 [4:33:49<00:00,  1.10s/step, Avg Loss=0.0168]

Checkpoint saved at Micro-batch Step 20000 (Global Step 5000) to smollm2-checkpoints\smollm2_checkpoint_step_20000.pth
Training finished!





As we can observe, model loaded from saved checkpoint showed average error loss start from the point where we stopped training before, showing that the **model checkpoint were saved correctly**

## Model Prediction 

SmolLLM-125M model after 5000 Steps (20000 Micro Steps)

In [8]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# from model import create_smollm2_model # Assuming model.py is in the same directory

def generate_prediction(checkpoint_path, input_text, max_new_tokens=50, temperature=0.7, top_p=0.9):
    """
    Loads a model from a checkpoint and generates text based on an input prompt.

    Args:
        checkpoint_path (str): Path to the saved model checkpoint file (e.g., 'smollm2-checkpoints/smollm2_checkpoint_step_5000.pth').
        input_text (str): The text prompt to start generation from (e.g., "To be or not to be,").
        max_new_tokens (int): Maximum number of new tokens to generate.
        temperature (float): Sampling temperature for generation (higher values more creative, lower more deterministic).
        top_p (float): Top-p sampling parameter.

    Returns:
        str: The generated text.
    """
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    # 1. Load Model and Tokenizer (same as in training)
    model, tokenizer = create_smollm2_model() # Use the same model creation function
    model.to(device)

    # 2. Load Model State from Checkpoint
    print(f"Loading checkpoint from: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device) # Load checkpoint to correct device
    model.load_state_dict(checkpoint['model_state_dict']) # Only load model weights

    # 3. Set Model to Evaluation Mode
    model.eval()

    # 4. Tokenize Input Text
    prompt_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)

    # 5. Generate Prediction
    print(f"Generating text with prompt: '{input_text}'")
    sample_outputs = model.generate(
        prompt_ids,
        max_length=len(prompt_ids[0]) + max_new_tokens,
        num_return_sequences=1,
        temperature=temperature,
        top_p=top_p,
    )

    # 6. Decode and Return Generated Text
    predicted_text = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
    return predicted_text

if __name__ == "__main__":
    checkpoint_path = "smollm2-checkpoints\smollm2_checkpoint_step_20000.pth" # --- REPLACE WITH YOUR ACTUAL CHECKPOINT PATH ---
    input_prompt = "O art thou " # --- REPLACE WITH YOUR DESIRED PROMPT ---

    generated_text = generate_prediction(checkpoint_path, input_prompt)

    print(f"\nPrompt: '{input_prompt}'")
    print(f"Generated: '{generated_text}'")


Loading checkpoint from: smollm2-checkpoints\smollm2_checkpoint_step_20000.pth


  checkpoint = torch.load(checkpoint_path, map_location=device) # Load checkpoint to correct device
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


Generating text with prompt: 'O art thou '

Prompt: 'O art thou '
Generated: 'O art thou  ams?', leave son
 do th will the from thee one mine, there hear?
 will theyest thy but out sh as unto no
 bro one hisal'd' so; to you at
 brother not thees thesehouse'


## Current observations

1. We have trained the model for 5000 Global Steps, but average loss value has stayed consistent at around 0.0168
2. Now we will further train the model for further with smaller learning rate (old LR: 3e-4, new LR: 3e-5)


In [2]:
RESUME_CHECKPOINT_PATH = "smollm2-checkpoints\smollm2_checkpoint_step_20000.pth"

# train.py
import torch
from transformers import get_linear_schedule_with_warmup # Changed import for checkpointing
# from model import create_smollm2_model
import os

# --- 1. Hyperparameters and Configuration ---
TARGET_GLOBAL_STEPS = 5050
PREDICT_EVERY_STEPS = 100
CHECKPOINT_EVERY_STEPS = 200
SEQUENCE_LENGTH = 2048  # As defined in config - May need to reduce if OOM
MICRO_BATCH_SIZE = 8     # As defined in config - REDUCE THIS FIRST to fix OOM
LEARNING_RATE = 3e-5     # You can adjust, simplified from config for now
WARMUP_STEPS = 250      # Simplified warmup
CHECKPOINT_PATH = "smollm2-checkpoints"
INPUT_FILE = "datasets\input.txt"
PREDICTION_PROMPT = "To be or not to be," # A starting prompt for prediction

GRADIENT_ACCUMULATION_STEPS = 4 # --- Speedup 1: Gradient Accumulation --- Accumulate gradients over this many steps
TARGET_TRAIN_STEPS = TARGET_GLOBAL_STEPS * GRADIENT_ACCUMULATION_STEPS # Adjusted total steps based on accumulation steps

USE_ACTIVATION_CHECKPOINTING = True # --- Speedup 2: Activation Checkpointing --- Enable or disable activation checkpointing

# --- ***MEMORY OPTIMIZATION - REDUCE THESE IF OOM ERROR PERSISTS*** ---
REDUCE_MICRO_BATCH_SIZE_FACTOR = 2 # --- Reduce Micro Batch Size --- Reduce micro_batch_size by this factor
# If you STILL get OOM, try reducing SEQUENCE_LENGTH_FACTOR (but micro_batch_size reduction is usually more effective first)
REDUCE_SEQUENCE_LENGTH_FACTOR = 1 # --- Reduce Sequence Length --- Reduce sequence length by this factor if needed

# --- 2. Adjusted Hyperparameters based on Reduction Factors ---
ADJUSTED_MICRO_BATCH_SIZE = MICRO_BATCH_SIZE // REDUCE_MICRO_BATCH_SIZE_FACTOR
if ADJUSTED_MICRO_BATCH_SIZE <= 0:
    ADJUSTED_MICRO_BATCH_SIZE = 1 # Ensure micro_batch_size is at least 1
ADJUSTED_SEQUENCE_LENGTH = SEQUENCE_LENGTH // REDUCE_SEQUENCE_LENGTH_FACTOR
if ADJUSTED_SEQUENCE_LENGTH <= 0:
    ADJUSTED_SEQUENCE_LENGTH = 64 # Ensure sequence_length is at least reasonably sized

MICRO_BATCH_SIZE = ADJUSTED_MICRO_BATCH_SIZE # Update micro batch size
SEQUENCE_LENGTH = ADJUSTED_SEQUENCE_LENGTH # Update sequence length


# --- 3. Create Checkpoint Directory ---
os.makedirs(CHECKPOINT_PATH, exist_ok=True)


# --- 4. Load Model and Tokenizer ---
model, tokenizer = create_smollm2_model()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

# --- Speedup 3: Mixed Precision Training (bfloat16) --- Model is already in bfloat16 from model.py
# We are using bfloat16 dtype for model which speeds up training and reduces memory usage on compatible GPUs (like NVIDIA A100, H100)

if USE_ACTIVATION_CHECKPOINTING: # --- Enable activation checkpointing if flag is set ---
    model.gradient_checkpointing_enable() # Enable activation checkpointing from transformers

# --- 5. Prepare Dataset ---
print("Loading and tokenizing dataset...")
with open(INPUT_FILE, "r", encoding="utf-8") as f:
    text = f.read()

tokenized_dataset = tokenizer(text, return_tensors="pt", truncation=False) # No truncation initially
input_ids = tokenized_dataset['input_ids']

# --- 8. Optimizer and Scheduler ---
# --- Speedup 4: Fused AdamW Optimizer --- Using fused AdamW if available (often faster)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01, fused=True) # Using fused=True for potential speedup
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=WARMUP_STEPS, num_training_steps=TARGET_TRAIN_STEPS // GRADIENT_ACCUMULATION_STEPS # Adjusted lr scheduler steps
) # Simplified scheduler

# --- 6. Load Checkpoint if `resume_checkpoint_path` is provided ---
initial_global_step = 0 # Track initial global step, default is 0 for new training
if RESUME_CHECKPOINT_PATH:
    print(f"Loading checkpoint from: {RESUME_CHECKPOINT_PATH}")
    checkpoint = torch.load(RESUME_CHECKPOINT_PATH, map_location=device, weights_only=False) # Load checkpoint to correct device
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    initial_global_step = checkpoint.get('global_step', 0) # Try to get global_step, default to 0 if not found
    print(f"Resuming training from global step: {initial_global_step}")
else:
    print("Starting training from scratch.")

start_step = initial_global_step * GRADIENT_ACCUMULATION_STEPS + 1 # Calculate micro-batch step to start from

# --- 7. Create Data Batches ---
def create_batches(input_ids, seq_len, micro_batch_size):
    num_tokens = input_ids.shape[1]
    num_batches = num_tokens // seq_len
    truncated_input_ids = input_ids[:, :num_batches * seq_len] # Truncate to fit full sequences based on potentially reduced SEQUENCE_LENGTH
    batched_input_ids = truncated_input_ids.reshape(-1, seq_len) # Reshape into sequences
    num_micro_batches = num_batches // micro_batch_size
    micro_batches = []
    for i in range(num_micro_batches):
        start_index = i * micro_batch_size
        end_index = (i + 1) * micro_batch_size
        batch = batched_input_ids[start_index:end_index]
        micro_batches.append(batch)
    return micro_batches

micro_batches = create_batches(input_ids.to(device), SEQUENCE_LENGTH, MICRO_BATCH_SIZE) # Using potentially reduced SEQUENCE_LENGTH and MICRO_BATCH_SIZE
print(f"Dataset prepared with {len(micro_batches)} micro-batches using micro_batch_size={MICRO_BATCH_SIZE} and sequence_length={SEQUENCE_LENGTH}.")


# --- imports ---
from tqdm import tqdm

# --- 9. Training Loop ---
print("Starting training...")
model.train() # Set model to training mode
global_step = initial_global_step # Initialize global_step from checkpoint or 0
accumulated_loss = 0 # To track loss over accumulation steps

# Initialize tqdm progress bar, total is remaining global steps
remaining_global_steps = TARGET_GLOBAL_STEPS - initial_global_step
if remaining_global_steps <= 0:
    print("Training already completed to target steps or beyond based on checkpoint.")
    exit()

progress_bar = tqdm(range(start_step, TARGET_TRAIN_STEPS + 1), desc="Training", unit="step", initial=start_step -1, total=TARGET_TRAIN_STEPS) # Initialize tqdm with start and total

for step in progress_bar: # Wrap training loop with tqdm
    batch_index = (step - 1) % len(micro_batches) # Cycle through batches
    batch = micro_batches[batch_index]

    inputs = batch
    targets = torch.roll(batch, shifts=-1, dims=1) # Next token prediction

    outputs = model(inputs, labels=targets) # Labels for loss calculation
    loss = outputs.loss
    loss = loss / GRADIENT_ACCUMULATION_STEPS # --- Scale loss for gradient accumulation ---
    accumulated_loss += loss.item() # Accumulate loss for logging

    loss.backward()

    if step % GRADIENT_ACCUMULATION_STEPS == 0: # --- Update weights every GRADIENT_ACCUMULATION_STEPS ---
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        global_step += 1 # Increment global step only when weights are updated
        avg_loss = accumulated_loss / GRADIENT_ACCUMULATION_STEPS # Calculate average loss over accumulation steps
        accumulated_loss = 0 # Reset accumulated loss

        if global_step % 10 == 0: # Log loss every 10 global steps (after accumulation)
            # No need for separate print here, tqdm will handle logging

            # --- Update tqdm progress bar with current average loss ---
            progress_bar.set_postfix({"Avg Loss": f"{avg_loss:.4f}"})


    # --- 8. Prediction Interval ---
    if step % PREDICT_EVERY_STEPS == 0:
        model.eval() # Set model to evaluation mode
        print(f"\n--- Prediction at Micro-batch Step {step} (Global Step: {global_step}) ---") # Clarify step counts
        prompt_ids = tokenizer.encode(PREDICTION_PROMPT, return_tensors="pt").to(device)
        sample_outputs = model.generate(
            prompt_ids,
            max_length=len(prompt_ids[0]) + 50, # Generate up to 50 new tokens
            num_return_sequences=1,
            temperature=0.7, # Adjust temperature for creativity
            top_p=0.9,
        )
        predicted_text = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
        print(f"Prompt: '{PREDICTION_PROMPT}'")
        print(f"Generated: '{predicted_text}'\n")
        model.train() # Set back to training mode

    # --- 9. Checkpoint Saving ---
    if step % CHECKPOINT_EVERY_STEPS == 0:
        checkpoint_file = os.path.join(CHECKPOINT_PATH, f"smollm2_checkpoint_step_{step}.pth") # Step here is still micro-batch step
        torch.save({
            'step': step, # Step here is still micro-batch step
            'global_step': global_step, # Saving global step as well (steps with weight updates)
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': loss.item(), # Save current micro-batch loss (last loss in accumulation)
        }, checkpoint_file)
        print(f"Checkpoint saved at Micro-batch Step {step} (Global Step {global_step}) to {checkpoint_file}") # Clarified step counts in checkpoint message

progress_bar.close() # Close progress bar when training finishes
print("Training finished!")

Loading and tokenizing dataset...
Loading checkpoint from: smollm2-checkpoints\smollm2_checkpoint_step_20000.pth
Resuming training from global step: 5000
Dataset prepared with 41 micro-batches using micro_batch_size=4 and sequence_length=2048.
Starting training...


Training:  99%|█████████▉| 20000/20200 [00:00<?, ?step/s]`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
Training: 100%|█████████▉| 20099/20200 [02:12<02:15,  1.34s/step, Avg Loss=0.0175]


--- Prediction at Micro-batch Step 20100 (Global Step: 5025) ---


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Training: 100%|█████████▉| 20100/20200 [02:15<03:01,  1.81s/step, Avg Loss=0.0175]

Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,,,, Iar,'t
 dis did your you
 are, will not
yTh-, am buted,y more my
 treason- to make
 it at andy
 time no, the will at, I so'



Training: 100%|█████████▉| 20199/20200 [04:30<00:01,  1.34s/step, Avg Loss=0.0171]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.



--- Prediction at Micro-batch Step 20200 (Global Step: 5050) ---
Prompt: 'To be or not to be,'
Generated: 'To be or not to be,,,,, Iar,'t
 dis did your you
 are, will not
yTh-, am buted,y more my
 treason- to make
 it at andy
 time no, the will at, I so'



Training: 100%|██████████| 20200/20200 [04:32<00:00,  1.36s/step, Avg Loss=0.0171]

Checkpoint saved at Micro-batch Step 20200 (Global Step 5050) to smollm2-checkpoints\smollm2_checkpoint_step_20200.pth
Training finished!





Final Average loss achieved: **0.0171** 

- Reducing LR rate by 10 did not significantly reduced model loss value

### Prediction using the updated model

In [9]:
checkpoint_path = "smollm2-checkpoints\smollm2_checkpoint_step_20200.pth" # --- REPLACE WITH YOUR ACTUAL CHECKPOINT PATH ---
input_prompt = "O art thou " # --- REPLACE WITH YOUR DESIRED PROMPT ---

generated_text = generate_prediction(checkpoint_path, input_prompt)

print(f"\nPrompt: '{input_prompt}'")
print(f"Generated: '{generated_text}'")

Loading checkpoint from: smollm2-checkpoints\smollm2_checkpoint_step_20200.pth


  checkpoint = torch.load(checkpoint_path, map_location=device) # Load checkpoint to correct device
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


Generating text with prompt: 'O art thou '

Prompt: 'O art thou '
Generated: 'O art thou  ams?', leave son
 do th will the from thee one mine, there hear?
 will theyest thy but out sh as unto no
 bro one hisal'd' so; to you at
 brother not thees thesehouse'
