In [None]:
import torch  # type: ignore[attr-defined]
from causal_wrapper import load_causal_whisper

In [None]:
MODEL_ID = "openai/whisper-base"
DEVICE = "cuda:5" if torch.cuda.is_available() else "cpu"
N_MELS = 80
TOTAL_FRAMES = 3000 
SHARED_FRAMES = 500 

In [None]:
model = load_causal_whisper(MODEL_ID, device=DEVICE)
model.eval()

In [None]:
for name, param in model.named_parameters():
    if param.requires_grad is False:
        print(name, param.shape)

In [None]:
common = torch.randn(1, N_MELS, SHARED_FRAMES, device=DEVICE)
extra1 = torch.randn(1, N_MELS, TOTAL_FRAMES - SHARED_FRAMES, device=DEVICE)
extra2 = torch.randn(1, N_MELS, TOTAL_FRAMES - SHARED_FRAMES, device=DEVICE)

x1 = torch.cat([common, extra1], dim=2)
x2 = torch.cat([common, extra2], dim=2)

L = TOTAL_FRAMES // 2
look_ahead = 0
mask = model.encoder._create_lookahead_mask(L, look_ahead, DEVICE, dtype=model.dtype)

with torch.no_grad():
    lat1 = model.encoder(x1, causal_mask=mask).last_hidden_state  # type: ignore[attr-defined]
    lat2 = model.encoder(x2, causal_mask=mask).last_hidden_state  # type: ignore[attr-defined]

print(f"Latent shapes: {lat1.shape}, {lat2.shape}")

In [None]:
del_latents = lat1 - lat2
ok = True

for i in range(del_latents.shape[1]):
    for j in range(del_latents.shape[2]):
        if del_latents[0][i][j] != 0:
            print(f"Latent {i} {j} is not zero: {del_latents[0][i][j]}")
            ok = False
            break

In [None]:
del_latents[0][439]