In [1]:
import os
import subprocess

# Path to the directory containing the MKV files
input_directory = '../us_raw'
output_directory = '../us_raw_audio'

# Create the output directory if it doesn't exist
os.makedirs(output_directory, exist_ok=True)

# List all files in the input directory
files = os.listdir(input_directory)

# Extract audio from each MKV file
for file in files:
    if file.endswith('.mkv'):
        input_file = os.path.join(input_directory, file)
        output_file = os.path.join(output_directory, file.replace('.mkv', '.wav'))
        
        # Run ffmpeg to extract audio and save as WAV
        subprocess.run(['ffmpeg', '-i', input_file, '-vn', '-acodec', 'pcm_s16le', '-ar', '44100', '-ac', '2', output_file])

print("Audio extraction complete.")


ffmpeg version 4.3 Copyright (c) 2000-2020 the FFmpeg developers
  built with gcc 7.3.0 (crosstool-NG 1.23.0.449-a04d0)
  configuration: --prefix=/opt/conda/conda-bld/ffmpeg_1597178665428/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeh --cc=/opt/conda/conda-bld/ffmpeg_1597178665428/_build_env/bin/x86_64-conda_cos6-linux-gnu-cc --disable-doc --disable-openssl --enable-avresample --enable-gnutls --enable-hardcoded-tables --enable-libfreetype --enable-libopenh264 --enable-pic --enable-pthreads --enable-shared --disable-static --enable-version3 --enable-zlib --enable-libmp3lame
  libavutil      56. 51.100 / 56. 51.100
  libavcodec     58. 91.100 / 58. 91.100
  libavformat    58. 45.100 / 58. 45.100
  libavdevice    58. 10.100 / 58. 10.100
  libavfilter     7. 85.100 /  7. 85.100
  libavresample   4.  0.  0 /  4.  0.  0
  libsw

Audio extraction complete.


Output #0, wav, to '../us_raw_audio/s4_05.wav':
  Metadata:
    ISFT            : Lavf58.45.100
    Stream #0:0: Audio: pcm_s16le ([1][0][0][0] / 0x0001), 44100 Hz, stereo, s16, 1411 kb/s
    Metadata:
      title           : simple_aac_recording0
      DURATION        : 00:00:36.800000000
      encoder         : Lavc58.91.100 pcm_s16le
size=    6339kB time=00:00:36.80 bitrate=1411.2kbits/s speed=1.05e+03x    
video:0kB audio:6339kB subtitle:0kB other streams:0kB global headers:0kB muxing overhead: 0.001202%
ffmpeg version 4.3 Copyright (c) 2000-2020 the FFmpeg developers
  built with gcc 7.3.0 (crosstool-NG 1.23.0.449-a04d0)
  configuration: --prefix=/opt/conda/conda-bld/ffmpeg_1597178665428/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeh --cc=/opt/conda/conda-bld/ffmpeg_1597178665428/_build_env/bin/x86_64-conda_cos6-linu

In [1]:
import os
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
from transformers import Wav2Vec2Processor, Wav2Vec2Model

from torch import nn
from torch.optim.lr_scheduler import StepLR
from model.utils import get_parser

def get_model(cfg):
    ## old
    if cfg.arch == 'vocal_stage1':
        from model.audio_vqvae import VQAutoEncoder as Model
        model = Model(args=cfg)
    else:
        raise Exception('architecture not supported yet'.format(cfg.arch))
    return model


class AudioDataset(Dataset):
    def __init__(self, audio_dir, segment_ms, processor):
        self.audio_dir = audio_dir
        self.segment_ms = segment_ms
        self.processor = processor
        self.sample_rate = 16000  # Wav2Vec2 expects 16000 Hz
        self.segment_len = int(self.sample_rate * (self.segment_ms / 1000))  # Segment length in samples
        self.file_list = []
        self.load_and_segment_files()

    def load_and_segment_files(self):
        audio_files = [os.path.join(self.audio_dir, f) for f in os.listdir(self.audio_dir) if f.endswith('.wav')]

        for audio_path in audio_files:
            waveform, sample_rate = torchaudio.load(audio_path)
            
            # Resample to Wav2Vec2's expected sample rate
            if sample_rate != self.sample_rate:
                waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sample_rate)(waveform)

            if waveform.size(0) > 1:
                waveform = waveform.mean(dim=0, keepdim=True)
            
            
            # Segment the waveform into chunks of segment_len
            num_frames = waveform.size(1)
            for start_idx in range(0, num_frames, self.segment_len):
                if start_idx + self.segment_len <= num_frames:
                    segment = waveform[:, start_idx:start_idx + self.segment_len]
                    self.file_list.append(segment)

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        segment = self.file_list[idx]
        inputs = self.processor(segment.squeeze(0), sampling_rate=self.sample_rate, return_tensors="pt", padding=True)
        return inputs.input_values.squeeze(0)

def collate_fn(batch):
    return torch.stack(batch)

# Parameters
audio_directory = '../us_raw_audio'
segment_duration_ms = 25  # Desired sequence length

processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
wav_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")

# Create dataset and dataloader
audio_dataset = AudioDataset(audio_directory,  segment_duration_ms, processor)
audio_dataloader = DataLoader(audio_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)

args = get_parser()
model = get_model(args)

# Iterate over the dataloader
for batch in audio_dataloader:
    print('batch', batch.shape)
    # with torch.no_grad():
    #     outputs = wav_model(batch)
    #     print(outputs.last_hidden_state.shape)
    # Perform operations on the batch


  from .autonotebook import tqdm as notebook_tqdm
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
usage: ipykernel_launcher.py [-h] [--config CONFIG] ...
ipykernel_launcher.py: error: unrecognized arguments: --f=/home/pengy/.local/share/jupyter/runtime/kernel-v2-22127907C4NcxuZFeXe.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
