In [21]:
import torch as t
import random, re, csv, tqdm, os
import numpy as np
from transformers import AutoProcessor, WhisperForConditionalGeneration
from util import load_to_array, transcribe_batch, TrialData, WordBoundary
import pickle

device = t.device("cuda" if t.cuda.is_available() else "mps" if t.backends.mps.is_available() else "cpu")
model_id = "openai/whisper-small"
sr=16000
random.seed(100) 

In [22]:
processor = AutoProcessor.from_pretrained(model_id)
model = WhisperForConditionalGeneration.from_pretrained(model_id).to(device)
model.generation_config.language = "english"
model.generation_config.task = "transcribe"
processor.feature_extractor.return_attention_mask = True

In [None]:

def select_frames(target, voiced, voiceless, talker):
    endpoints = target.split("_")
    voiced_frames = [f for f in voiced if endpoints[0] not in f]
    voiceless_frames = []
    control_frames = []
    control_condition_order = []
    random.shuffle(voiced)
    neutral_frames = [re.search(r"([a-z]+-[0-9]{2})", f).group(1) for f in voiced_frames]
    for idx, f in enumerate(neutral_frames):
        for v in voiceless:
            if f in v:
                voiceless_frames.append(v)
                if random.randint(0,1):
                    control_frames.append(voiced_frames[idx])
                    control_condition_order.append('voiced')
                else:
                    control_frames.append(v)
                    control_condition_order.append('voiceless')
        
    return voiced_frames, voiceless_frames, control_frames, control_condition_order


def create_trial_batch(condition, target_pair, frames, talker, n_steps, control_order=[]):
    TARGET_ONSET_MS = 28000
    SR = 16000
    TARGET_IDX = int(28.0 * SR)
    silence_500_np = np.zeros(int(0.5 * SR), dtype=np.float32)
    
    # Build shared context array and track WordBoundaries
    context_blocks = []
    context_boundaries = []
    
    # First, calculate total context length to determine initial padding offset
    temp_len = 0
    for idx, f in enumerate(frames):
        c_type = control_order[idx] if condition == 'control' else condition
        # calculate the padding; load_to_array is used for consistency
        audio_len = len(load_to_array(f'audio/MP/{talker}/{c_type}/{f}'))
        temp_len += audio_len + len(silence_500_np)
        
    pad_len = TARGET_IDX - temp_len
    if pad_len < 0: raise ValueError("Context too long")
    
    # Current pointer starts after the leading padding
    current_sample_ptr = pad_len
    
    for idx, f in enumerate(frames):
        c_type = control_order[idx] if condition == 'control' else condition
        audio_np = load_to_array(f'audio/MP/{talker}/{c_type}/{f}')
        
        # Calculate boundaries for this word
        start_sec = current_sample_ptr / SR
        end_sec = (current_sample_ptr + len(audio_np)) / SR
        label = re.search(r"_([a-z]+).wav", f).group(1)
        context_boundaries.append(WordBoundary(start_sec, end_sec, label))
        
        context_blocks.append(audio_np)
        context_blocks.append(silence_500_np)
        current_sample_ptr += len(audio_np) + len(silence_500_np)
    
    context_np = np.concatenate(context_blocks)
    padding = np.zeros(pad_len, dtype=np.float32)
    full_context = np.concatenate([padding, context_np])

    # 2. Stitch and create TrialData metadata
    batch_arrays = []
    trials_metadata = []
    
    for i in range(n_steps):
        target_path = f'audio/MP/{talker}/continuum/{target_pair}_1_{i}.wav'
        target_np = load_to_array(target_path)
        
        # Determine target label and boundaries
        parts = target_pair.split("_")
        target_label = parts[0] if condition == 'voiceless' else parts[1]
        t_start, t_end = 28.0, 28.0 + (len(target_np) / SR)
        
        # Combine shared context boundaries with the target boundary
        all_boundaries = context_boundaries + [WordBoundary(t_start, t_end, target_label)]
        
        # Assemble 30s buffer
        trial_buffer = np.zeros(30 * SR, dtype=np.float32)
        trial_buffer[:TARGET_IDX] = full_context
        trial_buffer[TARGET_IDX : TARGET_IDX + len(target_np)] = target_np
        
        batch_arrays.append(trial_buffer)
        
        # Create TrialData object (no audio attached here to save memory)
        trial = TrialData(
            transcript="", 
            outcome="",
            word_boundaries=all_boundaries,
            condition=condition,
            continuum_step=i,
            target_word=target_label,
            target_array=target_np
        )
        trials_metadata.append(trial)
        
    return batch_arrays, full_context, trials_metadata


In [None]:
def selective_adaptation(output_path, talkers, n_continuum, n_trials):
    all_experiment_data = []
        
    for t in talkers:
        targets = os.listdir(f'audio/MP/{t}/continuum')
        targets = list(set([re.search(r'([a-zA-Z]+_[a-zA-Z]+)_[0-9]', target).group(1) for target in targets]))
        voiced = os.listdir(f'audio/MP/{t}/voiced')
        voiced = [v for v in voiced if v != '.DS_Store']
        voiceless = os.listdir(f'audio/MP/{t}/voiceless')
        voiceless = [v for v in voiceless if v != '.DS_Store']
        
        for n in tqdm.tqdm(range(n_trials)):
            target_pair = targets[n % len(targets)]
            v_frames, vl_frames, c_frames, c_order = select_frames(target_pair, voiced, voiceless, t)

            for c in ['voiced', 'voiceless', 'control']:
                frames = v_frames if c == 'voiced' else vl_frames if c == 'voiceless' else c_frames
                
                # Create the batch of 13 stitched arrays
                batch_audio, context_audio, metadata_list = create_trial_batch(c, target_pair, frames, t, n_continuum, c_order)
                
                # Transcribe entire continuum for this condition in one forward pass
                transcripts = transcribe_batch(batch_audio, processor, model, device)

                for i, text in enumerate(transcripts):
                    metadata_list[i].transcript = text
                    all_experiment_data.append({
                        "context": context_audio, 
                        "metadata": metadata_list[i],
                        "target_pair": target_pair,
                        "iteration": n
                    })

    with open(output_path, 'wb') as file:
        pickle.dump(data, file)

    return all_experiment_data


data = selective_adaptation('data/SelAd_test.csv', ['hope'], 13, 300)


100%|██████████| 2/2 [00:28<00:00, 14.15s/it]
