In [None]:
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset
import numpy as np
from tqdm import tqdm
import math
from torch.cuda.amp import autocast, GradScaler

In [None]:
# Set random seed for reproducibility
torch.manual_seed(0)
np.random.seed(0)

# Device setup
# 'accelerate' with device_map='auto' will handle device placement.
# We still might need this for moving specific tensors like rewards or inputs.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [None]:
# --- Memory Optimization Settings ---
# No need to call torch.cuda.empty_cache() at the start
# torch.backends.cuda.enable_mem_efficient_sdp(False) # Keep commented unless needed for specific issues

# --- Model and Tokenizer Loading ---
model_name = "Qwen/Qwen2.5-Math-1.5B"
print(f"Loading tokenizer: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)

Loading tokenizer: Qwen/Qwen2.5-Math-1.5B


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

In [None]:
# Set pad token if it doesn't exist
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    print("Set tokenizer pad_token to eos_token")

print(f"Loading model: {model_name}")
# Load the main model with device_map='auto', gradient checkpointing, and float16
# device_map='auto' handles placing the model on available GPUs/CPU/disk
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto", # Let accelerate handle device placement
    torch_dtype=torch.float16,
    trust_remote_code=True # Add if required by the specific model
)
model.gradient_checkpointing_enable() # Enable gradient checkpointing *after* loading
# DO NOT call model.to(device) here - device_map handles it.
print("Model loaded onto devices via device_map='auto'")

print("Loading reference model...")
# Load the reference model similarly
reference_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto", # Let accelerate handle device placement
    torch_dtype=torch.float16,
    trust_remote_code=True # Add if required by the specific model
)
reference_model.eval() # Set reference model to evaluation mode
# DO NOT call reference_model.to(device) here - device_map handles it.
print("Reference model loaded.")


Loading model: Qwen/Qwen2.5-Math-1.5B


config.json:   0%|          | 0.00/676 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/3.09G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/138 [00:00<?, ?B/s]

Model loaded onto devices via device_map='auto'
Loading reference model...
Reference model loaded.


In [None]:
# --- Example Data ---
example = {
    "prompt": "The pressure \\( P \\) exerted by wind on a sail varies jointly as the area \\( A \\) of the sail and the cube of the wind’s velocity \\( V \\). When the velocity is \\( 8 \\) miles per hour, the pressure on a sail of \\( 2 \\) square feet is \\( 4 \\) pounds. Find the wind velocity when the pressure on \\( 4 \\) square feet of sail is \\( 32 \\) pounds. Let’s think step by step and output the final answer within \\boxed{}.",
    # Previous attempt at recalculating GT removed to fix syntax error.
    # Using the user's original GT string for consistency with the initial request.
    "ground_truth": "12.8"
}

