In [None]:
from datasets import load_from_disk, load_dataset
import torch
from torch.utils.data import DataLoader, IterableDataset

#### from disk

dataset = load_from_disk("D:/proj/datasets/CV17/CV17/TEST/")

def filter_fn(example):
    return example['sentence'] is not None

dataset = dataset.filter(filter_fn)

def preprocess_fn(examples):
    audio_tensors = [torch.tensor(audio['array'], dtype=torch.float32) for audio in examples['audio']]
    sentence_tensors = [torch.tensor([ord(ch) for ch in sentence], dtype=torch.int64) for sentence in examples['sentence']]
    
    return {
        'audio': audio_tensors,
        'sentence': sentence_tensors
    }

dataset = dataset.map(preprocess_fn, batched=True)

def collate_fn(batch):
    audio = torch.stack([item['audio'] for item in batch])
    sentence = torch.nn.utils.rnn.pad_sequence([item['sentence'] for item in batch], batch_first=True)
    return {'audio': audio, 'sentence': sentence}

dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

In [None]:
####### or streaming

# Load the dataset with streaming enabled
dataset = load_dataset("common_voice", "en", split='train', streaming=True)


def preprocess_fn(batch):
    audio_tensor = torch.tensor(batch['audio']['array'], dtype=torch.float32)
    sentence_tensor = torch.tensor([ord(ch) for ch in batch['sentence']], dtype=torch.int64)
    return {'audio': audio_tensor, 'sentence': sentence_tensor}



class HuggingFaceDataset(IterableDataset):
    def __init__(self, hf_dataset, preprocess_fn):
        self.dataset = hf_dataset
        self.preprocess_fn = preprocess_fn

    def __iter__(self):
        for batch in self.dataset:
            yield self.preprocess_fn(batch)

def collate_fn(batch):
    audio = torch.stack([item['audio'] for item in batch])
    sentence = torch.nn.utils.rnn.pad_sequence([item['sentence'] for item in batch], batch_first=True)
    return {'audio': audio, 'sentence': sentence}

streaming_dataset = HuggingFaceDataset(dataset, preprocess_fn)

dataloader = DataLoader(streaming_dataset, batch_size=32, collate_fn=collate_fn)


In [None]:
###example trainning loop

def train_and_evaluate(model, dataloader, eval_loader, optimizer, loss_fn, scheduler, num_epochs=1, max_steps=None, device='cuda', accumulation_steps=1, clear_cache=True, log_interval=10, eval_interval=20, save_interval=100, checkpoint_dir="checkpoint_dir", log_dir="log_dir"):
    model.to(device)
    global_step = 0
    scaler = torch.amp.GradScaler()
    writer = SummaryWriter(log_dir=log_dir)

    for epoch in range(num_epochs):
        if max_steps is not None and global_step >= max_steps:
            break

        model.train()
        total_loss = 0
        optimizer.zero_grad()
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}")

        for step, batch in enumerate(progress_bar):
            if max_steps is not None and global_step >= max_steps:
                break

            start_time = time.time()

            try:
                audio_features = batch['audio'].to(device)
                sentences = batch['sentence'].to(device)
            except KeyError as e:
                print(f"Key error: {e}. Available keys in batch: {batch.keys()}")
                continue

            with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
                with record_function("model_training"):
                    with torch.amp.autocast(device_type='cuda'):
                        audio_features_encoded = model.encoder(audio_features)
                        decoder_output = model.decoder(audio_features_encoded, sentences)
                        logits = decoder_output.view(-1, decoder_output.size(-1))
                        loss = loss_fn(logits, sentences.view(-1))
                        total_loss += loss.item()
                        loss = loss / accumulation_steps

                    scaler.scale(loss).backward()

                    if (step + 1) % accumulation_steps == 0:
                        scaler.step(optimizer)
                        scaler.update()
                        optimizer.zero_grad()

                        if clear_cache:
                            torch.cuda.empty_cache()

            global_step += 1
            end_time = time.time()
            samples_per_sec = len(batch['audio']) / (end_time - start_time)

            total_norm = 0
            for p in model.parameters():
                if p.grad is not None:
                    param_norm = p.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
            total_norm = total_norm ** (1. / 2)

            if global_step % log_interval == 0:
                writer.add_scalar('Loss/train', total_loss / (step + 1), global_step)
                writer.add_scalar('GradientNorm', total_norm, global_step)
                writer.add_scalar('LearningRate', optimizer.param_groups[0]['lr'], global_step)
                writer.add_scalar('SamplesPerSec', samples_per_sec, global_step)
                writer.add_scalar("Memory/Allocated", torch.cuda.memory_allocated(), global_step)
                writer.add_scalar("Memory/Cached", torch.cuda.memory_reserved(), global_step)

        if global_step % eval_interval == 0:
            model.eval()
            eval_loss = 0
            all_predictions = []
            all_labels = []
            with torch.no_grad():
                for eval_batch in eval_loader:
                    try:
                        audio_features = eval_batch['audio'].to(device)
                        sentences = eval_batch['sentence'].to(device)
                    except KeyError as e:
                        print(f"Key error: {e}. Available keys in eval batch: {eval_batch.keys()}")
                        continue

                    audio_features_encoded = model.encoder(audio_features)
                    decoder_output = model.decoder(audio_features_encoded, sentences)

                    logits = decoder_output.view(-1, decoder_output.size(-1))
                    loss = loss_fn(logits, sentences.view(-1))
                    eval_loss += loss.item()

                    all_predictions.extend(torch.argmax(decoder_output, dim=-1).cpu().numpy().tolist())
                    all_labels.extend(sentences.cpu().numpy().tolist())

            predictions = {
                "predictions": np.array(all_predictions, dtype="object"),
                "label_ids": np.array(all_labels, dtype="object")
            }

            metrics = compute_metrics(predictions)
            writer.add_scalar('Loss/eval', eval_loss / len(eval_loader), global_step)
            writer.add_scalar('CER', metrics['cer'], global_step)

            scheduler.step(eval_loss / len(eval_loader))

            sample_indices = range(min(1, len(all_predictions)))
            for idx in sample_indices:
                pred_str = tokenizer.decode(all_predictions[idx], skip_special_tokens=True)
                label_str = tokenizer.decode(all_labels[idx], skip_special_tokens=True)
                print(f"Sample {idx}: Prediction: {pred_str}, Label: {label_str}")
                logging.info(f"Sample {idx}: Prediction: {pred_str}, Label: {label_str}")

            model.train()

        if global_step % save_interval == 0:
            checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_step_{global_step}.pt')
            torch.save(model.state_dict(), checkpoint_path)
            print(f"Model saved at step {global_step} to {checkpoint_path}")
            logging.info(f"Model saved at step {global_step} to {checkpoint_path}")

    print(f'Epoch {epoch + 1}, Loss: {total_loss / len(dataloader)}')
    logging.info(f'Epoch {epoch + 1}, Loss: {total_loss / len(dataloader)}')

final_model_path = os.path.join(checkpoint_dir, 'final_model.pt')
torch.save(model.state_dict(), final_model_path)
print(f"Final model saved to {final_model_path}")
logging.info(f"Final model saved to {final_model_path}")
writer.close()
