# EXPERIMENTAL
Sharded Whisper with NxD

In [1]:
%%writefile compile.py
import os
#os.environ['XLA_USE_BF16']='1'
import types
import torch
import argparse
import torch_neuronx
import transformers
from datasets import load_dataset
from transformers import WhisperProcessor, WhisperForConditionalGeneration

import neuronx_distributed
import neuronx_distributed.parallel_layers.utils as neuronx_dist_utils
from transformers.models.whisper.modeling_whisper import WhisperAttention

from torch import nn
from neuronx_distributed.parallel_layers import layers, parallel_state
from functools import partial

conf = {}
conf['suffix'] = 'tiny' # 'large-v3'
conf['model_id'] = f"openai/whisper-{conf['suffix']}"
conf['tp_degree'] = 2
conf['embed_dim_per_head'] = 64
conf['max_dec_len'] = 64

def wrap_model():
    global conf    
    model = WhisperForConditionalGeneration.from_pretrained(conf['model_id'], torchscript=True)

    # Here we build a model with tensor-parallel layers.
    class NxDWhisperAttention(WhisperAttention):
        def __init__(self, embed_dim, num_heads, dropout, is_decoder,bias,is_causal,config):
            super().__init__(embed_dim, num_heads, dropout, is_decoder,bias,is_causal,config)
            self.q_proj = layers.ColumnParallelLinear(embed_dim, embed_dim, bias=bias, gather_output=False)
            self.k_proj = layers.ColumnParallelLinear(embed_dim, embed_dim, bias=False, gather_output=False)
            self.v_proj = layers.ColumnParallelLinear(embed_dim, embed_dim, bias=bias, gather_output=False)
            
            self.out_proj = layers.RowParallelLinear(embed_dim, embed_dim, bias=bias, input_is_parallel=True)

            # self.num_heads = neuronx_dist_utils.divide(num_heads, parallel_state.get_tensor_model_parallel_size())
            # self.embed_dim = neuronx_dist_utils.divide(embed_dim, parallel_state.get_tensor_model_parallel_size())
            # self.head_dim = neuronx_dist_utils.divide(self.embed_dim, self.num_heads)
            # self.scaling = self.head_dim**-0.5
            self.heads_pad = get_pad_size(config.d_model, parallel_state.get_tensor_model_parallel_size())
            
            self.num_heads = neuronx_dist_utils.divide(
               num_heads + self.heads_pad,
               parallel_state.get_tensor_model_parallel_size())
            
            self.embed_dim = neuronx_dist_utils.divide(
               embed_dim + (conf['embed_dim_per_head'] * self.heads_pad),
               parallel_state.get_tensor_model_parallel_size())
            
            self.head_dim = neuronx_dist_utils.divide(self.embed_dim, self.num_heads)
            self.scaling = self.head_dim**-0.5

    for layer in model.model.encoder.layers:
        l = layer.self_attn
        layer.self_attn = NxDWhisperAttention(l.embed_dim, l.num_heads, l.dropout, l.is_decoder, True, l.is_causal, l.config)

    for layer in model.model.decoder.layers:
        l = layer.self_attn
        layer.self_attn = NxDWhisperAttention(l.embed_dim, l.num_heads, l.dropout, l.is_decoder, True, l.is_causal, l.config)
        l = layer.encoder_attn
        layer.encoder_attn = NxDWhisperAttention(l.embed_dim, l.num_heads, l.dropout, l.is_decoder, True, l.is_causal, l.config)

    if not hasattr(model.model.encoder, 'forward_'): model.model.encoder.forward_ = model.model.encoder.forward
    def enc_f(self, input_features, attention_mask, **kwargs):        
        return self.forward_(input_features, attention_mask, return_dict=False)
    model.model.encoder.forward = types.MethodType(enc_f, model.model.encoder)    

    if not hasattr(model.model.decoder, 'forward_'): model.model.decoder.forward_ = model.model.decoder.forward
    def dec_f(self, input_ids, encoder_hidden_states, **kwargs):
        return self.forward_(input_ids=input_ids, encoder_hidden_states=encoder_hidden_states, use_cache=False)[0]
    model.model.decoder.forward = types.MethodType(dec_f, model.model.decoder)

    neuronx_distributed.parallel_layers.load(f"whisper-{conf['suffix']}.pt", model, sharded=False)
    
    return model.model
    
