In [1]:
from studies.gwilliams2023 import Gwilliams2023
from studies.armeini2022 import Armeini2022


study = Gwilliams2023(
    batch_type="audiotext",
    download=False,
)

rec = study.recordings[0][0][0]
raw = rec.load_raw(load_data=True)
events = rec.load_events(raw, options="both")
word_events = events["word"]
word_events

Loading Gwilliams2023 with batch type audiotext


Unnamed: 0,onset,duration,word
0,23.506,0.30,Tara
1,23.816,0.24,stood
2,24.056,0.37,stock
3,24.586,0.40,still
4,25.136,0.41,waiting
...,...,...,...
663,361.097,0.17,end
664,361.277,0.14,for
665,361.487,0.58,project
666,362.207,0.15,and


In [7]:
from dataloader import DataLoader

add_timestamps = True

dataloader = DataLoader(
    buffer_size=30,
    max_cache_size_gb=400,
    cache_dir="cache",
    notch_filter=True,
    frequency_bands={"all": (0.5, 80)},
    scaling="both",
    brain_clipping=None,
    baseline_window=0.5,
    new_freq=200,
    delay=0.15,
    batch_types={"audiotext": 1},
    batch_kwargs={
        "audiotext": {
            "max_random_shift": 1.0,
            "window_size": 4,
            "window_stride": 1,
            "audio_sample_rate": 16000,
            "hop_length": 160,
            "audio_processor": "openai/whisper-tiny.en",
            "add_timestamps": add_timestamps,
            "tokenize": True,
        }
    },
)
dataloader.start_fetching(recordings=[rec])
batch = dataloader.get_recording()

brain, audio, transcript, transcript_attention_masks, recording = (
    batch.brain_segments["all"],  # .to(device)
    batch.audio_segments,  # .to(device)
    batch.transcript,
    batch.transcript_attention_masks,
    batch.recording,
)
transcript

tensor([[50257, 50362, 50363,  ..., 50256, 50256, 50256],
        [50257, 50362, 50367,  ..., 50256, 50256, 50256],
        [50257, 50362, 50370,  ..., 50256, 50256, 50256],
        ...,
        [50257, 50362, 50412,  ..., 50256, 50256, 50256],
        [50257, 50362, 50370,  ..., 50256, 50256, 50256],
        [50257, 50362, 50376,  ..., 50256, 50256, 50256]])

In [3]:
# import re


# def remove_timestamps(transcript_line: str) -> str:
#     """
#     Removes <|x.xx|> and with space patterns from a single transcript line.
#     """
#     pattern = r"<\|\d+(\.\d+)?\|>\s?"
#     return re.sub(pattern, "", transcript_line).strip()


# def clean_timestamped_transcript(transcript_lines):
#     """
#     Removes timestamp tokens from a list of transcript lines.
#     Returns a list of cleaned lines.
#     """
#     return [remove_timestamps(line) for line in transcript_lines]


# transcript_no_timestamps = clean_timestamped_transcript(transcript)
# transcript_no_timestamps

In [10]:
from transformers import WhisperTokenizerFast

predict_timestamps = add_timestamps
add_prefix_space = True

tokenizer = WhisperTokenizerFast.from_pretrained(
    "openai/whisper-tiny.en",
    predict_timestamps=predict_timestamps,
    add_prefix_space=add_prefix_space,
)

# encoded = tokenizer(
#     transcript,
#     return_tensors="pt",
#     padding="max_length",
#     truncation=True,
#     max_length=64,  # 16 * int(windows)
# )
# input_ids, attention_mask = encoded["input_ids"], encoded["attention_mask"]

skip_special_tokens = False
decode_with_timestamps = True

decoded = tokenizer.batch_decode(
    sequences=transcript,
    skip_special_tokens=skip_special_tokens,
    decode_with_timestamps=decode_with_timestamps,
    clean_up_tokenization_spaces=True,
)
decoded = [" ".join(word.split()) for word in decoded]
decoded

