In [None]:
import torch # type: ignore[attr-defined]
from transformers import WhisperForConditionalGeneration, WhisperProcessor # type: ignore[attr-defined]
from causal_wrapper import load_causal_whisper
from utils import prepare_data

In [None]:
MODEL_ID = "openai/whisper-base"
DEVICE = "cuda:6" if torch.cuda.is_available() else "cpu"

In [None]:
my_model = load_causal_whisper(MODEL_ID, device=DEVICE)
whisper_model = WhisperForConditionalGeneration.from_pretrained(MODEL_ID)
whisper_model.to(DEVICE)
processor = WhisperProcessor.from_pretrained(MODEL_ID)    

In [None]:
ds = prepare_data(max_shards=1)

In [None]:
my_model.encoder.causal_mask = my_model.encoder._create_lookahead_mask(1500, 5, DEVICE, dtype=my_model.dtype)

In [None]:
my_model.encoder.causal_mask

In [None]:
import inspect
print(inspect.getsource(my_model.__init__))

In [None]:
my_model_loss = []
whisper_model_loss = []

for i in range(1):
    sample = ds[i]                                            # type: ignore[attr-defined]
    audio = torch.from_numpy(sample["mp3"]["array"]).float()  # type: ignore[attr-defined]
    text = sample["json"]["text"]                             # type: ignore[attr-defined]
    
    labels = torch.tensor(processor.tokenizer(text, add_special_tokens=False).input_ids).unsqueeze(0)  # type: ignore[attr-defined]
    inputs = processor(audio, sampling_rate=16000, return_tensors="pt")  
    
    labels = labels.to(DEVICE)
    inputs = inputs.to(DEVICE)
    
    with torch.no_grad():
        my_latents = my_model.encoder(
            inputs.input_features 
            )
        
        my_outputs = my_model(
            inputs.input_features,
            labels=labels
        )
        
    # my_results = processor.batch_decode(my_outputs.logits.argmax(dim=-1), skip_special_tokens=True)[0]  # type: ignore[attr-defined]
    # whisper_results = processor.batch_decode(whisper_outputs.logits.argmax(dim=-1), skip_special_tokens=True)[0]  # type: ignore[attr-defined]
        
    # my_model_loss.append(my_outputs.loss.item())
    # whisper_model_loss.append(whisper_outputs.loss.item())
    # print("\n")
    
    print(my_latents.last_hidden_state.shape)
    my_results = processor.batch_decode(my_outputs.logits.argmax(dim=-1), skip_special_tokens=True)[0]  # type: ignore[attr-defined]