In [1]:
import whisper
from datasets import load_dataset
from torchaudio.functional import resample
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

cuda


In [3]:
import pickle
with open("train_dataset.bin", "rb") as f:
    train_dataset = pickle.load(f)

In [4]:
import torch
BATCH_SIZE = 2
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=8,
    # 不完全なバッチの無視
    drop_last=True,
    collate_fn=train_dataset.collate_fn
)

In [5]:
model = whisper.load_model("base").to(DEVICE)

In [6]:
batch = next(iter(train_dataloader))
bidx, bx, bx_len, by_input, by_input_len, by_target, by_target_len = batch
with torch.no_grad():
    bx = bx.to(DEVICE)
    benc_out = model.encoder(bx)
by_input = by_input.to(DEVICE)
bdec_out = model.decoder(by_input, benc_out)


In [None]:
# bdec_out: [B, T, C]
# by_target: [B, T]

In [7]:
by_target = by_target.to(DEVICE)
by_target_len = by_target_len.to(DEVICE)

In [8]:
padded_target_len = by_target_len.max()
# p(X) = 1の計算
one_hot_target = torch.nn.functional.one_hot(by_target, num_classes=51865).float() # [B, T, C]
# log q(x)の計算
log_q_x = torch.nn.functional.log_softmax(bdec_out[:, -padded_target_len:, :], dim=-1) # [B, T, C]
# log q(x) * p(x)の計算
log_q_x_p_x = log_q_x * one_hot_target # [B, T, C]
# padding部分のマスクケイン
mask = torch.arange(padded_target_len).expand(len(by_target_len), padded_target_len).to(by_target_len.device) < by_target_len.unsqueeze(1) # [B, T]
mask = mask.unsqueeze(-1) # [B, T, 1]
# log q(x) * p(x)のpadding部分を0マスクすることでロス和の計算に寄与させない
log_q_x_p_x = log_q_x_p_x.masked_fill(~mask, 0) # [B, T, C]
# ロス和の計算
loss = -log_q_x_p_x.sum()

In [9]:
loss

tensor(219.6853, device='cuda:0', grad_fn=<NegBackward0>)

In [None]:
# 非padding部分のみ抽出するマスクを作成
padded_length = by_target_len.max()
mask = torch.arange(padded_length, device=by_target.device)[None, :] < by_target_len[:, None]


In [81]:
by_target_len = torch.tensor(by_target_len)
padded_len = by_target_len.max()
mask = torch.arange(padded_len, dtype=torch.int32).expand(len(by_target_len), padded_len).to(by_target_len.device) < by_target_len.unsqueeze(1)
mask = mask.unsqueeze(-2)

  by_target_len = torch.tensor(by_target_len)


In [37]:
mask.permute(0, 2, 1)[0][0]

tensor([True])

In [35]:
mask.shape

torch.Size([2, 1, 17])

In [36]:
mask

tensor([[[ True,  True,  True,  True,  True,  True,  True,  True, False, False,
          False, False, False, False, False, False, False]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
           True,  True,  True,  True,  True,  True,  True]]])