In [None]:
# %load_ext autoreload
# %autoreload 2

import os,sys
sys.path.insert(0, os.path.abspath('../..')) # root 
sys.path.insert(1, os.path.abspath('../../modules/deepspeech_pytorch') ) #modules deepspeech 
import torch
import torch.nn as nn
import hydra
from src.autodiff.ctc_loss_imp import ctc_loss_imp

from deepspeech_pytorch.model import DeepSpeech, SequenceWise
from src.utils.plot_utils import *

import matplotlib
%matplotlib inline

In [None]:
def get_deep_speech_model_saved_20_10(device) -> DeepSpeech:


    from hydra.core.config_store import ConfigStore
    from deepspeech_pytorch.configs.inference_config import EvalConfig
    
    cs = ConfigStore.instance()
    cs.store(name="config", node=EvalConfig)
    cfg = EvalConfig()
    # model_path='/scratch/f006pq6/projects/gitrepos/deepspeech.pytorch/librispeech_pretrained_v3.ckpt'
    model_path='/scratch/f006pq6/projects/gitrepos/deepspeech.pytorch/outputs/2023-11-01/23-31-50/lightning_logs/version_0/checkpoints/epoch=5-step=19295.ckpt'
    model_ds = DeepSpeech.load_from_checkpoint(model_path)
    model_ds.train()
    model_ds = model_ds.to(device)

    # modify window size and window stride of spect_cfg
    model_ds.spect_cfg.window_size = 0.032 # 32 ms
    model_ds.spect_cfg.window_stride = 0.020 # 20 ms

    return model_ds

def get_deep_speech_model_32_20(device) -> DeepSpeech:
    import json
    label_file = '../../modules/deepspeech_pytorch/labels.json'
    with open(label_file) as label_file:
        labels = json.load(label_file)

    from deepspeech_pytorch.configs.train_config import DeepSpeechConfig
    from deepspeech_pytorch.configs.train_config import UniDirectionalConfig

    cfg = DeepSpeechConfig()
    model_ds = DeepSpeech(
        labels=labels,
        model_cfg=UniDirectionalConfig(),
        optim_cfg=cfg.optim,
        precision=cfg.trainer.precision,
        spect_cfg=cfg.data.spect
    )

    model_ds.train()
    model_ds = model_ds.to(device)

    # modify window size and window stride of spect_cfg
    # model_ds.spect_cfg.window_size = 0.032 # 32 ms
    # model_ds.spect_cfg.window_stride = 0.020 # 20 ms

    return model_ds


def get_dataloader(model_ds: DeepSpeech, batch_size=1, num_workers=4):
    from deepspeech_pytorch.loader.data_loader import SpectrogramDataset, AudioDataLoader
    
    test_dir = '/scratch/f006pq6/datasets/librispeech/test_clean'
    test_dataset = SpectrogramDataset(
        audio_conf=model_ds.spect_cfg,
        input_path=hydra.utils.to_absolute_path(test_dir),
        labels=model_ds.labels,
        normalize=True
    )
    test_loader = AudioDataLoader(
        test_dataset,
        batch_size=batch_size,
        num_workers=num_workers
    )
    return test_loader

device ='cuda:0'
model_ds = get_deep_speech_model_32_20(device)
test_loader = get_dataloader(model_ds, batch_size=1, num_workers=4)
test_loader_iter = iter(test_loader)

In [None]:
test_loader.dataset[0][0].shape

In [33]:
def length_filter(dataset, sr=16000, window_size_ms=32, window_step_ms=20):
    """
    Print out samples from the dataset with lengths ranging from 1 second to 3 seconds.
    
    Args:
    - dataset: PyTorch dataset returning T x F tensors.
    - sr: Sampling rate.
    - window_size_ms: Window size in milliseconds.
    - window_step_ms: Window step in milliseconds.
    """
    # Calculate window size and step in samples
    window_size_samples = int(sr * (window_size_ms / 1000))
    window_step_samples = int(sr * (window_step_ms / 1000))
    
    # Iterate over samples in the dataset

    good_indices =[]
    for i in range(len(dataset)):
        dur = dataset[i][0].shape[1] * window_step_samples / sr
        # print(dur)
        if dur >= 1 and dur <= 3.5:
            print(f"Sample {i}: Length = {dur:.2f} seconds")
            # print(dataset[i][-1])
            good_indices.append(i)
    return good_indices


In [34]:
test_loader.dataset[4][0].shape

torch.Size([257, 554])

In [35]:
good_idx = length_filter(test_loader.dataset)

Sample 1: Length = 3.32 seconds
Sample 6: Length = 3.20 seconds
Sample 9: Length = 3.04 seconds
Sample 15: Length = 2.96 seconds
Sample 20: Length = 1.96 seconds
Sample 22: Length = 3.10 seconds
Sample 24: Length = 3.40 seconds
Sample 26: Length = 2.84 seconds
Sample 33: Length = 3.22 seconds
Sample 36: Length = 1.90 seconds
Sample 38: Length = 3.04 seconds
Sample 40: Length = 2.58 seconds
Sample 42: Length = 3.42 seconds
Sample 45: Length = 2.16 seconds
Sample 46: Length = 2.70 seconds
Sample 51: Length = 3.24 seconds
Sample 53: Length = 1.82 seconds
Sample 54: Length = 3.32 seconds
Sample 60: Length = 3.30 seconds
Sample 62: Length = 2.18 seconds
Sample 63: Length = 3.40 seconds
Sample 71: Length = 3.08 seconds
Sample 72: Length = 2.94 seconds
Sample 74: Length = 2.54 seconds
Sample 75: Length = 2.36 seconds
Sample 118: Length = 2.26 seconds
Sample 126: Length = 3.50 seconds
Sample 143: Length = 2.68 seconds
Sample 147: Length = 2.38 seconds
Sample 168: Length = 2.22 seconds
Sample 1

Sample 1019: Length = 3.10 seconds
Sample 1026: Length = 3.36 seconds
Sample 1038: Length = 2.88 seconds
Sample 1043: Length = 2.62 seconds
Sample 1050: Length = 3.02 seconds
Sample 1052: Length = 2.10 seconds
Sample 1055: Length = 2.22 seconds
Sample 1056: Length = 3.46 seconds
Sample 1061: Length = 2.18 seconds
Sample 1067: Length = 3.40 seconds
Sample 1069: Length = 3.10 seconds
Sample 1073: Length = 3.16 seconds
Sample 1075: Length = 3.20 seconds
Sample 1077: Length = 2.72 seconds
Sample 1078: Length = 3.28 seconds
Sample 1095: Length = 3.34 seconds
Sample 1096: Length = 2.24 seconds
Sample 1099: Length = 3.46 seconds
Sample 1108: Length = 2.70 seconds
Sample 1112: Length = 3.26 seconds
Sample 1113: Length = 2.72 seconds
Sample 1126: Length = 2.42 seconds
Sample 1127: Length = 2.68 seconds
Sample 1130: Length = 3.04 seconds
Sample 1134: Length = 2.64 seconds
Sample 1139: Length = 3.12 seconds
Sample 1142: Length = 3.32 seconds
Sample 1143: Length = 2.64 seconds
Sample 1145: Length 

In [36]:
len(good_idx)

594