In [1]:
!pip install -U "huggingface_hub[cli]"
!pip install transformers datasets tqdm numpy torch
!pip install mambapy

from IPython.display import clear_output
clear_output(wait=False)

In [2]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
import os
os.environ['HF_TOKEN'] = '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'
!huggingface-cli login --token $HF_TOKEN

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: fineGrained).
The token `598Project` has been saved to /home/sbhushan/.cache/huggingface/stored_tokens
Your token has been saved to /home/sbhushan/.cache/huggingface/token
Login successful.
Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


In [4]:
import torch
def print_gpu_memory():
    if torch.cuda.is_available():
        for i in range(torch.cuda.device_count()):
            total_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3
            allocated_memory = torch.cuda.memory_allocated(i) / 1024**3
            cached_memory = torch.cuda.memory_reserved(i) / 1024**3
            free_memory = total_memory - allocated_memory
            
            print(f"GPU {i} - {torch.cuda.get_device_name(i)}")
            print(f"Total Memory: {total_memory:.2f} GB")
            print(f"Allocated Memory: {allocated_memory:.2f} GB")
            print(f"Cached Memory: {cached_memory:.2f} GB")
            print(f"Free Memory: {free_memory:.2f} GB")
            print("-" * 50)
    else:
        print("No GPU available")

# Call the function
print_gpu_memory()

GPU 0 - Tesla V100-PCIE-16GB
Total Memory: 15.77 GB
Allocated Memory: 0.00 GB
Cached Memory: 0.00 GB
Free Memory: 15.77 GB
--------------------------------------------------


In [5]:
def memory_check(step=""):
    print(f"\nMemory Check - {step}")
    print_gpu_memory()

In [6]:
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.utils.data import Dataset, DataLoader
import numpy as np
from tqdm.notebook import tqdm
from mambapy.mamba import Mamba, MambaConfig

In [7]:
class LLMDataset(Dataset):
    def __init__(self, input_ids, attention_masks, labels):
        self.input_ids = input_ids
        self.attention_masks = attention_masks
        self.labels = labels
    
    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        return {
            'input_ids': self.input_ids[idx],
            'attention_mask': self.attention_masks[idx],
            'labels': self.labels[idx]
        }

In [8]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
tokenizer.pad_token = tokenizer.eos_token

# Load streaming dataset
dataset = load_dataset("HuggingFaceFW/fineweb-edu", 
                      name="sample-10BT", 
                      split="train", 
                      streaming=True)

Resolving data files:   0%|          | 0/1630 [00:00<?, ?it/s]

In [9]:
# Initialize storage
collected_tokens = []
total_tokens = 0
target_tokens = 5_000_000

# Collect samples
for sample in dataset:
    # Get text from sample
    text = sample['text']
    
    # Tokenize
    tokens = tokenizer(text, truncation=False, padding=False)['input_ids']
    
    # Add to collection
    collected_tokens.extend(tokens)
    total_tokens += len(tokens)
    
    # Print progress every 100k tokens
    if total_tokens // 100_000 > (total_tokens - len(tokens)) // 100_000:
        print(f"Collected {total_tokens:,} tokens")
    
    # Stop when we hit target
    if total_tokens >= target_tokens:
        break

# Convert to numpy array and trim to exact size
collected_tokens = np.array(collected_tokens[:target_tokens])
print(f"\nFinal token count: {len(collected_tokens):,}")

Collected 101,374 tokens
Collected 201,135 tokens
Collected 301,190 tokens
Collected 400,573 tokens
Collected 500,666 tokens
Collected 600,893 tokens
Collected 701,230 tokens
Collected 800,062 tokens
Collected 900,056 tokens
Collected 1,000,303 tokens
Collected 1,104,261 tokens
Collected 1,200,768 tokens
Collected 1,300,278 tokens
Collected 1,400,226 tokens
Collected 1,500,380 tokens
Collected 1,600,015 tokens
Collected 1,700,608 tokens
Collected 1,810,446 tokens
Collected 1,902,046 tokens
Collected 2,002,909 tokens
Collected 2,100,052 tokens
Collected 2,200,058 tokens
Collected 2,303,549 tokens
Collected 2,401,422 tokens
Collected 2,500,434 tokens
Collected 2,600,557 tokens
Collected 2,700,392 tokens
Collected 2,801,133 tokens
Collected 2,900,204 tokens
Collected 3,001,947 tokens
Collected 3,101,877 tokens
Collected 3,200,123 tokens
Collected 3,300,891 tokens
Collected 3,400,373 tokens
Collected 3,500,236 tokens
Collected 3,600,082 tokens
Collected 3,700,818 tokens
Collected 3,802,260

