In [1]:
import glob
from collections import deque
import numpy as np
import tensorflow as tf
import torch
from torch.utils.data import IterableDataset, DataLoader
from tqdm import tqdm


## Data Processing

In [2]:
# TFRecord parsing 

class AudioFrameTFRecordDataset(IterableDataset):
    def __init__(self, pattern, sequence_length):
        self.files   = sorted(glob.glob(pattern))
        self.seq_len = sequence_length

        # grab the audio stream
        self.seq_feat = {
            "audio": tf.io.FixedLenSequenceFeature(
                [], tf.string
            )
        }

    def __iter__(self):
        buf = deque(maxlen=self.seq_len + 1)

        for fn in self.files:
            for raw in tf.data.TFRecordDataset(fn):
                _, seq_feats = tf.io.parse_single_sequence_example(
                    raw,
                    context_features={},
                    sequence_features=self.seq_feat
                )
                audio_raw = seq_feats["audio"].numpy()

                frames = np.stack([
                    np.frombuffer(b, dtype=np.uint8).astype(np.float32)
                    for b in audio_raw
                ], axis=0)

                # slide window
                for row in frames:
                    buf.append(row)
                    if len(buf) == self.seq_len + 1:
                        seq    = torch.from_numpy(
                            np.stack(list(buf)[:-1], axis=0)
                        ).float()
                        target = torch.from_numpy(buf[-1]).float()
                        yield seq, target


In [3]:
# DataLoader

SEQ_LEN  = 10
BATCH_SZ = 32

dataset    = AudioFrameTFRecordDataset(
    "train/*.tfrecord", SEQ_LEN
)
train_loader = torch.utils.data.DataLoader(
    dataset, batch_size=BATCH_SZ
)


## Model

In [4]:
# LSTM Regressor

class AudioRNN(torch.nn.Module):
    def __init__(self, input_dim=128, hidden_dim=256, num_layers=2):
        super().__init__()
        self.lstm = torch.nn.LSTM(
            input_dim, hidden_dim,
            num_layers, batch_first=True
        )
        self.fc   = torch.nn.Linear(hidden_dim, input_dim)

    def forward(self, x):
        out, _ = self.lstm(x)
        last   = out[:, -1, :]
        return self.fc(last)


In [5]:
# Loss

model     = AudioRNN(input_dim=128, hidden_dim=256, num_layers=2)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = torch.nn.MSELoss()


In [7]:
# Training

NUM_EPOCHS = 10

for epoch in range(1, NUM_EPOCHS+1):
    model.train()
    total_loss = 0.0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{NUM_EPOCHS}")
    for seqs, targets in pbar:
        optimizer.zero_grad()
        preds = model(seqs)
        loss  = criterion(preds, targets)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * seqs.size(0)
        pbar.set_postfix(loss=loss.item())

    avg_loss = total_loss / (len(train_loader) * BATCH_SZ)
    print(f" → Epoch {epoch} avg MSE: {avg_loss:.6f}")


Epoch 1/10: 7313it [01:52, 56.34it/s, loss=5.13e+3]2025-06-10 17:23:08.064691: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
Epoch 1/10: 21948it [06:02, 62.23it/s, loss=4.56e+3]2025-06-10 17:27:17.800073: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
Epoch 1/10: 50768it [13:59, 60.39it/s, loss=3.8e+3] 2025-06-10 17:35:14.873136: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
Epoch 1/10: 108945it [29:59, 58.30it/s, loss=4.67e+3]2025-06-10 17:51:15.517736: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
Epoch 1/10: 216989it [1:25:13, 42.43it/s, loss=8.18e+3]


KeyboardInterrupt: 