<a href="https://colab.research.google.com/github/rsk2327/DistAya/blob/main/KLDBasedPruning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip -q install datasets
# I ran this on colab, so all the required dependencies were already installed. On a another machine, you'll probably need to install all the dependencies by hand.

In [None]:
!pip install -q git+https://github.com/huggingface/transformers.git huggingface_hub

In [None]:
!git config --global credential.helper store

In [None]:
!huggingface-cli login

In [None]:
!huggingface-cli download CohereForAI/aya-23-8B --repo-type model --local-dir aya-23-8B

In [None]:
!huggingface-cli download CohereForAI/aya_dataset --repo-type dataset --local-dir aya-dataset

In [None]:
import torch

def getmodule(module: torch.nn.Module, target_module: str):
    """Get a target module from a given module."""
    submodules = target_module.split(".", 1)
    if submodules[0].isdigit():
      next_module = module[int(submodules[0])]
    else:
      next_module = getattr(module, submodules[0])
    if len(submodules) == 1:
        return next_module
    return getmodule(next_module, submodules[-1])

def setmodule(module: torch.nn.Module, target_module: str, value: torch.nn.Module):
    """Set a target module in a given module."""
    submodules = target_module.split(".", 1)
    if len(submodules) == 1:
        if submodules[0].isdigit():
            module[int(submodules[0])] = value
        else:
            setattr(module, submodules[0], value)
    else:
        setmodule(getattr(module, submodules[0]), submodules[-1], value)

In [None]:
import torch
from datasets import load_dataset
from torch.utils.data.dataset import Dataset

class IndexDataset(Dataset):
    def __init__(self, tensors):
        self.tensors = tensors

    def __getitem__(self, index):
        return self.tensors[index]

    def __len__(self):
        return len(self.tensors)

# Loading the data and tokenizing it
def process_data(samples, tokenizer, seq_len, field_name):
    test_ids = tokenizer("\n\n".join(samples[field_name]), return_tensors='pt').input_ids[0]
    test_ids_batch = []
    nsamples = test_ids.numel() // seq_len

    for i in range(nsamples):
        batch = test_ids[(i * seq_len):((i + 1) * seq_len)]
        test_ids_batch.append(batch)
    test_ids_batch = torch.stack(test_ids_batch)
    return IndexDataset(tensors=test_ids_batch)

def merge_instructions(sample):
    return {"text": "\n\n".join([sample["inputs"], sample["targets"]])}

def get_aya_loaders(tokenizer, seq_len=512, batch_size=4, max_samples=256):
    test_data = load_dataset('/content/aya-dataset/', 'default', split='test')
    test_data = test_data.map(merge_instructions,
                              batched=False,
                              remove_columns=test_data.column_names)

    if max_samples is not None:
        test_data = test_data.select(range(max_samples)) # select a small subset just for testing
    test_dataset = process_data(test_data, tokenizer, seq_len, 'text')

    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return test_loader

def get_wikitext_loaders(tokenizer, seq_len=128, batch_size = 4, max_samples=256):
    test_data = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
    test_data = test_data.shuffle(seed=42)
    if max_samples is not None:
        test_data = test_data.select(range(max_samples)) # select a small subset just for testing
    test_dataset = process_data(test_data, tokenizer, seq_len, 'text')

    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return test_loader

In [None]:
from abc import abstractmethod, ABC
from  pathlib import Path
import re
import logging
from argparse import ArgumentParser
from pandas import DataFrame
from tqdm.notebook import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F

def get_layers(llm):
    for name, module in llm.named_modules():
        if re.search("layers\.\d+$", name):
            yield name, module

def set_pruned_layers(llm):
    for name, module in get_layers(llm):
        pruned_layer = PruneLayer(module, drop=False, is_last=False)
        setmodule(llm, name, pruned_layer)

class PruneLayer(torch.nn.Module):
    """"If pruned, the layer will do nothing other than returning its input."""
    def __init__(self, layer, is_last: bool, drop: bool=False):
        super().__init__()
        self.layer = layer
        self.drop = drop
        self.is_last = is_last

    @torch.no_grad()
    def forward(self, hidden_states, **kwargs):
        if self.drop:
            return (hidden_states,)
        return self.layer(hidden_states, **kwargs)

class Sensivity(ABC):
    def __init__(self, llm):
        self.llm = llm
        set_pruned_layers(llm)

    @abstractmethod
    def score(self, batch: torch.Tensor, target_module: str) -> float:
        """PPL or KL-div"""
        pass

    def sensivity(self, test_lodaer, target_module):
        scores = []
        for batch in tqdm(test_lodaer):
            batch = batch.to(self.llm.device)
            score = self.score(batch, target_module=target_module)
            scores.append(score)
        return torch.tensor(scores).mean().item()

    def prune_layer(self, target_module):
        module = getmodule(self.llm, target_module)
        module.drop = True

    def unprune_layer(self, target_module):
        module = getmodule(self.llm, target_module)
        module.drop = False

    def __call__(self, test_dataset, output_folder):
        layers = list(name for name, _ in get_layers(self.llm))
        results = []
        for name in tqdm(layers):
            layer_idx = name.split(".")[-1]
            sensivity = self.sensivity(test_lodaer=test_dataset, target_module=name)
            logging.info(f"pruned layer={layer_idx}, sensvity={sensivity}")
            results.append({
                "layer": name,
                "score": sensivity,
            })
            print({
                "layer": name,
                "score": sensivity,
            })
            df = DataFrame(results)
            df.to_csv(output_folder / f"sensivities.csv")