In [10]:
sequence_length = 512  # Changed from 1024
n_sequences = len(collected_tokens) // sequence_length

# Reshape tokens into sequences
sequences = collected_tokens[:n_sequences * sequence_length].reshape(-1, sequence_length)

# Create input and target pairs for causal language modeling
input_sequences = sequences[:, :-1]  # all tokens except last
target_sequences = sequences[:, 1:]  # all tokens except first

# Convert to torch tensors
inputs = torch.tensor(input_sequences)
masks = torch.tensor(np.ones_like(input_sequences))
targets = torch.tensor(target_sequences)

print(f"Input shape: {inputs.shape}")
print(f"Target shape: {targets.shape}")

Input shape: torch.Size([9765, 511])
Target shape: torch.Size([9765, 511])


In [11]:
val_split = 0.1
val_idx = int(len(inputs) * (1 - val_split))

# Split into train/val
train_inputs = inputs[:val_idx]
train_masks = masks[:val_idx]
train_targets = targets[:val_idx]

val_inputs = inputs[val_idx:]
val_masks = masks[val_idx:]
val_targets = targets[val_idx:]

# Create datasets
train_dataset = LLMDataset(train_inputs, train_masks, train_targets)
val_dataset = LLMDataset(val_inputs, val_masks, val_targets)

# Create dataloaders
batch_size = 2

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,  # Changed to 4
    shuffle=True,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    pin_memory=True
)

In [12]:
print("\nDataset Statistics:")
print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

# Sample a batch
sample_batch = next(iter(train_loader))
print("\nBatch shapes:")
for k, v in sample_batch.items():
    print(f"{k}: {v.shape}")

# Decode a sample sequence
sample_seq = sample_batch['input_ids'][0][:10].tolist()
decoded = tokenizer.decode(sample_seq)
print(f"\nSample decoded text:\n{decoded}")


Dataset Statistics:
Training batches: 4394
Validation batches: 489

Batch shapes:
input_ids: torch.Size([2, 511])
attention_mask: torch.Size([2, 511])
labels: torch.Size([2, 511])

Sample decoded text:
 Betty LaRue was joined by the New York


In [13]:
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")
#model = model.half()

if torch.cuda.is_available():
    model = model.cuda()
    print("Using CUDA")
else:
    print("Using CPU")

Using CUDA


In [14]:
# Initialize 8 Mamba blocks
mamba_blocks = {}

# Configuration for all blocks
mamba_config = MambaConfig(
    d_model=2048,          # Matches Llama's hidden size
    n_layers=1             
)

# Create blocks for layers 0,2,4,6,8,10,12,14
#for layer_idx in [1, 3, 5, 7, 8, 10, 12, 14]:
for layer_idx in [0, 2, 4, 6, 8, 10, 12, 14]:
    
    mamba_blocks[layer_idx] = Mamba(mamba_config).cuda()

# Verify shapes for each block
test_input = torch.randn(1, 5, 2048).cuda()  # Changed dtype
for layer_idx, block in mamba_blocks.items():
    test_output = block(test_input)
    print(f"Layer {layer_idx} - Input shape: {test_input.shape}, Output shape: {test_output.shape}")
    print(f"Layer {layer_idx} - Input dtype: {test_input.dtype}, Output dtype: {test_output.dtype}")