def get_model_encoder():
    global conf
    model = wrap_model().encoder
    if get_pad_size(model.config.d_model, conf['tp_degree']) > 0:
        print("Padding Encoder...")
        model = neuronx_distributed.parallel_layers.pad.pad_model(model, tp_degree=conf['tp_degree'], n_heads=model.config.encoder_attention_heads)
    return model, {}

def get_model_decoder():
    global conf
    model = wrap_model().decoder
    if get_pad_size(model.config.d_model, conf['tp_degree']) > 0:
        print("Padding Decoder...")
        model = neuronx_distributed.parallel_layers.pad.pad_model(model, tp_degree=conf['tp_degree'], n_heads=model.config.decoder_attention_heads)
    return model, {}

def get_pad_size(n_heads, tp_degree):
    heads_pad = 0
    # find the optimal padding considering at most 32 cores
    for i in range(32): 
        if (n_heads+i) % tp_degree == 0:
            heads_pad = i
            break
    print(f"Heads pad: {heads_pad}\nTP degree: {tp_degree}")    
    return heads_pad    

if __name__=='__main__':
    # 'attention_mask', 'input_features'    
    if not os.path.isfile(f"whisper-{conf['suffix']}.pt"):
        print(f"Saving the original weights: {conf['model_id']}...")
        model = WhisperForConditionalGeneration.from_pretrained(conf['model_id'], torchscript=True)
        torch.save({"model":model.state_dict()}, f"whisper-{conf['suffix']}.pt")

    print(f"Exporting encoder: whisper-{conf['suffix']}")    
    
    if conf['suffix'] == "tiny":
        dim_enc,dim_dec=80,384
    elif conf['suffix'] == "large-v3":
        dim_enc,dim_dec=128,1280
        
    inp = (torch.zeros([1, dim_enc, 3000], dtype=torch.float32), torch.zeros([1, dim_enc], dtype=torch.int64))
    model = neuronx_distributed.trace.parallel_model_trace(
        get_model_encoder, 
        inp, 
        tp_degree=conf['tp_degree'],
        compiler_args='--model-type=transformer --enable-saturate-infinity --auto-cast=all',
        compiler_workdir='./enc_dir',
        inline_weights_to_neff=False,
        max_parallel_compilations=max(4, conf['tp_degree'])
    )
    neuronx_distributed.trace.parallel_model_save(model, f"tp_model_enc_{conf['suffix']}_{conf['tp_degree']}")
        
    print(f"Exporting decoder: whisper-{conf['suffix']}")
    inp = (torch.zeros([1, conf['max_dec_len']], dtype=torch.int64), torch.zeros([1, 1500, dim_dec], dtype=torch.float32))    
    model = neuronx_distributed.trace.parallel_model_trace(
        get_model_decoder, 
        inp, 
        tp_degree=conf['tp_degree'],
        compiler_args='--model-type=transformer --enable-saturate-infinity  --auto-cast=all',
        compiler_workdir='./dec_dir',
        inline_weights_to_neff=True,
        max_parallel_compilations=max(4, conf['tp_degree'])
    )
    neuronx_distributed.trace.parallel_model_save(model, f"tp_model_dec_{conf['suffix']}_{conf['tp_degree']}")    

Overwriting compile.py


In [2]:
%%bash

rm -rf tp_model_enc_* tp_model_dec_*

kill -9 $(ps aux|grep multiprocessing|awk '{print $2}')

python3 compile.py

bash: line 4: kill: (104759) - No such process


Exporting encoder: whisper-tiny


