In [1]:
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset

def get_dataset(dataset_name, subset=None, split="train", size=None):
    """Load and optionally subset a dataset."""
    dataset = load_dataset(dataset_name, subset, split=split)
    if size is not None:
        dataset = dataset.select(range(size))
    return dataset

def format_alpaca_as_chat(example):
    """Format Alpaca data as Command-R chat messages."""
    instruction = example["instruction"]
    input_text = example["input"]
    full_instruction = f"{instruction}\n{input_text}" if input_text else instruction
    
    messages = [{"role": "user", "content": full_instruction}]
    return messages

def eval_cohere_ppl(model, tokenizer, dataset, device="cuda", debug=False):
    """Evaluate perplexity of Cohere model on dataset."""
    model.eval()
    total_loss = 0
    total_tokens = 0
    
    with torch.no_grad():
        for example in tqdm(dataset):
            messages = format_alpaca_as_chat(example)
            
            # Format with chat template and get target
            input_ids = tokenizer.apply_chat_template(
                messages, 
                tokenize=True,
                add_generation_prompt=True,
                return_tensors="pt"
            ).to(device)
            
            # Get model outputs
            outputs = model(input_ids, labels=input_ids)
            loss = outputs.loss
            
            # Calculate perplexity
            total_loss += loss.item() * input_ids.size(1)
            total_tokens += input_ids.size(1)
            
            if debug:
                print(f"Example input: {messages}")
                print(f"Tokenized length: {input_ids.size(1)}")
                print(f"Loss: {loss.item()}")
                break
    
    avg_loss = total_loss / total_tokens
    ppl = torch.exp(torch.tensor(avg_loss))
    return ppl.item()

# Example usage in notebook cells:

# Cell 1: Load model and tokenizer
model_name = "/home/riyasatohib_cohere_com/repos/models/command-r-refresh"
print(f"Loading model: {model_name}")

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto"
)

# Cell 2: Load dataset
dataset_size = 250  # Adjust as needed
dataset = get_dataset(
    "tatsu-lab/alpaca",
    subset=None,
    split="train",
    size=dataset_size
)

# Cell 3: Test with one example (debug mode)
debug_ppl = eval_cohere_ppl(model, tokenizer, dataset, debug=True)
print(f"Debug example perplexity: {debug_ppl:.2f}")

# Cell 4: Run full evaluation
print(f"Evaluating perplexity on {dataset_size} examples...")
ppl = eval_cohere_ppl(model, tokenizer, dataset, debug=False)
print(f"Perplexity: {ppl:.2f}")

  from .autonotebook import tqdm as notebook_tqdm


Loading model: /home/riyasatohib_cohere_com/repos/models/command-r-refresh


Loading checkpoint shards: 100%|██████████| 14/14 [00:11<00:00,  1.22it/s]
  0%|          | 0/250 [00:00<?, ?it/s]


Example input: [{'role': 'user', 'content': 'Give three tips for staying healthy.'}]
Tokenized length: 13
Loss: 2.444206714630127
Debug example perplexity: 11.52
Evaluating perplexity on 250 examples...


100%|██████████| 250/250 [00:08<00:00, 30.43it/s]

Perplexity: 11.44





In [2]:
# Cell 1: Imports and Function Definitions
import torch
import time
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

def get_dataset(dataset_name, subset=None, split="train", size=None):
    """Load and optionally subset a dataset."""
    dataset = load_dataset(dataset_name, subset, split=split)
    if size is not None:
        dataset = dataset.select(range(size))
    print(f"Loaded dataset with {len(dataset)} examples")
    return dataset

def format_alpaca_as_chat(example):
    """Format Alpaca data as Command-R chat messages."""
    instruction = example["instruction"]
    input_text = example["input"]
    full_instruction = f"{instruction}\n{input_text}" if input_text else instruction
    messages = [{"role": "user", "content": full_instruction}]
    return messages