Layer 0 - Input shape: torch.Size([1, 5, 2048]), Output shape: torch.Size([1, 5, 2048])
Layer 0 - Input dtype: torch.float32, Output dtype: torch.float32
Layer 2 - Input shape: torch.Size([1, 5, 2048]), Output shape: torch.Size([1, 5, 2048])
Layer 2 - Input dtype: torch.float32, Output dtype: torch.float32
Layer 4 - Input shape: torch.Size([1, 5, 2048]), Output shape: torch.Size([1, 5, 2048])
Layer 4 - Input dtype: torch.float32, Output dtype: torch.float32
Layer 6 - Input shape: torch.Size([1, 5, 2048]), Output shape: torch.Size([1, 5, 2048])
Layer 6 - Input dtype: torch.float32, Output dtype: torch.float32
Layer 8 - Input shape: torch.Size([1, 5, 2048]), Output shape: torch.Size([1, 5, 2048])
Layer 8 - Input dtype: torch.float32, Output dtype: torch.float32
Layer 10 - Input shape: torch.Size([1, 5, 2048]), Output shape: torch.Size([1, 5, 2048])
Layer 10 - Input dtype: torch.float32, Output dtype: torch.float32
Layer 12 - Input shape: torch.Size([1, 5, 2048]), Output shape: torch.Size

In [15]:
# Freeze LLaMA
for param in model.parameters():
    param.requires_grad = False

# Verify LLaMA is frozen
trainable_params = [p for p in model.parameters() if p.requires_grad]
print(f"LLaMA trainable parameters: {len(trainable_params)} (should be 0)")

# Ensure Mamba blocks are trainable
for mamba in mamba_blocks.values():
    for param in mamba.parameters():
        param.requires_grad = True

LLaMA trainable parameters: 0 (should be 0)


In [16]:
optimizers = {}
for layer_idx in [0, 2, 4, 6, 8, 10, 12, 14]:
    optimizers[layer_idx] = torch.optim.AdamW(mamba_blocks[layer_idx].parameters(), lr=1e-4)

In [17]:
from tqdm.notebook import tqdm
import time
from datetime import timedelta
import torch.nn.functional as F

In [18]:
def train_epoch(model, mamba_blocks, train_loader, optimizers, accumulation_steps, mse_factor=1e3):
    layer_indices = [0, 2, 4, 6, 8, 10, 12, 14]
    layer_losses = {idx: 0.0 for idx in layer_indices}
    
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc='Training')
    
    for batch_idx, batch in progress_bar:
        # Move inputs to GPU efficiently
        input_ids = batch['input_ids'].cuda()
        
        # Create position IDs
        input_shape = input_ids.size()
        position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        
        # Get LLaMA hidden states and cache them
        with torch.no_grad():
            true_outputs = model(
                input_ids=input_ids,
                position_ids=position_ids,  # Add position_ids here
                output_hidden_states=True,
                use_cache=False
            )
            hidden_states = true_outputs.hidden_states
            
            # Extract teacher model outputs for each layer
            teacher_outputs = []
            for layer_idx in layer_indices:
                layer = model.model.layers[layer_idx]
                # Get normalized input and attention output
                layer_input = layer.input_layernorm(hidden_states[layer_idx])
                # Call attention with hidden states and position_ids
                attn_output = layer.self_attn(
                    hidden_states=layer_input,
                    position_ids=position_ids  # Add position_ids here
                )[0]
                
                # Verify shapes before storing
                if attn_output is not None and layer_input is not None:
                    teacher_outputs.append((layer_input.cpu(), attn_output.cpu()))
                else:
                    raise ValueError(f"Layer {layer_idx} produced None output")
            
            # Clear GPU memory
            del true_outputs, hidden_states
            torch.cuda.empty_cache()

        # Process each Mamba block
        total_loss = 0
        for idx, layer_idx in enumerate(layer_indices):
            layer_input, teacher_output = teacher_outputs[idx]
            
            # Zero gradients
            optimizers[layer_idx].zero_grad()
            
            # Forward pass through Mamba
            mamba_output = mamba_blocks[layer_idx](
                layer_input.cuda(),
            )
            
            # Verify Mamba output shape
            if mamba_output is None:
                raise ValueError(f"Mamba block {layer_idx} produced None output")
            
            # Compute loss
            loss = F.mse_loss(
                mamba_output,
                teacher_output.cuda()
            ) * mse_factor
            
            # Scale loss and backward pass
            scaled_loss = loss / accumulation_steps
            scaled_loss.backward()
            
            # Update tracking
            layer_losses[layer_idx] += loss.item()
            total_loss += loss.item()
            
            # Optimizer step on accumulation boundary
            if (batch_idx + 1) % accumulation_steps == 0:
                optimizers[layer_idx].step()
            
            # Clear memory
            del mamba_output, layer_input, teacher_output
            torch.cuda.empty_cache()
        
        # Update progress bar
        if (batch_idx + 1) % accumulation_steps == 0:
            avg_losses = {
                idx: loss/accumulation_steps 
                for idx, loss in layer_losses.items()
            }
            avg_total = total_loss/len(layer_indices)
            
            loss_str = " ".join([
                f"L{idx}: {loss:.4f}" 
                for idx, loss in avg_losses.items()
            ])
            progress_bar.set_description(
                f"Train | Avg: {avg_total:.4f} | {loss_str}"
            )
            
            # Reset loss tracking
            layer_losses = {idx: 0.0 for idx in layer_indices}
            total_loss = 0
            
    return layer_losses

