In [51]:
from studies.gwilliams2023 import Gwilliams2023
from studies.armeini2022 import Armeini2022


study = Armeini2022(
    batch_type="audiotext",
    download=False,
)

rec = study.recordings[0][0][0]
raw = rec.load_raw(load_data=True)
events = rec.load_events(raw, options="both")
word_events = events["word"]
word_events

Loading Armeini2022 with batch type audiotext


Unnamed: 0,onset,duration,word
0,38.863643,0.129705,the
1,38.993349,0.678458,adventures
2,39.671807,0.089796,of
3,39.761603,0.488889,sherlock
4,40.250491,0.698413,holmes
...,...,...,...
8617,4648.015340,0.548753,honourable
8618,4648.564093,0.838095,title
8619,4650.220329,0.089796,of
8620,4650.310125,0.419048,the


In [52]:
from dataloader import DataLoader

add_timestamps = True

dataloader = DataLoader(
    buffer_size=30,
    max_cache_size_gb=400,
    cache_dir="cache",
    notch_filter=True,
    frequency_bands={"all": (0.5, 80)},
    scaling="both",
    brain_clipping=None,
    baseline_window=0.5,
    new_freq=200,
    delay=0.15,
    batch_types={"audiotext": 1},
    batch_kwargs={
        "audiotext": {
            "max_random_shift": 1.0,
            "window_size": 4,
            "window_stride": 1,
            "audio_sample_rate": 16000,
            "hop_length": 160,
            "audio_processor": "openai/whisper-tiny.en",
            "add_timestamps": add_timestamps,
        }
    },
)
dataloader.start_fetching(recordings=[rec])
batch = dataloader.get_recording()

brain, audio, transcript, recording = (
    batch.brain_segments["all"],  # .to(device)
    batch.audio_segments,  # .to(device)
    batch.transcript,
    batch.recording,
)
transcript