def eval_cohere_ppl(model, tokenizer, dataset, num_examples=None, device="cuda", debug=False):
    """
    Evaluate perplexity of Cohere model on dataset with clean progress display.
    """
    model.eval()
    total_loss = 0
    total_tokens = 0
    lengths = []
    
    # Select subset if num_examples is specified
    if num_examples is not None:
        num_examples = min(num_examples, len(dataset))
        dataset = dataset.select(range(num_examples))
    else:
        num_examples = len(dataset)
    
    # Create progress bar
    pbar = tqdm(dataset, total=num_examples, desc="Computing perplexity")
    
    with torch.no_grad():
        for example in pbar:
            messages = format_alpaca_as_chat(example)
            
            input_ids = tokenizer.apply_chat_template(
                messages, 
                tokenize=True,
                add_generation_prompt=True,
                return_tensors="pt"
            ).to(device)
            
            lengths.append(input_ids.size(1))
            outputs = model(input_ids, labels=input_ids)
            loss = outputs.loss
            
            total_loss += loss.item() * input_ids.size(1)
            total_tokens += input_ids.size(1)
            
            # Update progress bar with current PPL
            current_ppl = torch.exp(torch.tensor(total_loss / total_tokens)).item()
            pbar.set_postfix({'PPL': f'{current_ppl:.2f}'}, refresh=True)
            
            if debug and len(lengths) == 1:
                print(f"\nExample input: {messages}")
                print(f"Tokenized length: {input_ids.size(1)}")
                print(f"Loss: {loss.item()}")
                break
    
    # Calculate final perplexity
    avg_loss = total_loss / total_tokens
    ppl = torch.exp(torch.tensor(avg_loss))
    
    # Print summary
    print(f"\nEvaluation Summary:")
    print(f"Processed {num_examples:,} examples, {total_tokens:,} tokens")
    print(f"Average tokens per example: {total_tokens/num_examples:.1f}")
    print(f"Max length: {max(lengths)}, Min length: {min(lengths)}, Avg length: {sum(lengths)/len(lengths):.1f}")
    print(f"Final Perplexity: {ppl:.2f}")
    
    return ppl.item()

In [3]:
# Cell 2: Load Model and Tokenizer
model_name = "/home/riyasatohib_cohere_com/repos/models/command-r-refresh"
print(f"Loading model: {model_name}")

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto"
)

# Cell 3: Load Dataset
# Change size=None to process all examples, or set a specific number
dataset_size = None  # Try with 1000 first, then set to None for full dataset
dataset = get_dataset(
    "tatsu-lab/alpaca",
    subset=None,
    split="train",
    size=dataset_size
)

# Cell 4: Optional Debug Run (run this cell to test with one example)
# debug_dataset = dataset.select(range(1))
# debug_ppl = eval_cohere_ppl(model, tokenizer, dataset, num_examples=500, device="cuda", debug=False)
# # debug_ppl = eval_cohere_ppl(model, tokenizer, debug_dataset, debug=True)
# print(f"Debug example perplexity: {debug_ppl:.2f}")

# Cell 5: Full Evaluation
print(f"Starting full evaluation on {len(dataset)} examples...")
ppl = eval_cohere_ppl(model, tokenizer, dataset, num_examples=500, device="cuda", debug=False)
print(f"Final Perplexity: {ppl:.2f}")

Loading model: /home/riyasatohib_cohere_com/repos/models/command-r-refresh


Loading checkpoint shards: 100%|██████████| 14/14 [00:10<00:00,  1.35it/s]


Loaded dataset with 52002 examples
Starting full evaluation on 52002 examples...


Computing perplexity: 100%|██████████| 500/500 [00:16<00:00, 29.87it/s, PPL=11.54]


Evaluation Summary:
Processed 500 examples, 11,379 tokens
Average tokens per example: 22.8
Max length: 110, Min length: 11, Avg length: 22.8
Final Perplexity: 11.54
Final Perplexity: 11.54





---------
### Modification of the TEAL ppl file

In [4]:
import sys,os
# sys.path.append('../')
sys.path.append('/home/riyasatohib_cohere_com/repos/teal_clone/utils/utils.py')

import torch
from tqdm import tqdm
import os
import argparse

from utils.utils import get_tokenizer, get_sparse_model
from utils.eval_ppl import eval_ppl

from teal.model import (
    LlamaSparseForCausalLM, 
    LlamaSparseConfig,
    MistralSparseForCausalLM, 
    MistralSparseConfig,
    CohereSparseForCausalLM,
    CohereSparseConfig
)