class PPLSensivity(Sensivity):
    def __init__(self, llm):
        super().__init__(llm)

    @torch.no_grad()
    def score(self, batch: torch.Tensor, target_module) -> float:
        self.prune_layer(target_module=target_module)
        output = self.llm(batch, use_cache=False, output_attentions=False)
        lm_logits = output.logits

        shift_logits = lm_logits[:, :-1, :].contiguous()
        shift_labels = batch[:, 1:].contiguous()

        loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
        loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.view(-1))
        self.unprune_layer(target_module=target_module)
        return torch.exp(loss).mean().item()

class KLDivSensivity(Sensivity):
    def __init__(self, llm):
        super().__init__(llm)
        self.t = 2

    @torch.no_grad()
    def score(self, batch: torch.Tensor, target_module: str) -> float:
        b, *_ = batch.shape
        teacher_logits = self.llm(batch, use_cache=False, output_attentions=False).logits
        self.prune_layer(target_module=target_module)
        student_logits = self.llm(batch, use_cache=False, output_attentions=False).logits

        t_probs = F.softmax(teacher_logits / self.t, dim=-1)
        s_probs = F.log_softmax(student_logits / self.t, dim=-1)
        kl_d = torch.sum(t_probs * (t_probs.log() - s_probs)) / b * (self.t ** 2)

        return kl_d.item()

def main(args):
    output_folder = Path(args.output_folder)
    output_folder.mkdir(exist_ok=True, parents=True)

    tokenizer = AutoTokenizer.from_pretrained(args.model,
                                              trust_remote_code=True)
    loader = get_wikitext_loaders if args.data == 'wikitext' else get_aya_loaders
    test_loader = loader(tokenizer=tokenizer, max_samples=args.subset, batch_size=args.batch_size)

    llm = AutoModelForCausalLM.from_pretrained(args.model,
                                               torch_dtype=torch.bfloat16,
                                               trust_remote_code=True,
                                               device_map="auto")
    print(llm)
    llm.cuda()
    set_pruned_layers(llm)

    scorer = PPLSensivity(llm) if args.score == "perplexity" else KLDivSensivity(llm)
    scorer(test_dataset=test_loader, output_folder=output_folder)

In [None]:
from dataclasses import dataclass

@dataclass
class Args:
    output_folder: str = "sensivities"
    model: str = "aya-23-8B"
    data: str = "aya"
    score: str = "kl_div"
    subset: int = 512
    batch_size: int = 2

In [None]:
args = Args()
main(args)

In [None]:
from transformers import AutoModelForCausalLM
import torch

def sort_by_importance(sensitivity_scores):
    layers_scores = []
    with open(sensitivity_scores, "r") as layer_scores:
        next(layer_scores)
        for line in layer_scores:
            line = line.strip()
            *_, layer, score = line.split(",")
            layers_scores.append((layer, float(score)))

    layers, *_ = zip(*sorted(layers_scores, key=lambda x: x[1]))
    return list(layers)

def prune(llm, sensitivities, reduction: int=50.0):
    sorted_layers = sort_by_importance(sensitivity_scores=sensitivities)
    num_layers_to_skip = round((reduction * llm.config.num_hidden_layers) / 100)
    layers_to_skip = sorted_layers[:num_layers_to_skip + 1]
    for layer in layers_to_skip:
        delattr(llm.model.layers, layer.split(".")[-1])
    print(f"Parameters of the pruned LLM: {llm.num_parameters():,}")
    remaining_layers = sum(1 for _ in llm.model.layers)
    llm.config.num_hidden_layers = remaining_layers
    llm.save_pretrained("aya-4b-kld-pruning")

In [None]:
tokenizer = AutoTokenizer.from_pretrained("aya-23-8B",
                                          trust_remote_code=True)

In [None]:
!huggingface-cli download yaya-sy/aya-4b-kld-pruning --local-dir aya-4b-kld-pruning

In [None]:
llm = AutoModelForCausalLM.from_pretrained("aya-23-8B",
                                            torch_dtype=torch.bfloat16,
                                            trust_remote_code=True,
                                            device_map="auto")

In [None]:
prune(llm=llm, sensitivities="sensivities.csv")

In [None]:
llm_pruned = AutoModelForCausalLM.from_pretrained("aya-4b-kld-pruning",
                                                  torch_dtype=torch.bfloat16,
                                                  trust_remote_code=True,
                                                  device_map="cpu")

In [None]:
llm_pruned

In [None]:
tokenizer.save_pretrained("aya-4b-kld-pruning")

In [None]:
llm_pruned.push_to_hub("yaya-sy/aya-4b-kld-pruning-with-tokenizer")

In [None]:
!huggingface-cli upload yaya-sy/aya-4b-kld-pruning aya-4b-kld-pruning

In [None]:
!rm -rf /content/aya-4b-kld-pruning

In [None]:
test = AutoModelForCausalLM.from_pretrained("yaya-sy/aya-4b-kld-pruning",
                                          trust_remote_code=True,
                                          device_map="auto")

In [None]:
tokenizer = AutoTokenizer.from_pretrained("yaya-sy/aya-4b-kld-pruning",
                                          trust_remote_code=True)

In [None]:
tokenizer