['<|1.08|> the <|1.22|> adventures <|1.90|> of <|1.98|> sherlock <|2.48|> holmes',
 '<|0.34|> of <|0.44|> sherlock <|0.92|> holmes <|2.36|> a <|2.52|> scandal <|3.40|> in',
 '<|0.94|> a <|1.10|> scandal <|1.98|> in <|2.20|> bohemia',
 '<|0.54|> in <|0.76|> bohemia',
 '<|2.32|> to <|2.48|> sherlock <|2.92|> holmes',
 '<|0.36|> to <|0.52|> sherlock <|0.96|> holmes <|1.86|> she <|2.08|> is <|2.28|> always <|3.20|> the <|3.28|> woman',
 '<|0.64|> she <|0.86|> is <|1.06|> always <|1.98|> the <|2.06|> woman',
 '<|0.08|> the <|0.16|> woman <|2.02|> i <|2.22|> have <|2.34|> seldom <|2.98|> heard <|3.24|> him <|3.40|> mention <|3.78|> her',
 '<|0.94|> i <|1.16|> have <|1.26|> seldom <|1.90|> heard <|2.16|> him <|2.32|> mention <|2.70|> her <|2.92|> under <|3.16|> any <|3.50|> other <|3.62|> name',
 '<|0.04|> heard <|0.28|> him <|0.46|> mention <|0.84|> her <|1.04|> under <|1.28|> any <|1.62|> other <|1.76|> name <|3.30|> in <|3.44|> his',
 '<|0.04|> under <|0.28|> any <|0.62|> other <|0.74|> na

In [40]:
# import re


# def remove_timestamps(transcript_line: str) -> str:
#     """
#     Removes <|x.xx|> and with space patterns from a single transcript line.
#     """
#     pattern = r"<\|\d+(\.\d+)?\|>\s?"
#     return re.sub(pattern, "", transcript_line).strip()


# def clean_timestamped_transcript(transcript_lines):
#     """
#     Removes timestamp tokens from a list of transcript lines.
#     Returns a list of cleaned lines.
#     """
#     return [remove_timestamps(line) for line in transcript_lines]


# transcript_no_timestamps = clean_timestamped_transcript(transcript)
# transcript_no_timestamps

In [None]:
from transformers import WhisperTokenizerFast

predict_timestamps = add_timestamps
add_prefix_space = True

tokenizer = WhisperTokenizerFast.from_pretrained(
    "openai/whisper-tiny.en",
    predict_timestamps=predict_timestamps,
    add_prefix_space=add_prefix_space,
)

encoded = tokenizer(
    transcript,
    return_tensors="pt",
    padding="max_length",
    truncation=True,
    max_length=64,  # 16 * int(windows)
)
input_ids, attention_mask = encoded["input_ids"], encoded["attention_mask"]

In [84]:
skip_special_tokens = True
decode_with_timestamps = False

decoded = tokenizer.batch_decode(
    sequences=input_ids,
    skip_special_tokens=skip_special_tokens,
    decode_with_timestamps=decode_with_timestamps,
    clean_up_tokenization_spaces=True,
)
decoded = [" ".join(word.split()) for word in decoded]
decoded

['the adventures of sherlock holmes',
 'of sherlock holmes a scandal in',
 'a scandal in bohemia',
 'in bohemia',
 'to sherlock holmes',
 'to sherlock holmes she is always the woman',
 'she is always the woman',
 'the woman i have seldom heard him mention her',
 'i have seldom heard him mention her under any other name',
 'heard him mention her under any other name in his',
 'under any other name in his eyes she',
 'in his eyes she eclipses and predominates the',
 'she eclipses and predominates the whole of her sex',
 'and predominates the whole of her sex it was not that he',
 'sex it was not that he felt any emotion akin to love',
 'it was not that he felt any emotion akin to love for irene adler',
 'to love for irene adler all emotions and that',
 'adler all emotions and that one particularly',
 'and that one particularly were abhorrent to',
 'particularly were abhorrent to his cold precise but',
 'were abhorrent to his cold precise but admirably',
 'precise but admirably balanced m

In [None]:
# from train.training_session_v1 import load_training_session
# import multiprocessing
# import torch

# device = "cuda"

# session = load_training_session(
#     save_path="saves/phase2/architecture/task/transformers/4C4Con_d256/epoch_39",
#     studies={"gwilliams2023": "audio"},
#     data_path="/home/ubuntu/storage-texas/data",
#     cache_name="cache",
# )

# dataloader = session.get_dataloader(buffer_size=1, num_workers=1, max_cache_size=100)

# # Unseen both
# # recording = session.studies["gwilliams2023"].recordings[19][0][0]

# # Seen
# # recording = session.studies["gwilliams2023"].recordings[15][0][1]

# # Unseen task
# # recording = session.studies["gwilliams2023"].recordings[18][0][0]

# # Unseen subject
# recording = session.studies["gwilliams2023"].recordings[19][0][1]

# print(
#     f"Showing recording: {recording.study_name}_{recording.subject_id}_{recording.task_id}"
# )

# dataloader.start_fetching(recordings=[recording])
# batch = dataloader.get_recording()
# brain, audio, recording = (
#     batch.brain_segments["all"].to(device),
#     batch.audio_segments.to(device),
#     batch.recording,
# )

# conditions = {
#     "study": f"{recording.study_name}",
#     "subject": f"{recording.study_name}_{recording.subject_id}",
# }
# session.model.to(device).eval()

# # with torch.no_grad():
# #     (
# #         x,  # [B, C, T]
# #         quantizer_metrics,
# #         channel_weights,
# #         hidden_outputs,
# #         encoder_hidden_states,  # L * [B, T, D]
# #     ) = session.model(
# #         x=[brain],
# #         recording=[recording],
# #         conditions=[conditions],
# #         mel=[audio],
# #         train=False,
# #         return_hidden_outputs=False,
# #     )

# dataloader.stop()