from utils.data import get_dataset

from transformers import AutoConfig, AutoModelForCausalLM

AutoConfig.register("llama_sparse", LlamaSparseConfig)
AutoConfig.register("mistral_sparse", MistralSparseConfig)
AutoConfig.register("cohere_sparse", CohereSparseConfig)

AutoModelForCausalLM.register(LlamaSparseConfig, LlamaSparseForCausalLM)
AutoModelForCausalLM.register(MistralSparseConfig, MistralSparseForCausalLM)
AutoModelForCausalLM.register(CohereSparseConfig, CohereSparseForCausalLM)

In [5]:
# parser = argparse.ArgumentParser(description="Parse command line arguments for the script.")
# parser.add_argument('--model_name', type=str, default="meta-llama/Llama-2-7b-hf",help='Name of the model to use')
# parser.add_argument('--hist_path', type=str, default="meta-llama/Llama-2-7b-hf",help='Name of the model to use')
# parser.add_argument('--teal_path', type=str, required=True,help='Path to the teal input')
# parser.add_argument('--save_path', type=str, default="./model", required=True,help='Path to the teal input')
# parser.add_argument('--greedy_flag', action='store_true', help='Flag for greedy')
# parser.add_argument('--sparsity', type=float, default=0.5, help='Sparsity level')
# args = parser.parse_args()

In [6]:
class Args:
    def __init__(self):
        self.model_name = "/home/riyasatohib_cohere_com/repos/models/command-r-refresh"
        self.hist_path = "/home/riyasatohib_cohere_com/repos/teal_clone/models/command-r-refresh/histograms"
        self.teal_path = '/home/riyasatohib_cohere_com/repos/teal_clone/models'  # Required argument
        self.save_path = "/home/riyasatohib_cohere_com/repos/models/command-refresh-sparse/"  # Default value though marked as required
        self.greedy_flag = False
        self.sparsity = 0.3

    def __str__(self):
        return "\n".join([f"{key}={value}" for key, value in vars(self).items()])

args = Args()
args.model_name

'/home/riyasatohib_cohere_com/repos/models/command-r-refresh'

In [7]:
## Dataset has to be loaded like this for Command-r-compatibility
# tokenizer = get_tokenizer(args.model_name)
sps_model = get_sparse_model(args.model_name, device="auto", histogram_path=args.hist_path, apply_prefill=False) # Add this

You are using a model of type cohere to instantiate a model of type cohere_sparse. This is not supported for all configurations of models and can yield errors.
Loading checkpoint shards: 100%|██████████| 14/14 [00:10<00:00,  1.32it/s]
  histogram = torch.load(f"{self.file_path}/histograms.pt")


In [8]:
dataset_size = None  # Try with 1000 first, then set to None for full dataset
dataset = get_dataset(
    "tatsu-lab/alpaca",
    subset=None,
    split="train",
    size=dataset_size
)
args.model_name

'/home/riyasatohib_cohere_com/repos/models/command-r-refresh'

In [9]:
# print("Evaluating dense PPL")
# print("="*40)
# dense_ppl = eval_ppl(model, tokenizer, device="cuda", dataset=dataset, debug=False)
# print(f"PPL: {dense_ppl}")


# print("Evaluating sparse PPL at sparsity level: ", args.sparsity)
# print("="*40)
# if args.greedy_flag:
#     print("Evaluating greedy PPL")
#     greedy_path = os.path.join(args.teal_path, "lookup")
#     model.load_greedy_sparsities(greedy_path, args.sparsity)
# else:
#     print("Evaluating uniform PPL")
#     model.set_uniform_sparsity(args.sparsity)

# sparse_ppl = eval_ppl(model, tokenizer, device="cuda", dataset=dataset, debug=False)
# print(f"PPL: {sparse_ppl}")

# print("="*40)

### Saving model
# print(f"saving the model")
# model.save_pretrained(args.save_path)
# tokenizer.save_pretrained(args.save_path)

In [10]:
%%capture 

# set sparsity
args.sparsity = 0.00000001

print("Evaluating sparse PPL at sparsity level: ", args.sparsity)
print("="*40)
if args.greedy_flag:
    print("Evaluating greedy PPL")
    greedy_path = os.path.join(args.teal_path, "lookup")
    sps_model.load_greedy_sparsities(greedy_path, args.sparsity)