In [19]:
def validate(model, mamba_blocks, val_loader, mse_factor=1e3):
    layer_indices = [0, 2, 4, 6, 8, 10, 12, 14]
    layer_losses = {idx: 0.0 for idx in layer_indices}
    num_batches = 0
    
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].cuda()
            position_ids = torch.arange(0, input_ids.size(-1), device=input_ids.device).unsqueeze(0)
            
            # Get teacher outputs
            outputs = model(input_ids=input_ids, 
                          position_ids=position_ids,
                          output_hidden_states=True)
            hidden_states = outputs.hidden_states
            
            # Get teacher attention outputs
            teacher_outputs = []
            for layer_idx in layer_indices:
                layer = model.model.layers[layer_idx]
                layer_input = layer.input_layernorm(hidden_states[layer_idx])
                attn_output = layer.self_attn(
                    hidden_states=layer_input,
                    position_ids=position_ids
                )[0]
                teacher_outputs.append((layer_input, attn_output))
            
            # Compare with Mamba outputs
            for idx, layer_idx in enumerate(layer_indices):
                layer_input, teacher_output = teacher_outputs[idx]
                mamba_output = mamba_blocks[layer_idx](layer_input)
                loss = F.mse_loss(mamba_output, teacher_output) * mse_factor
                layer_losses[layer_idx] += loss.item()
            
            num_batches += 1
            
    return {idx: loss/num_batches for idx, loss in layer_losses.items()}

In [20]:
os.makedirs('mamba_blocks', exist_ok=True)

In [21]:
num_epochs = 2
accumulation_steps = 32
best_val_losses = {idx: float('inf') for idx in [0, 2, 4, 6, 8, 10, 12, 14]}
patience = 3
no_improve_count = {idx: 0 for idx in [0, 2, 4, 6, 8, 10, 12, 14]}

start_time = time.time()

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    
    # Train
    train_losses = train_epoch(model, mamba_blocks, train_loader, optimizers, accumulation_steps)
    
    # Validate
    val_losses = validate(model, mamba_blocks, val_loader)
    
    # Print validation losses
    val_str = " ".join([f"L{idx}: {loss:.4f}" for idx, loss in val_losses.items()])
    print(f"Validation losses: {val_str}")
    
    # Check early stopping per layer
    active_layers = False
    for layer_idx in [0, 2, 4, 6, 8, 10, 12, 14]:
        if val_losses[layer_idx] < best_val_losses[layer_idx]:
            best_val_losses[layer_idx] = val_losses[layer_idx]
            no_improve_count[layer_idx] = 0
            # Save best model for this layer
            torch.save(mamba_blocks[layer_idx].state_dict(), f'mamba_blocks/mamba_layer_{layer_idx}.pt')
        else:
            no_improve_count[layer_idx] += 1
            if no_improve_count[layer_idx] >= patience:
                print(f"Early stopping for layer {layer_idx}")
                # Load best model for this layer
                mamba_blocks[layer_idx].load_state_dict(torch.load(f'mamba_blocks/mamba_layer_{layer_idx}.pt'))
            else:
                active_layers = True
    
    if not active_layers:
        print("All layers converged - stopping training")
        break
    
    epoch_time = time.time() - start_time
    print(f"Time elapsed: {timedelta(seconds=int(epoch_time))}")

total_time = time.time() - start_time
print(f"\nTraining completed in: {timedelta(seconds=int(total_time))}")