['<|startoftranscript|><|notimestamps|><|0.00|> Tara <|0.30|> stood <|0.56|> stock <|1.08|> still <|1.62|> waiting <|2.04|> for <|2.16|> the <|2.26|> first <|2.54|> tiny <|2.86|> gleam <|3.10|> from <|3.28|> the <|3.36|> scout <|3.64|> craft<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '<|startoftranscript|><|notimestamps|><|0.08|> waiting <|0.50|> for <|0.62|> the <|0.72|> first <|1.00|> tiny <|1.30|> gleam <|1.56|> from <|1.74|> the <|1.80|> scout <|2.10|> craft <|2.40|> to <|2.54|> appear <|2.82|> in <|2.90|> the <|3.00|> darkness <|3.42|> of <|3.56|> the<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '<|startoftranscript|><|notimestamps|><|0.14|> from <|0.32|> the <|0.38

In [11]:
import torch
from config import SimpleConvConfig
from peft import AdaLoraConfig
from models.whisper_decoder import WhisperDecoder


brain_module_config = SimpleConvConfig(
    # Str to list of possible conditions
    mel_normalization=False,
    conditions={
        "study": [],
        "subject": [],
    },
    # Channels
    in_channels=208,
    out_channels=80,
    hidden_dim=256,
    dropout=0.2,
    initial_batch_norm=True,
    # Sensor layout settings
    layout_dim=2,
    layout_proj=True,
    layout_scaling="minmax",
    # Merger with spatial attn
    merger=False,
    merger_emb_type=None,
    merger_emb_dim=0,
    merger_channels=0,
    merger_dropout=0.0,  # Float
    merger_conditional=None,
    # Inital
    initial_linear=256,
    initial_depth=1,
    # Conditional layers
    conditional_layers=False,
    conditional_layers_dim=None,  # input or hidden_dim
    # Conv layer overall structure
    depth=4,
    kernel_size=3,
    growth=1.0,
    dilation_growth=2,
    dilation_period=5,
    glu=1,
    conv_dropout=0.2,
    dropout_input=0.1,
    batch_norm=True,
    half=True,
    cnn_pos_encoding=False,
    # Quantizer
    quantizer=False,
    num_codebooks=0,
    codebook_size=0,
    quantizer_commitment=0,
    quantizer_temp_init=0,
    quantizer_temp_min=0,
    quantizer_temp_decay=0,
    # Transformers Encoders
    transformer_input="continuous",
    transformer_encoder_emb="sinusoidal",
    transformer_encoder_layers=4,
    transformer_encoder_heads=4,
    # Conformer encoder variant
    rnn_type="conformer",
    depthwise_conv_kernel_size=15,
    use_group_norm=False,
    convolution_first=False,
    # Transformer Decoders
    transformer_decoder_emb=None,
    transformer_decoder_layers=0,
    transformer_decoder_heads=0,
    transformer_decoder_dim=0,
)

adalora_init_r = 12
adalora_target_r = 4
adalora_tinit = 450 * 3  # 5% total steps
adalora_tfinal = 450 * 8  # 50-80% total steps
adalora_deltaT = 450 * 1  # 1-5% total steps
adalora_lora_alpha = 32
adalora_lora_dropout = 0.1
adalora_total_step = 450 * 50

adalora_config = AdaLoraConfig(
    peft_type="ADALORA",
    task_type="SPEECH_RECOGNITION",
    target_modules=["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"],
    init_r=adalora_init_r,
    target_r=adalora_target_r,
    tinit=adalora_tinit,
    tfinal=adalora_tfinal,
    deltaT=adalora_deltaT,
    lora_alpha=adalora_lora_alpha,
    lora_dropout=adalora_lora_dropout,
    total_step=adalora_total_step,
)

model = WhisperDecoder(
    brain_module_config=brain_module_config,
    adalora_config=adalora_config,
    device="mps",
    audio_model_id="openai/whisper-tiny.en",
)

RNNEncoder initialized as conformer with 4 layers, 256 d_model, 4 nhead
	Embedding: sinusoidal, params: 6075392
SimpleConv initialized with 8927984 parameters, cond: ['study', 'subject']
Merger False, merger channels 0
ConvBlocks: 4, hidden_dim: 256, params 2626048
Found 40 target modules for AdaLora: ['model.decoder.layers.0.self_attn.k_proj', 'model.decoder.layers.0.self_attn.v_proj', 'model.decoder.layers.0.self_attn.q_proj', 'model.decoder.layers.0.self_attn.out_proj', 'model.decoder.layers.0.encoder_attn.k_proj', 'model.decoder.layers.0.encoder_attn.v_proj', 'model.decoder.layers.0.encoder_attn.q_proj', 'model.decoder.layers.0.encoder_attn.out_proj', 'model.decoder.layers.0.fc1', 'model.decoder.layers.0.fc2', 'model.decoder.layers.1.self_attn.k_proj', 'model.decoder.layers.1.self_attn.v_proj', 'model.decoder.layers.1.self_attn.q_proj', 'model.decoder.layers.1.self_attn.out_proj', 'model.decoder.layers.1.encoder_attn.k_proj', 'model.decoder.layers.1.encoder_attn.v_proj', 'model.dec

In [None]:
idx = 20

brain_input = brain[:idx].to("mps")
attention_mask_input = transcript_attention_masks[:idx].to("mps")
input_ids_input = transcript[:idx].to("mps")
audio_input = audio[:idx].to("mps")

# with torch.no_grad():
#     model.eval()
#     output = model(
#         x=[brain_input],
#         recording=[recording],
#         conditions=[{}],
#         mel=None,
#         train=False,
#         return_hidden_outputs=False,
#         attention_mask=None,
#         labels=input_ids_input,
#         decoder_attention_mask=attention_mask_input,
#     )

# (
#     x,  # predicted mel
#     quantizer_metrics,
#     channel_weights,
#     hidden_outputs,
#     encoder_last_hidden_state,
#     ce_loss,
# ) = output

with torch.no_grad():
    model.eval()
    (
        output_token_ids,  # [B, T]
        x,  # [B, 80, T']
        quantizer_metrics,
        channel_weights,
        hidden_outputs,
    ) = model.generate(
        x=None,  # [brain_input],
        recording=[recording],
        conditions=[{}],
        mel=audio_input,
        max_new_tokens=128,
        attention_mask=None,
        return_hidden_outputs=False,
    )
output_token_ids_decoded = tokenizer.batch_decode(
    sequences=output_token_ids,
    skip_special_tokens=skip_special_tokens,
    decode_with_timestamps=decode_with_timestamps,
    clean_up_tokenization_spaces=True,
)
output_token_ids_decoded

UnboundLocalError: cannot access local variable 'quantizer_metrics' where it is not associated with a value

In [16]:
transcript

tensor([[50257, 50362, 50363,  ..., 50256, 50256, 50256],
        [50257, 50362, 50367,  ..., 50256, 50256, 50256],
        [50257, 50362, 50370,  ..., 50256, 50256, 50256],
        ...,
        [50257, 50362, 50412,  ..., 50256, 50256, 50256],
        [50257, 50362, 50370,  ..., 50256, 50256, 50256],
        [50257, 50362, 50376,  ..., 50256, 50256, 50256]])

In [7]:
input_ids_input_decoded = tokenizer.batch_decode(
    sequences=input_ids_input,
    skip_special_tokens=skip_special_tokens,
    decode_with_timestamps=decode_with_timestamps,
    clean_up_tokenization_spaces=True,
)
input_ids_input_decoded = [" ".join(word.split()) for word in input_ids_input_decoded]
input_ids_input_decoded

['Tara stood stock still waiting for the first tiny gleam from the scout craft',
 'waiting for the first tiny gleam from the scout craft to appear in the darkness of the',
 'from the scout craft to appear in the darkness of the The gentle',
 'the darkness of the The gentle constant breeze of recycled air from the',
 'The gentle constant breeze of recycled air from the vent above blew an annoying hair',
 'air from the vent above blew an annoying hair against her nose but she ignored it',
 'hair against her nose but she ignored it A gasp from the',
 'ignored it A gasp from the psychic broke her silent',
 'from the psychic broke her silent vigil and she',
 'vigil and she turned Results',
 'and she turned Results Harmon she',
 'Harmon she suppressed the surge of annoyance that',
 'she suppressed the surge of annoyance that ran through her as she',
 'the surge of annoyance that ran through her as she contemplated the',
 'through her as she contemplated the gift of getting all the',
 'the gi

In [8]:
from utils.nlp_metrics import nlp_metrics

metrics = nlp_metrics(output_token_ids_decoded, input_ids_input_decoded)
metrics

{'bleu': 0.5525711421877697,
 'rouge_f': 0.801990348361316,
 'bert_score': 0.7350820302963257,
 'cer': 0.3835767566453689,
 'self_bleu': 0.034513324789303225}