In [None]:
import torch  # type: ignore[attr-defined]
from causal_wrapper import load_causal_whisper

In [None]:
MODEL_ID = "openai/whisper-base"
DEVICE = "cuda:5" if torch.cuda.is_available() else "cpu"
N_MELS = 80
TOTAL_FRAMES = 3000 
SHARED_FRAMES = 500 

In [None]:
model = load_causal_whisper(MODEL_ID, for_conditional=True)
model.to(DEVICE)
model.eval()

In [None]:
# if not conditional generation, then replace model.model.encoder with model.encoder
print(model.model.encoder.causal_mask)

In [None]:
for name, param in model.named_parameters():
    print(name)

In [None]:
import inspect
print(inspect.getsource(model.model.encoder.forward))

In [None]:
common = torch.randn(1, N_MELS, SHARED_FRAMES, device=DEVICE)
extra1 = torch.randn(1, N_MELS, TOTAL_FRAMES - SHARED_FRAMES, device=DEVICE)
extra2 = torch.randn(1, N_MELS, TOTAL_FRAMES - SHARED_FRAMES, device=DEVICE)

x1 = torch.cat([common, extra1], dim=2)
x2 = torch.cat([common, extra2], dim=2)

L = TOTAL_FRAMES // 2
look_ahead = 1
model.model.encoder.causal_mask = model.model.encoder._create_lookahead_mask(L, 
                                                                 look_ahead, 
                                                                 DEVICE, 
                                                                 dtype=model.dtype)


with torch.no_grad():
    latents_1 = model.model.encoder(x1).last_hidden_state  # type: ignore[attr-defined]
    latents_2 = model.model.encoder(x2).last_hidden_state  # type: ignore[attr-defined]

print(f"Latent shapes: {latents_1.shape}, {latents_2.shape}")

In [None]:
del_latents = latents_1 - latents_2

for i in range(del_latents.shape[1]):
    if del_latents[0][i][0] != 0:
        print(f"Latent {i} {0} is not zero: {del_latents[0][i][0]}")