In [None]:
import jsonlines
import torchaudio
import torch
import os

cur_dir = os.getcwd()
src_dir = os.path.dirname(cur_dir) # 2 directories up
til_dir = os.path.dirname(os.path.dirname(src_dir))
home_dir = os.path.dirname(til_dir)
test_dir = os.path.join(home_dir, 'novice')
audio_dir = os.path.join(test_dir, 'audio')
data_dir = os.path.join(cur_dir, 'data')

In [None]:
data = {'key': [], 'audio': [], 'transcript': []}
data_path = os.path.join(test_dir, "asr.jsonl")
with jsonlines.open(data_path) as reader:
    for obj in reader:
        for key, value in obj.items():
            data[key].append(value)

def get_audio_length(audio_path, audio_dir, processor):
    try:
        waveform, sample_rate = torchaudio.load(os.path.join(audio_dir, audio_path))
        waveform = waveform.numpy().flatten()

        # Resample if needed
        if sample_rate != 16000:
            resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
            waveform = resampler(torch.tensor(waveform)).numpy().flatten()

        # Extract audio features
        audio_features = processor.feature_extractor(waveform, sampling_rate=16000).input_features[0]
        return int(len(audio_features))  # Ensure the length is a Python int
    except Exception as e:
        print(f"Error processing audio: {e}")
        return None

audio_lengths = [(audio, get_audio_length(audio, audio_dir, processor)) for audio in data['audio']]
audio_lengths = [length for length in audio_lengths if length[1] is not None]  # Remove None values

audio_lengths = np.array(audio_lengths)
print(f"Max length: {np.max(audio_lengths)}")
print(f"Mean length: {np.mean(audio_lengths)}")

In [None]:
def get_transcription_length(transcript, processor, max_length=None):
    try:
        labels = processor.tokenizer(transcript, max_length=max_length, truncation=True).input_ids
        return len(labels)
    except Exception as e:
        print(f"Error processing transcription: {e}")
        return None
    
transcription_lengths = [get_transcription_length(transcript, processor) for transcript in data['transcript']]
transcription_lengths = [length for length in transcription_lengths if length is not None]  # Remove None values

# Analyze the lengths
plt.hist(transcription_lengths, bins=20)
plt.xlabel('Length of tokenized transcription')
plt.ylabel('Frequency')
plt.title('Distribution of Tokenized Transcription Lengths')
plt.show()

max_length = max(transcription_lengths)
print(f"Maximum transcription length: {max_length}")