Some weights of WhisperForConditionalGeneration were not initialized from the model checkpoint at openai/whisper-tiny and are newly initialized: ['proj_out.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


> initializing tensor model parallel with size 2
> initializing pipeline model parallel with size 1
> initializing data parallel with size 1
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
[2024-05-01 16:43:30.090: I neuronx_distributed/parallel_layers/checkpointing.py:148] `load` kwarg `model` is deprecated, please use `model_or_optimizer` instead as we are supporting to use `load` with optimizer as well
[2024-05-01 16:43:30.090: I neuronx_distributed/parallel_layers/checkpointing.py:161] loading checkpoint from whisper-tiny.pt




2024-05-01 16:43:39.000902:  107413  INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2024-05-01 16:43:39.000904:  107413  INFO ||NEURON_CC_WRAPPER||: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.13.68.0+6dfecc895/MODULE_9921017747686325983+d41d8cd9/model.neff. Exiting with a successfully compiled graph.
2024-05-01 16:43:40.000112:  107420  INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2024-05-01 16:43:40.000113:  107420  INFO ||NEURON_CC_WRAPPER||: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.13.68.0+6dfecc895/MODULE_16942971074878409330+d41d8cd9/model.neff. Exiting with a successfully compiled graph.
2024-05-01 16:43:40.000310:  107426  INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2024-05-01 16:43:40.000311:  107426  INFO ||NEURON_CC_WRAPPER||: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.13.68.0+6dfecc895/MODULE_1703642196029355321+d41d8cd9/model

Some weights of WhisperForConditionalGeneration were not initialized from the model checkpoint at openai/whisper-tiny and are newly initialized: ['proj_out.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
[2024-05-01 16:43:42.508: I neuronx_distributed/parallel_layers/checkpointing.py:148] `load` kwarg `model` is deprecated, please use `model_or_optimizer` instead as we are supporting to use `load` with optimizer as well
2024-05-01T16:43:42Z Running DoNothing
2024-05-01T16:43:42Z DoNothing finished after 0.000 seconds
2024-05-01T16:43:42Z Running AliasDependencyInduction
2024-05-01T16:43:42Z AliasDependencyInduction finished after 0.002 seconds
2024-05-01T16:43:42Z Running CanonicalizeIR
2024-05-01T16:43:42Z CanonicalizeIR finished after 0.005 seconds
2024-05-01T16:43:42Z Running LegalizeCCOpLayout
2024-05-01T16:43:42Z LegalizeCCOpLayout finished after 0.007 seconds
2024-05-01T16:



2024-05-01T16:43:43Z MemcpyElimination finished after 0.549 seconds
2024-05-01T16:43:43Z Running LoopFusion
2024-05-01T16:43:43Z LoopFusion finished after 0.177 seconds
2024-05-01T16:43:43Z Running Simplifier
2024-05-01T16:43:43Z Simplifier finished after 0.014 seconds
2024-05-01T16:43:43Z Running Delinearization
2024-05-01T16:43:43Z Delinearization finished after 0.022 seconds
2024-05-01T16:43:43Z Running AliasDependencyElimination
2024-05-01T16:43:43Z AliasDependencyElimination finished after 0.006 seconds
2024-05-01T16:43:43Z Running DeadStoreElimination
2024-05-01T16:43:44Z DeadStoreElimination finished after 0.284 seconds
2024-05-01T16:43:44Z Running AliasDependencyInduction
2024-05-01T16:43:44Z AliasDependencyInduction finished after 0.001 seconds
2024-05-01T16:43:44Z Running Simplifier
2024-05-01T16:43:44Z Simplifier finished after 0.027 seconds
2024-05-01T16:43:44Z Running LICM
2024-05-01T16:43:44Z LICM finished after 0.012 seconds
2024-05-01T16:43:44Z Running Delinearization
2

Some weights of WhisperForConditionalGeneration were not initialized from the model checkpoint at openai/whisper-tiny and are newly initialized: ['proj_out.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


> initializing tensor model parallel with size 2
> initializing pipeline model parallel with size 1
> initializing data parallel with size 1
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
[2024-05-01 16:44:24.992: I neuronx_distributed/parallel_layers/checkpointing.py:148] `load` kwarg `model` is deprecated, please use `model_or_optimizer` instead as we are supporting to use `load` with optimizer as well
[2024-05-01 16:44:24.992: I neuronx_distributed/parallel_layers/checkpointing.py:161] loading checkpoint from whisper-tiny.pt
2024-05-01 16:44:35.000333:  111565  INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2024-05-01 16:44:35.000334:  111565  INFO ||NEURON_CC_WRAPPER||: Using a cached neff at /va

Some weights of WhisperForConditionalGeneration were not initialized from the model checkpoint at openai/whisper-tiny and are newly initialized: ['proj_out.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
Heads pad: 0
TP degree: 2
[2024-05-01 16:44:38.130: I neuronx_distributed/parallel_layers/checkpointing.py:148] `load` kwarg `model` is deprecated, please use `model_or_optimizer` instead as we are supporting to use `load` with optimizer as well
2024-05-01T16:44:38Z Running DoNothing
2024-05-01T16:44:38Z DoNothing finished after 0.000 seconds
2024-05-01T16:44:38Z Running AliasDependencyInduction
2024-05-01T16:44:38Z AliasDependencyInduction finished after 0.002 seconds
2024-05-01T16:44:38Z Running CanonicalizeIR
2024-05-01T16:44:38Z CanonicalizeIR finished after 0.007 seconds
2024-05-01T16:44:38Z Running LegalizeCCOpLayout
2024-05-01T16:44:38Z LegalizeCCOpLayout finished after 0.007 seconds
2024-05-01T16:

In [3]:
import os
import types
import torch
import neuronx_distributed
from datasets import load_dataset
from transformers import WhisperProcessor, WhisperForConditionalGeneration

tp_degree=2
suffix="tiny"
#suffix="large-v3"
model_id=f"openai/whisper-{suffix}"

processor = WhisperProcessor.from_pretrained(model_id)
model = WhisperForConditionalGeneration.from_pretrained(model_id, torchscript=True)
cpu_model = WhisperForConditionalGeneration.from_pretrained(model_id, torchscript=True)

dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[3]["audio"]
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features

model.model.decoder.forward_neuron = neuronx_distributed.trace.parallel_model_load(f"tp_model_dec_{suffix}_{tp_degree}")

max_dec_len = 64

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Some weights of WhisperForConditionalGeneration were not initialized from the model checkpoint at openai/whisper-tiny and are newly initialized: ['proj_out.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of WhisperForConditionalGeneration were not initialized from the model checkpoint at openai/whisper-tiny and are newly initialized: ['proj_out.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [4]:
import torch
import time
import torch.nn.functional as F
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions

def enc_f(self, input_features, attention_mask, **kwargs):
    if hasattr(self, 'forward_neuron'):
        return self.forward_neuron(input_features, attention_mask)
    else:
        return self.forward_(input_features, attention_mask, return_dict=False)

if not hasattr(model.model.encoder, 'forward_'): model.model.encoder.forward_ = model.model.encoder.forward
model.model.encoder.forward = types.MethodType(enc_f, model.model.encoder)

def dec_f(self, input_ids, attention_mask=None, encoder_hidden_states=None, **kwargs):    
    t=time.time()
    out = None        
    if not attention_mask is None and encoder_hidden_states is None:
        ## I know, this is a quick-n-dirty workaround to align the input parameters for NeuronSDK tracer
        encoder_hidden_states, attention_mask = attention_mask,encoder_hidden_states
    inp = [input_ids, encoder_hidden_states]
    # pad the input to max_dec_len
    pad_size = torch.as_tensor(self.max_length - inp[0].shape[1])
    inp[0] = F.pad(inp[0], (0, pad_size), "constant", processor.tokenizer.pad_token_id)
    if hasattr(self, 'forward_neuron'):
        out = self.forward_neuron(*inp)
    else:        
        out = self.forward_(input_ids=inp[0], encoder_hidden_states=inp[1], return_dict=False, use_cache=False)[0]
    print(f"Elapsed encoder forward: {time.time()-t}")
    # unpad the output
    last_hidden_state = out[:, :input_ids.shape[1], :]
    return BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=last_hidden_state)
if not hasattr(model.model.decoder, 'forward_'): model.model.decoder.forward_ = model.model.decoder.forward
model.model.decoder.forward = types.MethodType(dec_f, model.model.decoder)
model.model.decoder.max_length = max_dec_len

## Only decoder running on inf2

In [8]:
# Load a sample
import time
t=time.time()
y1 = model.generate(input_features)
print(f"Elapsed time: {time.time()-t}")
y2 = cpu_model.generate(input_features)
y1,y2

Elapsed encoder forward: 0.0016117095947265625
Elapsed encoder forward: 0.0013756752014160156
Elapsed encoder forward: 0.0013403892517089844
Elapsed encoder forward: 0.0012969970703125
Elapsed encoder forward: 0.0014400482177734375
Elapsed encoder forward: 0.001407623291015625
Elapsed encoder forward: 0.0013933181762695312
Elapsed encoder forward: 0.0013897418975830078
Elapsed encoder forward: 0.0013837814331054688
Elapsed encoder forward: 0.0013625621795654297
Elapsed encoder forward: 0.0013911724090576172
Elapsed encoder forward: 0.0013849735260009766
Elapsed encoder forward: 0.0013873577117919922
Elapsed encoder forward: 0.0013608932495117188
Elapsed encoder forward: 0.0013782978057861328
Elapsed encoder forward: 0.0013666152954101562
Elapsed encoder forward: 0.001371145248413086
Elapsed encoder forward: 0.0013751983642578125
Elapsed encoder forward: 0.0013742446899414062
Elapsed encoder forward: 0.001374959945678711
Elapsed encoder forward: 0.0013704299926757812
Elapsed encoder for

(tensor([[50258, 50259, 50359, 50363,   634,   575, 12525, 22618,  1968,  6144,
          35617, 20084,  1756,   311,   589,   307,   534, 10281,   934,   439,
            293,   393,  4411,   294,   309,   457,   707,   295, 26916,   286,
            392,  6628,    13, 50257]]),
 tensor([[50258, 50259, 50359, 50363,   634,   575, 12525, 22618,  1968,  6144,
          35617, 20084,  1756,   311,   589,   307,   534, 10281,   934,   439,
            293,   393,  4411,   294,   309,   457,   707,   295, 26916,   286,
            392,  6628,    13, 50257]]))

In [6]:
y1.shape,y2.shape

(torch.Size([1, 34]), torch.Size([1, 34]))

In [7]:
processor.batch_decode(y1, skip_special_tokens=True)

[" He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky Ithaca."]

## Now both encoder and decoder on inf2

In [9]:
model.model.encoder.forward_neuron = neuronx_distributed.trace.parallel_model_load(f"tp_model_enc_{suffix}_{tp_degree}")

In [10]:
# Load a sample
import time
t=time.time()
y1 = model.generate(input_features)
print(f"Elapsed time: {time.time()-t}")
y2 = cpu_model.generate(input_features)
y1,y2

Elapsed encoder forward: 0.003704547882080078
Elapsed encoder forward: 0.0015168190002441406
Elapsed encoder forward: 0.0015499591827392578
Elapsed encoder forward: 0.0013225078582763672
Elapsed encoder forward: 0.0014243125915527344
Elapsed time: 1.6213769912719727


(tensor([[50258, 50259, 50359, 50363,  2411, 50257]]),
 tensor([[50258, 50259, 50359, 50363,   634,   575, 12525, 22618,  1968,  6144,
          35617, 20084,  1756,   311,   589,   307,   534, 10281,   934,   439,
            293,   393,  4411,   294,   309,   457,   707,   295, 26916,   286,
            392,  6628,    13, 50257]]))