# --- Dataset Preparation ---
# Reduced batch size for memory, effective batch size simulated with accumulation
batch_size = 4 # Significantly reduced for memory
effective_batch_size = 32 # Target effective batch size
accumulation_steps = max(1, effective_batch_size // batch_size) # Calculate accumulation steps
print(f"Physical Batch Size: {batch_size}, Accumulation Steps: {accumulation_steps}, Effective Batch Size: {batch_size * accumulation_steps}")

# Duplicate the single example to create a batch
data = [example] * (batch_size * accumulation_steps) # Create enough data for one effective batch
dataset = Dataset.from_dict({
    "prompt": [d["prompt"] for d in data],
    "ground_truth": [d["ground_truth"] for d in data]
})

Physical Batch Size: 4, Accumulation Steps: 8, Effective Batch Size: 32


In [None]:
# --- Hyperparameters ---
learning_rate = 1e-6
kl_coeff = 0.02 # Adjusted KL coefficient, often needs tuning
entropy_coeff = 0.001 # Entropy bonus to encourage exploration
rollout_temperature = 0.7 # Temperature for sampling responses
weight_decay = 0.01
max_prompt_length = 512 # Max length for input prompt tokens
# Adjust max_response_length based on expected output length and memory
# max_new_tokens will be max_total_length - prompt_length
max_total_length = 768 # Reduced total length (prompt + response)
num_steps = 200 # Reduced number of steps for quicker testing
samples_per_prompt = 2 # Reduced samples per prompt for memory


In [None]:
# --- Optimizer and Scaler ---
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scaler = GradScaler() # For mixed precision

# --- Reward Function ---
def compute_reward(response, ground_truth):
    """Binary reward: 1.0 if ground_truth is in response, 0.0 otherwise."""
    # Simple check, might need more sophisticated reward logic (e.g., parsing the boxed answer)
    # Check if the exact ground_truth string is present in the response.
    # A more robust check might involve parsing "\\boxed{...}"
    return 1.0 if ground_truth in response else 0.0

  scaler = GradScaler() # For mixed precision


In [None]:
import torch
import torch.nn.functional as F

def compute_grpo_loss(policy_logits, policy_sampled_ids, policy_attention_mask,
                      ref_log_probs_generated, rewards, kl_coeff, entropy_coeff,
                      pad_token_id, prompt_length):
    """
    Memory-efficient GRPO loss computation (without entropy to save memory).

    Args:
        policy_logits: Logits from the policy model for generated tokens. Shape (B*S, GenLen, V)
        policy_sampled_ids: Full sequence IDs (prompt + generated). Shape (B*S, FullLen)
        policy_attention_mask: Attention mask for the full sequence. Shape (B*S, FullLen)
        ref_log_probs_generated: Log probabilities from reference model for generated tokens. Shape (B*S, GenLen)
        rewards: List of scalar rewards for each sampled sequence.
        kl_coeff: Coefficient for KL divergence term.
        entropy_coeff: Coefficient for entropy bonus term (unused, kept for compatibility).
        pad_token_id: ID of the padding token.
        prompt_length: Length of the initial prompt sequence.

    Returns:
        Total loss tensor.
        Tuple containing (pg_loss, kl_loss, entropy_loss) for logging.
    """
    B_times_S = policy_logits.shape[0]
    gen_len = policy_logits.shape[1]

    # Extract generated tokens only
    generated_ids = policy_sampled_ids[:, prompt_length:]  # (B*S, GenLen)
    gen_attention_mask = policy_attention_mask[:, prompt_length:].float()  # (B*S, GenLen)

    # Compute policy log probs efficiently
    policy_log_probs = F.log_softmax(policy_logits, dim=-1)  # (B*S, GenLen, V)
    policy_log_probs_sampled = torch.gather(
        policy_log_probs,
        dim=-1,
        index=generated_ids.unsqueeze(-1)
    ).squeeze(-1)  # (B*S, GenLen)

    # Clean up immediately to save memory
    del policy_log_probs

    # --- 1. Policy Gradient Loss (REINFORCE) ---
    rewards_tensor = torch.tensor(rewards, device=policy_logits.device, dtype=policy_logits.dtype)

    # Normalize rewards
    if len(rewards) > 1:
        rewards_mean = rewards_tensor.mean()
        rewards_std = rewards_tensor.std() + 1e-8
        normalized_rewards = (rewards_tensor - rewards_mean) / rewards_std
    else:
        normalized_rewards = rewards_tensor

    # Expand rewards to token level
    normalized_rewards_expanded = normalized_rewards.view(-1, 1)  # (B*S, 1)

    # Policy gradient: -log_prob * advantage (negative because we minimize)
    pg_loss_per_token = -policy_log_probs_sampled * normalized_rewards_expanded  # (B*S, GenLen)
    pg_loss = (pg_loss_per_token * gen_attention_mask).sum() / gen_attention_mask.sum().clamp(min=1)

    # --- 2. KL Divergence Penalty ---
    # KL(π || ref) = log π(a|s) - log ref(a|s)
    kl_per_token = policy_log_probs_sampled - ref_log_probs_generated  # (B*S, GenLen)
    mean_kl = (kl_per_token * gen_attention_mask).sum() / gen_attention_mask.sum().clamp(min=1)
    kl_penalty = kl_coeff * mean_kl

    # --- 3. Entropy (set to 0 to save memory) ---
    # The entropy calculation causes OOM, so we skip it
    entropy_loss = torch.tensor(0.0, device=policy_logits.device)

    # --- Total Loss ---
    total_loss = pg_loss + kl_penalty + entropy_loss

    return total_loss, (pg_loss.item(), kl_penalty.item(), entropy_loss.item())

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import os
import gc
from contextlib import nullcontext

# ===== SETUP AND CONFIGURATION =====
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# Enable gradient checkpointing
if hasattr(model, 'gradient_checkpointing_enable'):
    model.gradient_checkpointing_enable()

# Memory management
def clear_memory():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    gc.collect()

# ===== SAFE LOSS FUNCTION =====
def safe_compute_grpo_loss(policy_logits, policy_sampled_ids, policy_attention_mask,
                          ref_log_probs_generated, rewards, kl_coeff, entropy_coeff,
                          pad_token_id, prompt_length):
    """
    Safe version that handles tensor size mismatches
    """
    # Extract generated tokens
    generated_ids = policy_sampled_ids[:, prompt_length:]
    generated_mask = policy_attention_mask[:, prompt_length:]
    gen_length = generated_ids.shape[1]

    # Ensure all tensors have the same sequence length
    min_length = min(policy_logits.shape[1], ref_log_probs_generated.shape[1], gen_length)

    if min_length < gen_length:
        print(f"    WARNING: Truncating sequences from {gen_length} to {min_length}")
        generated_ids = generated_ids[:, :min_length]
        generated_mask = generated_mask[:, :min_length]
        policy_logits = policy_logits[:, :min_length, :]
        ref_log_probs_generated = ref_log_probs_generated[:, :min_length]
        gen_length = min_length

    # Calculate policy log probs
    policy_log_probs = F.log_softmax(policy_logits, dim=-1)
    policy_log_probs_sampled = torch.gather(
        policy_log_probs, -1, generated_ids.unsqueeze(-1)
    ).squeeze(-1)

    # Calculate KL divergence
    kl_div = policy_log_probs_sampled - ref_log_probs_generated
    kl_div = kl_div * generated_mask  # Mask out padding

    # Calculate entropy
    entropy = -torch.sum(
        torch.exp(policy_log_probs) * policy_log_probs * generated_mask.unsqueeze(-1),
        dim=-1
    )

    # Calculate policy gradient loss
    rewards_tensor = torch.tensor(rewards, device=policy_logits.device).unsqueeze(-1)
    pg_loss = -torch.mean(policy_log_probs_sampled * rewards_tensor * generated_mask)

    # Final loss components
    kl_loss = torch.mean(kl_div * generated_mask)
    ent_loss = torch.mean(entropy)

    total_loss = pg_loss + kl_coeff * kl_loss - entropy_coeff * ent_loss

    return total_loss, (pg_loss.item(), kl_loss.item(), ent_loss.item())

# ===== OPTIMIZED REFERENCE MODEL COMPUTATION =====
def compute_reference_log_probs_memory_efficient(sampled_ids, prompt_length, chunk_size=5):
    """
    Process reference model in chunks to avoid OOM
    """
    generated_ids_only = sampled_ids[:, prompt_length:]
    gen_length = generated_ids_only.shape[1]

    # If no generated tokens, return empty tensor
    if gen_length == 0:
        return torch.empty((1, 0), device=sampled_ids.device)

    ref_log_probs_chunks = []

    for start_idx in range(0, gen_length, chunk_size):
        end_idx = min(start_idx + chunk_size, gen_length)

        chunk_ids = generated_ids_only[:, start_idx:end_idx]
        chunk_mask = (chunk_ids != tokenizer.pad_token_id).long().to(chunk_ids.device)

        with torch.no_grad():
            ref_outputs = reference_model(
                input_ids=chunk_ids,
                attention_mask=chunk_mask
            )

            # Use mixed precision carefully
            with torch.amp.autocast('cuda', enabled=False):
                chunk_logits = ref_outputs.logits.float()
                chunk_log_probs = F.log_softmax(chunk_logits, dim=-1)

            # Gather probabilities for actual tokens
            chunk_sampled_probs = torch.gather(
                chunk_log_probs, -1, chunk_ids.unsqueeze(-1)
            ).squeeze(-1)

            ref_log_probs_chunks.append(chunk_sampled_probs.cpu())

        # Cleanup
        del ref_outputs, chunk_logits, chunk_log_probs, chunk_sampled_probs
        if start_idx % (chunk_size * 2) == 0:  # Clear memory periodically
            clear_memory()

    return torch.cat(ref_log_probs_chunks, dim=1)

# ===== OPTIMIZED TRAINING LOOP =====
model.train()
global_step = 0

print("\nStarting memory-optimized training...")
print(f"Config: batch_size={batch_size}, samples_per_prompt={samples_per_prompt}")

for step in range(num_steps):
    print(f"\n--- Step {step + 1} / {num_steps} ---")

    # Sample data for this step
    epoch_data = dataset.shuffle(seed=step)[:batch_size * accumulation_steps]

    optimizer.zero_grad(set_to_none=True)
    total_loss_accum = 0.0
    avg_reward_accum = 0.0
    all_pg_loss, all_kl_loss, all_ent_loss = 0.0, 0.0, 0.0

    for accum_step in range(accumulation_steps):
        print(f"  Accumulation Step {accum_step + 1}/{accumulation_steps}")

        # --- Prepare Micro-Batch ---
        start_idx = accum_step * batch_size
        end_idx = (accum_step + 1) * batch_size
        batch = {
            "prompt": epoch_data["prompt"][start_idx:end_idx],
            "ground_truth": epoch_data["ground_truth"][start_idx:end_idx]
        }
        prompts = batch["prompt"]
        ground_truths = batch["ground_truth"]

        # Tokenize (keep on CPU initially)
        inputs = tokenizer(
            prompts,
            return_tensors="pt",
            max_length=max_prompt_length,
            truncation=True,
            padding="max_length",
            return_attention_mask=True
        )
        prompt_length = inputs.input_ids.shape[1]

        all_samples = []

        # --- Rollout Phase - Process one prompt at a time ---
        print(f"    Generating {samples_per_prompt} samples per prompt...")

        for i in range(batch_size):
            prompt_input_ids = inputs.input_ids[i:i+1]
            prompt_attention_mask = inputs.attention_mask[i:i+1]

            # Move single prompt to GPU
            prompt_input_ids = prompt_input_ids.to(device)
            prompt_attention_mask = prompt_attention_mask.to(device)

            for sample_idx in range(samples_per_prompt):
                # Initialize variables
                outputs = sampled_ids = ref_log_probs = None

                try:
                    with torch.no_grad():
                        # Generate with lower precision if supported
                        with torch.amp.autocast('cuda', dtype=torch.float16):
                            outputs = model.generate(
                                input_ids=prompt_input_ids,
                                attention_mask=prompt_attention_mask,
                                max_new_tokens=max_total_length - prompt_length,
                                do_sample=True,
                                temperature=rollout_temperature,
                                top_p=1.0,
                                pad_token_id=tokenizer.pad_token_id,
                                return_dict_in_generate=True,
                                output_scores=False,
                                repetition_penalty=1.1,
                            )

                        sampled_ids = outputs.sequences

                        # Compute reward (on CPU)
                        decoded_response = tokenizer.decode(sampled_ids[0].cpu(), skip_special_tokens=True)
                        reward = compute_reward(decoded_response, ground_truths[i])

                        # Memory-efficient reference log probs
                        ref_log_probs = compute_reference_log_probs_memory_efficient(
                            sampled_ids, prompt_length, chunk_size=3
                        )

                        # Store sample data
                        all_samples.append({
                            'sampled_ids': sampled_ids.cpu(),
                            'attention_mask': (sampled_ids != tokenizer.pad_token_id).long().cpu(),
                            'ref_log_probs': ref_log_probs.cpu(),
                            'reward': reward,
                        })

                except Exception as e:
                    print(f"    Error generating sample {sample_idx} for prompt {i}: {e}")
                    continue

                finally:
                    # Clean up
                    variables_to_clean = ['outputs', 'sampled_ids', 'ref_log_probs']
                    for var_name in variables_to_clean:
                        if var_name in locals() and locals()[var_name] is not None:
                            del locals()[var_name]

                # Clear memory after each sample
                if sample_idx % 2 == 0:
                    clear_memory()

            # Clear prompt-specific memory
            del prompt_input_ids, prompt_attention_mask
            clear_memory()

        # --- Loss Calculation - Process samples sequentially ---
        print("    Calculating loss...")
        if not all_samples:
            print("    WARNING: No samples generated. Skipping.")
            continue

        total_samples = len(all_samples)
        micro_batch_loss = 0.0
        micro_batch_rewards = 0.0
        micro_pg_loss, micro_kl_loss, micro_ent_loss = 0.0, 0.0, 0.0

        for sample_idx, sample in enumerate(all_samples):
            # Initialize variables to None
            sampled_ids = attention_mask = ref_log_probs = policy_outputs = None
            policy_logits = sample_loss = weighted_loss = None

            try:
                # Move only one sample to GPU at a time
                sampled_ids = sample['sampled_ids'].to(device, non_blocking=True)
                attention_mask = sample['attention_mask'].to(device, non_blocking=True)
                ref_log_probs = sample['ref_log_probs'].to(device, dtype=torch.float16, non_blocking=True)
                reward = sample['reward']

                # Forward pass with mixed precision
                with torch.amp.autocast('cuda', dtype=torch.float16):
                    policy_outputs = model(
                        input_ids=sampled_ids,
                        attention_mask=attention_mask,
                        output_hidden_states=False,
                        output_attentions=False
                    )

                    # Extract policy logits for generated tokens only
                    policy_logits = policy_outputs.logits[:, prompt_length-1:-1, :]

                    # Compute loss for single sample using safe function
                    sample_loss, (sample_pg, sample_kl, sample_ent) = safe_compute_grpo_loss(
                        policy_logits=policy_logits.unsqueeze(0),  # Add batch dim
                        policy_sampled_ids=sampled_ids.unsqueeze(0),
                        policy_attention_mask=attention_mask.unsqueeze(0),
                        ref_log_probs_generated=ref_log_probs.unsqueeze(0),
                        rewards=[reward],
                        kl_coeff=kl_coeff,
                        entropy_coeff=entropy_coeff,
                        pad_token_id=tokenizer.pad_token_id,
                        prompt_length=prompt_length
                    )

                    # Scale loss for accumulation
                    weighted_loss = sample_loss / (total_samples * accumulation_steps)

                # Backward pass
                scaler.scale(weighted_loss).backward()

                # Accumulate metrics
                micro_batch_loss += weighted_loss.item()
                micro_batch_rewards += reward
                micro_pg_loss += sample_pg / total_samples
                micro_kl_loss += sample_kl / total_samples
                micro_ent_loss += sample_ent / total_samples

            except RuntimeError as e:
                if "out of memory" in str(e):
                    print(f"    OOM at sample {sample_idx}, skipping...")
                    clear_memory()
                    continue
                elif "size of tensor" in str(e):
                    print(f"    Tensor size mismatch at sample {sample_idx}: {e}")
                    print(f"    Sample shapes - sampled_ids: {sample['sampled_ids'].shape}, ref_log_probs: {sample['ref_log_probs'].shape}")
                    clear_memory()
                    continue
                else:
                    print(f"    Runtime error at sample {sample_idx}: {e}")
                    clear_memory()
                    continue

            except Exception as e:
                print(f"    Unexpected error at sample {sample_idx}: {e}")
                clear_memory()
                continue

            finally:
                # Safe cleanup - only delete variables that were defined
                variables_to_delete = ['sampled_ids', 'attention_mask', 'ref_log_probs',
                                     'policy_outputs', 'policy_logits', 'sample_loss', 'weighted_loss']

                for var_name in variables_to_delete:
                    if var_name in locals() and locals()[var_name] is not None:
                        del locals()[var_name]

                if sample_idx % 2 == 0:
                    clear_memory()

        # Update accumulators
        if total_samples > 0:
            total_loss_accum += micro_batch_loss
            avg_reward_accum += micro_batch_rewards / total_samples
            all_pg_loss += micro_pg_loss / accumulation_steps
            all_kl_loss += micro_kl_loss / accumulation_steps
            all_ent_loss += micro_ent_loss / accumulation_steps

            print(f"    Micro-batch Loss: {micro_batch_loss:.4f} (PG: {micro_pg_loss:.4f}, KL: {micro_kl_loss:.4f}, Ent: {micro_ent_loss:.4f})")
            print(f"    Avg Reward: {micro_batch_rewards / total_samples:.4f}")

        # Clean up micro-batch data
        del all_samples, batch, inputs
        clear_memory()

    # --- Optimizer Step ---
    print(f"  Performing optimizer step...")

    # Gradient clipping
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad(set_to_none=True)

    global_step += 1

    # --- Logging ---
    print(f"--- Step {step + 1} Completed ---")
    print(f"  Total Loss: {total_loss_accum:.4f}")
    print(f"  Avg Reward: {avg_reward_accum / accumulation_steps:.4f}")
    print(f"  Components (PG: {all_pg_loss:.4f}, KL: {all_kl_loss:.4f}, Ent: {all_ent_loss:.4f})")

    # --- Checkpointing ---
    if (step + 1) % 50 == 0:
        print(f"Saving checkpoint at step {step + 1}...")
        try:
            # Save to CPU to free GPU memory
            model_cpu = {k: v.cpu() for k, v in model.state_dict().items()}
            checkpoint_path = f"grpo_checkpoint_step_{step + 1}.pt"
            torch.save(model_cpu, checkpoint_path)
            del model_cpu
            print(f"Checkpoint saved to {checkpoint_path}")
        except Exception as e:
            print(f"Error saving checkpoint: {e}")
        clear_memory()

print("\nTraining finished successfully!")


Starting memory-optimized training...
Config: batch_size=4, samples_per_prompt=2

--- Step 1 / 200 ---
  Accumulation Step 1/8
    Generating 2 samples per prompt...


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
Caching is incompatible with gradient checkpointing in Qwen2DecoderLayer. Setting `past_key_values=None`.


    Calculating loss...
    Tensor size mismatch at sample 0: The size of tensor a (768) must match the size of tensor b (256) at non-singleton dimension 2
    Sample shapes - sampled_ids: torch.Size([1, 768]), ref_log_probs: torch.Size([1, 256])
    Tensor size mismatch at sample 1: The size of tensor a (768) must match the size of tensor b (256) at non-singleton dimension 2
    Sample shapes - sampled_ids: torch.Size([1, 768]), ref_log_probs: torch.Size([1, 256])
    Tensor size mismatch at sample 2: The size of tensor a (768) must match the size of tensor b (256) at non-singleton dimension 2
    Sample shapes - sampled_ids: torch.Size([1, 768]), ref_log_probs: torch.Size([1, 256])
    Tensor size mismatch at sample 3: The size of tensor a (768) must match the size of tensor b (256) at non-singleton dimension 2
    Sample shapes - sampled_ids: torch.Size([1, 768]), ref_log_probs: torch.Size([1, 256])
    Tensor size mismatch at sample 4: The size of tensor a (768) must match the size

AssertionError: Attempted unscale_ but _scale is None.  This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration.