In [9]:
import torch # type: ignore[attr-defined]
from transformers import WhisperForConditionalGeneration, WhisperProcessor # type: ignore[attr-defined]
from causal_wrapper import load_causal_whisper
from utils import prepare_data

In [10]:
MODEL_ID = "openai/whisper-base"
DEVICE = "cuda:6" if torch.cuda.is_available() else "cpu"

In [None]:
# The model weights are loaded properly, so when you give it inf (=1500) look ahead, it should give you the same output as the original model.

my_model = load_causal_whisper(MODEL_ID, for_conditional=True)
my_model.model.encoder.causal_mask = my_model.model.encoder._create_lookahead_mask(1500, 1500)
my_model.to(DEVICE)

whisper_model = WhisperForConditionalGeneration.from_pretrained(MODEL_ID)
whisper_model.to(DEVICE)

processor = WhisperProcessor.from_pretrained(MODEL_ID) 

Some weights of CausalWhisperForConditionalGeneration were not initialized from the model checkpoint at openai/whisper-base and are newly initialized: ['model.encoder.causal_mask']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [29]:
ds = prepare_data(max_shards=1)

In [30]:
my_model.model.encoder.causal_mask

tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]]], device='cuda:6')

In [32]:
import inspect
print(inspect.getsource(my_model.model.encoder.layers[0].forward))

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
        layer_head_mask: torch.Tensor,
        output_attentions: bool = False,
    ) -> torch.Tensor:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
                `(encoder_attention_heads,)`.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        """
        residual = hidden_states
        hidden_states = self.self_attn_layer_norm(hidden_states)

In [33]:
for i in range(1):
    sample = ds[i]                                            # type: ignore[attr-defined]
    audio = torch.from_numpy(sample["mp3"]["array"]).float()  # type: ignore[attr-defined]
    text = sample["json"]["text"]                             # type: ignore[attr-defined]
    
    labels = torch.tensor(processor.tokenizer(text, add_special_tokens=False).input_ids).unsqueeze(0)  # type: ignore[attr-defined]
    inputs = processor(audio, sampling_rate=16000, return_tensors="pt")  
    
    labels = labels.to(DEVICE)
    inputs = inputs.to(DEVICE)
    
    with torch.no_grad():
        my_latents = my_model.model.encoder(
            inputs.input_features 
            )
        
        my_outputs = my_model(
            inputs.input_features,
            labels=labels
        )
        
        whisper_latents = whisper_model.model.encoder(
            inputs.input_features
            )
        
        whisper_outputs = whisper_model(
            inputs.input_features,
            labels=labels
            )
        
        
    my_results = processor.batch_decode(my_outputs.logits.argmax(dim=-1), skip_special_tokens=True)[0]  # type: ignore[attr-defined]
    whisper_results = processor.batch_decode(whisper_outputs.logits.argmax(dim=-1), skip_special_tokens=True)[0]  # type: ignore[attr-defined]
        
    print(my_outputs.loss.item())
    print(whisper_outputs.loss.item())
    print("\n")
    
    print(my_latents.last_hidden_state.shape)
    print(whisper_latents.last_hidden_state.shape)
    
    print(f"diff {my_latents.last_hidden_state[0][0] - whisper_latents.last_hidden_state[0][0]}")

3.61093807220459
3.61093807220459


torch.Size([1, 1500, 512])
torch.Size([1, 1500, 512])
diff tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0