else:
    print("Evaluating uniform PPL")
    sps_model.set_uniform_sparsity(args.sparsity)
    
ppl = eval_cohere_ppl(sps_model, tokenizer, dataset, num_examples=500, device="cuda", debug=False)

In [11]:
ppl = eval_cohere_ppl(model, tokenizer, dataset, num_examples=500, device="cuda", debug=False)

Computing perplexity: 100%|██████████| 500/500 [00:16<00:00, 30.23it/s, PPL=11.54]


Evaluation Summary:
Processed 500 examples, 11,379 tokens
Average tokens per example: 22.8
Max length: 110, Min length: 11, Avg length: 22.8
Final Perplexity: 11.54





In [12]:
print(ppl)

11.53911018371582


In [11]:
dataset_size = None  # Try with 1000 first, then set to None for full dataset
dataset = get_dataset(
    "tatsu-lab/alpaca",
    subset=None,
    split="train",
    size=dataset_size
)

Loaded dataset with 52002 examples


--------
## Diagnostics

In [23]:
def verify_sparse_model_setup(sps_model, tokenizer, dataset, verbose=True):
    """Verify sparse model configuration and test with actual dataset samples"""
    issues_found = []
    
    # Check layer configurations
    for i, layer in enumerate(sps_model.model.layers):
        # Check MLP sparsification
        for name in ['gate', 'up', 'down']:
            thresh = layer.mlp.sparse_fns[name].threshold
            distr = layer.mlp.sparse_fns[name].distr
            
            if abs(thresh) > 10:  # Unusually large threshold
                issues_found.append(f"Layer {i} MLP {name} has large threshold: {thresh}")
            
            if distr is None:
                issues_found.append(f"Layer {i} MLP {name} missing distribution")
            else:
                centers = distr.bin_centers
                counts = distr.counts
                if torch.any(torch.isnan(centers)) or torch.any(torch.isnan(counts)):
                    issues_found.append(f"Layer {i} MLP {name} has NaN in distribution")
                if centers.numel() == 0 or counts.sum() == 0:
                    issues_found.append(f"Layer {i} MLP {name} has empty distribution")
        
        # Check Attention sparsification
        for name in ['q', 'k', 'v', 'o']:
            thresh = layer.self_attn.sparse_fns[name].threshold
            distr = layer.self_attn.sparse_fns[name].distr
            
            if abs(thresh) > 10:
                issues_found.append(f"Layer {i} Attention {name} has large threshold: {thresh}")
            
            if distr is None:
                issues_found.append(f"Layer {i} Attention {name} missing distribution")
            else:
                centers = distr.bin_centers
                counts = distr.counts
                if torch.any(torch.isnan(centers)) or torch.any(torch.isnan(counts)):
                    issues_found.append(f"Layer {i} Attention {name} has NaN in distribution")
                if centers.numel() == 0 or counts.sum() == 0:
                    issues_found.append(f"Layer {i} Attention {name} has empty distribution")
    
    # Test forward pass with actual dataset sample
    try:
        # Get first example from dataset
        sample_text = dataset[0] if isinstance(dataset[0], str) else dataset[0]['text']
        input_ids = tokenizer(sample_text, return_tensors="pt").input_ids.to(next(sps_model.parameters()).device)
        
        # Track activation sparsity during forward pass
        activation_stats = {}
        
        def hook_fn(name):
            def _hook(module, input, output):
                if hasattr(module, 'threshold'):
                    mask = input[0].abs() > module.threshold
                    sparsity = 1 - mask.float().mean().item()
                    activation_stats[name] = sparsity
            return _hook
        
        # Register hooks for first layer
        first_layer = sps_model.model.layers[0]
        hooks = []
        for name in ['gate', 'up', 'down']:
            hooks.append(first_layer.mlp.sparse_fns[name].register_forward_hook(hook_fn(f'mlp_{name}')))
        for name in ['q', 'k', 'v', 'o']:
            hooks.append(first_layer.self_attn.sparse_fns[name].register_forward_hook(hook_fn(f'attn_{name}')))
        
        with torch.no_grad():
            outputs = sps_model(input_ids)
            
            if torch.any(torch.isnan(outputs.logits)):
                issues_found.append("Forward pass produces NaN logits")
            if torch.any(torch.isinf(outputs.logits)):
                issues_found.append("Forward pass produces Inf logits")
        
        # Remove hooks
        for hook in hooks:
            hook.remove()
            
    except Exception as e:
        issues_found.append(f"Forward pass failed: {str(e)}")
    
    if verbose:
        if issues_found:
            print("\nIssues found:")
            for issue in issues_found:
                print(f"- {issue}")
        else:
            print("\nNo issues found in model configuration")
            
        # Print sample of thresholds and achieved sparsity
        print("\nFirst layer statistics:")
        first_layer = sps_model.model.layers[0]
        print("MLP:")
        for name in ['gate', 'up', 'down']:
            thresh = first_layer.mlp.sparse_fns[name].threshold
            achieved_sparsity = activation_stats.get(f'mlp_{name}')
            sparsity_str = f"{achieved_sparsity:.4f}" if achieved_sparsity is not None else "N/A"
            print(f"  {name}: threshold={thresh:.4f}, achieved_sparsity={sparsity_str}")
        
        print("Attention:")
        for name in ['q', 'k', 'v', 'o']:
            thresh = first_layer.self_attn.sparse_fns[name].threshold
            achieved_sparsity = activation_stats.get(f'attn_{name}')
            sparsity_str = f"{achieved_sparsity:.4f}" if achieved_sparsity is not None else "N/A"
            print(f"  {name}: threshold={thresh:.4f}, achieved_sparsity={sparsity_str}")
    
    return issues_found, activation_stats