Epoch 1/2


Training:   0%|          | 0/4394 [00:00<?, ?it/s]

The attention layers in this model are transitioning from computing the RoPE embeddings internally through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed `position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be removed and `position_embeddings` will be mandatory.


Validation losses: L0: 10.9501 L2: 62.4131 L4: 64.3233 L6: 79.8924 L8: 91.3085 L10: 116.2376 L12: 97.6159 L14: 106.8220
All layers converged - stopping training

Training completed in: 0:59:47


In [22]:
class MambaWrapper(torch.nn.Module):
    def __init__(self, mamba_block):
        super().__init__()
        self.mamba = mamba_block
        self.last_state = None

    def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False, **kwargs):
        if past_key_value is not None:
            self.last_state = past_key_value
        output = self.mamba(hidden_states)
        self.last_state = output.last_state if hasattr(output, 'last_state') else None
        return (output, None, self.last_state)

# Replace attention with wrapped Mamba blocks
for layer_idx in [0, 2, 4, 6, 8, 10, 12, 14]:
    model.model.layers[layer_idx].self_attn = MambaWrapper(mamba_blocks[layer_idx])

In [55]:
test_input = torch.randint(0, 1000, (1, 10)).cuda()  # batch_size=1, seq_len=10
test_output = model(test_input)
print("Forward pass successful!")

NameError: name 'model' is not defined

In [24]:
def prepare_for_finetuning(model):
    print("Freezing all parameters except Mamba blocks...")
    
    # First freeze everything
    for param in model.parameters():
        param.requires_grad = False
    
    # Count trainable parameters and unfreeze only Mamba blocks
    num_params = 0
    for layer_idx, layer in enumerate(model.model.layers):
        if isinstance(layer.self_attn, MambaWrapper):
            for param in layer.self_attn.mamba.parameters():
                param.requires_grad = True
                num_params += param.numel()
    
    print(f"Number of trainable parameters: {num_params:,}")

    # Verify which layers are trainable
    print("\nTrainable layers:")
    for layer_idx, layer in enumerate(model.model.layers):
        if isinstance(layer.self_attn, MambaWrapper):
            num_trainable = sum(p.requires_grad for p in layer.self_attn.mamba.parameters())
            print(f"Layer {layer_idx}: {num_trainable} trainable parameters")

In [54]:
# Call the function
prepare_for_finetuning(model)

NameError: name 'model' is not defined

In [56]:
def verify_grad_status(model):
    print("\nGradient status check:")
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"Trainable: {name}")

verify_grad_status(model)

NameError: name 'model' is not defined

In [53]:
def save_hybrid_model(model, mamba_blocks):
    save_dict = {
        'model_state': model.state_dict(),
        'mamba_configs': {
            'hidden_size': 2048,
            'replaced_layers': [0, 2, 4, 6, 8, 10, 12, 14]
        },
        'mamba_states': {
            idx: block.state_dict() 
            for idx, block in mamba_blocks.items()
        }
    }
    
    os.makedirs('hybrid_model', exist_ok=True)
    torch.save(save_dict, 'hybrid_model/pre_lora_model.pt')
    print("Hybrid model saved successfully!")

# Save the current model
save_hybrid_model(model, mamba_blocks)

NameError: name 'model' is not defined

In [40]:
# Load dataset
from datasets import load_dataset
dataset = load_dataset("yahma/alpaca-cleaned")
train_val = dataset['train'].train_test_split(test_size=0.1)
train_dataset, val_dataset = train_val['train'], train_val['test']

# Format with template
def format_alpaca_prompt(instruction, output):
    return f"<|begin of text|> Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n{output}<|end of text|>"

# Create a custom dataset class
class AlpacaDataset(Dataset):
    def __init__(self, dataset, tokenizer):
        self.dataset = dataset
        self.tokenizer = tokenizer
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        item = self.dataset[idx]
        text = format_alpaca_prompt(item['instruction'], item['output'])
        
        encodings = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=512,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encodings['input_ids'].squeeze(),
            'attention_mask': encodings['attention_mask'].squeeze()
        }

# Create proper collate function
def collate_fn(batch):
    input_ids = torch.stack([item['input_ids'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask
    }

# Create datasets
train_dataset = AlpacaDataset(train_val['train'], tokenizer)
val_dataset = AlpacaDataset(train_val['test'], tokenizer)

# Create dataloaders with collate_fn
train_loader = DataLoader(
    train_dataset,
    batch_size=2,
    shuffle=True,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=2,
    shuffle=False,
    collate_fn=collate_fn
)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

Training samples: 46584
Validation samples: 5176


In [51]:
# Delete existing model and optimizers if they exist
try:
    del model
    del optimizer
    del scheduler
    del train_loader
    del val_loader
except:
    pass

# Clear CUDA cache
torch.cuda.empty_cache()

# Force garbage collection
import gc
gc.collect()

372

In [52]:
def print_gpu_memory():
    print("\nGPU Memory Status:")
    print(f"Total Memory: {torch.cuda.get_device_properties(0).total_memory/1e9:.2f} GB")
    print(f"Allocated Memory: {torch.cuda.memory_allocated()/1e9:.2f} GB")
    print(f"Cached Memory: {torch.cuda.memory_reserved()/1e9:.2f} GB")
    print(f"Free Memory: {(torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated())/1e9:.2f} GB")

print_gpu_memory()


GPU Memory Status:
Total Memory: 16.93 GB
Allocated Memory: 15.99 GB
Cached Memory: 16.07 GB
Free Memory: 0.94 GB


In [29]:
from peft import LoraConfig, get_peft_model

# Define LoRA config targeting Mamba projection layers
lora_config = LoraConfig(
    r=8,  # LoRA rank
    lora_alpha=16,  # Scaling factor
    target_modules=["in_proj", "x_proj", "dt_proj", "out_proj"],  # Mamba projection layers
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM"
)

In [33]:
def train_epoch_lora(model, train_loader, optimizer, scheduler, accumulation_steps=128):
    model.train()
    total_loss = 0
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))
    
    for batch_idx, batch in progress_bar:
        input_ids = batch['input_ids'].cuda()
        attention_mask = batch['attention_mask'].cuda()
        labels = input_ids.clone()
        
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        
        loss = outputs.loss / accumulation_steps
        total_loss += loss.item() * accumulation_steps
        
        loss.backward()
        
        if (batch_idx + 1) % accumulation_steps == 0:
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            
            # Clear cache periodically
            if batch_idx % (accumulation_steps * 4) == 0:
                torch.cuda.empty_cache()
            
        progress_bar.set_description(
            f"Loss: {total_loss/(batch_idx+1):.4f}"
        )
    
    return total_loss / len(train_loader)

In [34]:
def validate_lora(model, val_loader):
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validating"):
            input_ids = batch['input_ids'].cuda()
            attention_mask = batch['attention_mask'].cuda()
            labels = input_ids.clone()
            
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            
            total_loss += outputs.loss.item()
            
    return total_loss / len(val_loader)

In [41]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=2e-4,
    weight_decay=0.01
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, 
    T_max=len(train_loader)
)

num_epochs = 3
best_val_loss = float('inf')
patience = 2
no_improve = 0

print("Starting training...")
print(f"Number of training batches: {len(train_loader)}")
print(f"Number of validation batches: {len(val_loader)}")

# Training loop
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    
    # Clear cache at start of epoch
    torch.cuda.empty_cache()
    
    train_loss = train_epoch_lora(
        model, 
        train_loader, 
        optimizer,
        scheduler
    )
    
    val_loss = validate_lora(model, val_loader)
    
    print(f"Train loss: {train_loss:.4f}")
    print(f"Val loss: {val_loss:.4f}")
    
    # Early stopping and model saving
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improve = 0
        print("Saving best model...")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
        }, "best_hybrid_model.pt")
    else:
        no_improve += 1
        if no_improve >= patience:
            print("Early stopping triggered")
            break
            
print("Training completed!")

Starting training...
Number of training batches: 23292
Number of validation batches: 2588

Epoch 1/3


  0%|          | 0/23292 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 256.00 MiB. GPU 0 has a total capacity of 15.77 GiB of which 172.19 MiB is free. Including non-PyTorch memory, this process has 15.59 GiB memory in use. Of the allocated memory 15.14 GiB is allocated by PyTorch, and 78.69 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)