In [34]:
import os
import torch
import tqdm
import csv
import signal
import torchaudio
import json
from pathlib import Path
from accelerate import Accelerator
from transformers import AutoProcessor, WhisperForConditionalGeneration
from torch.utils.data import Dataset, DistributedSampler
import torch.multiprocessing as mp
mp.set_start_method('spawn', force=True)

In [2]:
# Configuration
AUDIO_DIR = "data/CelebV-HQ/audio"
OUTPUT_FILE = "data/CelebV-HQ/transcripts/transcriptions.txt"
WHISPER_MODEL = "openai/whisper-small"
BATCH_SIZE = 32
MAX_NEW_TOKENS = 444
TARGET_SAMPLE_RATE = 16000

In [24]:
class AudioDataset(Dataset):
    def __init__(self, audio_dir, processor, device, processed_files=None):
        self.audio_dir = audio_dir
        self.file_paths = []
        self.processor = processor
        self.device = device
        
        # Find all audio files
        self.file_paths.extend(list(Path(audio_dir).rglob("*.wav")))
        
        # Filter out already processed files
        if processed_files:
            self.file_paths = [f for f in self.file_paths if str(f) not in processed_files]
            
    def __len__(self):
        return len(self.file_paths)
    
    def __getitem__(self, idx):
        path = str(self.file_paths[idx])
        
        # Load and resample audio
        speech, sr = torchaudio.load(path)
        
        # Handle multi-channel audio (take first channel)
        if speech.shape[0] > 1:
            speech = speech[0].unsqueeze(0)
        
        # Resample to 16kHz if needed
        if sr != TARGET_SAMPLE_RATE:
            resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=TARGET_SAMPLE_RATE)
            speech = resampler(speech)
        
        # Process with processor and explicitly request attention mask
        inputs = self.processor(
            speech.squeeze(0), 
            sampling_rate=TARGET_SAMPLE_RATE, 
            return_tensors="pt",
            return_attention_mask=True
        )
        
        # Move to device and ensure tensors
        processed_inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        # Return both the path and processed audio
        return path, processed_inputs

In [27]:
accelerator = Accelerator()
device = accelerator.device
processor = AutoProcessor.from_pretrained(WHISPER_MODEL)


In [28]:
dataset = AudioDataset(AUDIO_DIR, processor, device)


In [41]:
sampler = DistributedSampler(
        dataset,
        num_replicas=accelerator.num_processes,
        rank=accelerator.process_index,
        shuffle=False
    )

dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        sampler=sampler,
        num_workers=16
    )

In [30]:
model_config = WhisperForConditionalGeneration.config_class.from_pretrained(WHISPER_MODEL)
model_config.forced_decoder_ids = None
model_config.suppress_tokens = []
model = WhisperForConditionalGeneration.from_pretrained(
    WHISPER_MODEL, 
    config=model_config,
    torch_dtype=torch.float16
).to(device)

In [20]:
# Create progress bar
pbar = tqdm.tqdm(
    total=len(dataloader), 
    desc=f"GPU {accelerator.process_index}", 
    position=accelerator.process_index
)

GPU 0:   0%|          | 0/899 [00:00<?, ?it/s]

In [44]:
for batch_idx, paths in enumerate(dataloader):
    print(batch_idx)

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/opt/conda/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/opt/conda/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'AudioDataset' on <module '__main__' (built-in)>


KeyboardInterrupt: 