# Usage:
print("Verifying sparse model configuration...")
issues, act_stats = verify_sparse_model_setup(sps_model, tokenizer, dataset)

Verifying sparse model configuration...
Casting input hidden states to torch.float16 (this should not be happening)
Casting input hidden states to torch.float16 (this should not be happening)
Casting input hidden states to torch.float16 (this should not be happening)
Casting input hidden states to torch.float16 (this should not be happening)
Casting input hidden states to torch.float16 (this should not be happening)
Casting input hidden states to torch.float16 (this should not be happening)
Casting input hidden states to torch.float16 (this should not be happening)
Casting input hidden states to torch.float16 (this should not be happening)
Casting input hidden states to torch.float16 (this should not be happening)
Casting input hidden states to torch.float16 (this should not be happening)
Casting input hidden states to torch.float16 (this should not be happening)
Casting input hidden states to torch.float16 (this should not be happening)
Casting input hidden states to torch.float16 (th

In [9]:
sps_model

CohereSparseForCausalLM(
  (model): CohereModel(
    (embed_tokens): Embedding(256000, 8192, padding_idx=0)
    (layers): ModuleList(
      (0-39): 40 x CohereDecoderLayer(
        (self_attn): CohereFlashAttention2(
          (q_proj): Linear(in_features=8192, out_features=8192, bias=False)
          (k_proj): Linear(in_features=8192, out_features=1024, bias=False)
          (v_proj): Linear(in_features=8192, out_features=1024, bias=False)
          (o_proj): Linear(in_features=8192, out_features=8192, bias=False)
          (rotary_emb): CohereRotaryEmbedding()
          (sparse_fns): ModuleDict(
            (q): SparsifyFn()
            (k): SparsifyFn()
            (v): SparsifyFn()
            (o): SparsifyFn()
          )
        )
        (mlp): CohereMLP(
          (gate_proj): Linear(in_features=8192, out_features=24576, bias=False)
          (up_proj): Linear(in_features=8192, out_features=24576, bias=False)
          (down_proj): Linear(in_features=24576, out_features=8192, b

In [11]:
### Saving model
print(f"saving the model")
save_path='/home/riyasatohib_cohere_com/repos/models/command-r-ref-sps'
sps_model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)

saving the model


('/home/riyasatohib_cohere_com/repos/models/command-r-ref-sps/tokenizer_config.json',
 '/home/riyasatohib_cohere_com/repos/models/command-r-ref-sps/special_tokens_map.json',
 '/home/riyasatohib_cohere_com/repos/models/command-r-ref-sps/tokenizer.json')