In [None]:
import argparse
import math
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Union

import datasets
import torch
# from accelerate import Accelerator
# from accelerate.logging import get_logger
from datasets import DatasetDict, concatenate_datasets, load_dataset
from huggingface_hub import Repository, create_repo
from torch.utils.data.dataloader import DataLoader
from tqdm.auto import tqdm

import transformers

from utils import *

In [None]:
def parse_args():
    parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
    parser.add_argument(
        "--dataset_name",
        type=str,
        default='librispeech_asr',
        help="The name of the dataset to use (via the datasets library).",
    )
    parser.add_argument('--range', nargs='+', type=str, default=['clean, clean'], help='specify a range')
    parser.add_argument(
        "--dataset_config_names",
        nargs='+',
        type=str,
        default = ['clean', 'clean'],
        # required=True,
        help="The configuration names of the dataset to use (via the datasets library).",
    )
    parser.add_argument(
        "--dataset_split_names",
        nargs="+",
        type=str,
        default = ['validation', 'test'],
        # required=True,
        help="The names of the training data set splits to use (via the datasets library).",
    )
    parser.add_argument(
        "--preprocessing_num_workers",
        type=int,
        default=None,
        help="The number of processes to use for the preprocessing.",
    )
    parser.add_argument(
        "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
    )
    parser.add_argument(
        "--preprocessing_only",
        action="store_true",
        help="Only run the preprocessing script to be cached for future use",
    )
    parser.add_argument(
        "--cache_dir",
        type=str,
        default=None,
        help="Where do you want to store the pretrained models downloaded from huggingface.co",
    )
    parser.add_argument(
        "--validation_split_percentage",
        type=int,
        default=1,
        help="Percentage of training data that should be used for validation if no validation is present in dataset.",
    )
    parser.add_argument(
        "--logging_steps",
        type=int,
        default=500,
        help="Number of steps between each logging",
    )
    parser.add_argument(
        "--saving_steps",
        type=int,
        default=500,
        help="Number of steps between each logging",
    )
    parser.add_argument(
        "--audio_column_name",
        type=str,
        default="audio",
        help="Column in the dataset that contains speech file path. Defaults to 'audio'",
    )
    # parser.add_argument(
    #     "--model_name_or_path",
    #     type=str,
    #     help="Path to pretrained model or model identifier from huggingface.co/models.",
    #     required=True,
    # )
    parser.add_argument(
        "--config_name",
        type=str,
        default=None,
        help="Pretrained config name or path if not the same as model_name",
    )
    parser.add_argument(
        "--train_cache_file_name",
        type=str,
        default=None,
        help="Path to the train cached file name",
    )
    parser.add_argument(
        "--validation_cache_file_name",
        type=str,
        default=None,
        help="Path to the validation cached file name",
    )
    parser.add_argument(
        "--per_device_train_batch_size",
        type=int,
        default=8,
        help="Batch size (per device) for the training dataloader.",
    )
    parser.add_argument(
        "--per_device_eval_batch_size",
        type=int,
        default=8,
        help="Batch size (per device) for the evaluation dataloader.",
    )
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=5e-5,
        help="Initial learning rate (after the potential warmup period) to use.",
    )
    parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
    parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
    parser.add_argument(
        "--max_train_steps",
        type=int,
        default=None,
        help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument(
        "--gradient_checkpointing",
        action="store_true",
        help="If True, use gradient checkpointing to save memory at the expense of slower backward pass.",
    )
    # parser.add_argument(
    #     "--lr_scheduler_type",
    #     type=SchedulerType,
    #     default="linear",
    #     help="The scheduler type to use.",
    #     choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
    # )
    parser.add_argument(
        "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
    )
    parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
    parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.")
    parser.add_argument(
        "--max_gumbel_temperature",
        type=float,
        default=2.0,
        help="Maximum temperature for gumbel softmax.",
    )
    parser.add_argument(
        "--min_gumbel_temperature",
        type=float,
        default=0.5,
        help="Minimum temperature for gumbel softmax.",
    )
    parser.add_argument(
        "--gumbel_temperature_decay", type=float, default=0.999995, help="Decay of gumbel temperature during training."
    )
    parser.add_argument(
        "--max_duration_in_seconds",
        type=float,
        default=5.0,
        help="Filter out audio files that are longer than `max_duration_in_seconds` seconds",
    )
    parser.add_argument(
        "--min_duration_in_seconds",
        type=float,
        default=3.0,
        help="Filter out audio files that are shorter than `min_duration_in_seconds` seconds",
    )
    parser.add_argument(
        "--pad_to_multiple_of",
        type=int,
        default=None,
        help=(
            "If set will pad the sequence to a multiple of the provided value. This is especially useful to enable the"
            " use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta)."
        ),
    )
    parser.add_argument(
        "--adam_beta1",
        type=float,
        default=0.9,
        help="Beta1 for AdamW optimizer",
    )
    parser.add_argument(
        "--adam_beta2",
        type=float,
        default=0.999,
        help="Beta2 for AdamW optimizer",
    )
    parser.add_argument(
        "--adam_epsilon",
        type=float,
        default=1e-8,
        help="Epsilon for AdamW optimizer",
    )
    parser.add_argument("--push_to_hub", default=False, action="store_false", help="Whether or not to push the model to the Hub.")
    parser.add_argument(
        "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
    )
    parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
    parser.add_argument(
        "--mask_time_prob",
        type=float,
        default=None,
        help=(
            "Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked in the"
            " contrastive task. If omitted, will pull value from model config."
        ),
    )
    parser.add_argument(
        "--mask_time_length",
        type=int,
        default=None,
        help=(
            "Length of each vector mask span to mask along the time axis in the contrastive task."
            " If omitted, will pull value from model config."
        ),
    )
    # args = parser.parse_args()
    args = parser.parse_known_args()[0]

    if args.push_to_hub:
        print('TTTTT')
        assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."

    if args.output_dir is not None:
        os.makedirs(args.output_dir, exist_ok=True)

    return args

In [None]:
args = parse_args()

print(args.dataset_config_names)
print(args.dataset_split_names)

In [None]:
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
accelerator = Accelerator()
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
    datasets.utils.logging.set_verbosity_warning()
    transformers.utils.logging.set_verbosity_info()

    # set up weights and biases if available
    if is_wandb_available():
        import wandb

        wandb.init(project=args.output_dir.split("/")[-1])
else:
    datasets.utils.logging.set_verbosity_error()
    transformers.utils.logging.set_verbosity_error()

# If passed along, set the training seed now.
if args.seed is not None:
    set_seed(args.seed)

# Handle the repository creation
if accelerator.is_main_process:
    if args.push_to_hub and not args.preprocessing_only:
        # Retrieve of infer repo_name
        repo_name = args.hub_model_id
        if repo_name is None:
            repo_name = Path(args.output_dir).absolute().name
        # Create repo and retrieve repo_id
        repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id
        # Clone repo locally
        repo = Repository(args.output_dir, clone_from=repo_id, token=args.hub_token)
    elif args.output_dir is not None:
        os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone()

In [None]:
# 1. Download and create train, validation dataset
# We load all dataset configuration and datset split pairs passed in
# ``args.dataset_config_names`` and ``args.dataset_split_names``
datasets_splits = []
for dataset_config_name, train_split_name in zip(args.dataset_config_names, args.dataset_split_names):
    # load dataset
    dataset_split = load_dataset(
        args.dataset_name,
        dataset_config_name,
        split=train_split_name,
        cache_dir=args.cache_dir,
    )
    datasets_splits.append(dataset_split)

# Next, we concatenate all configurations and splits into a single training dataset
raw_datasets = DatasetDict()
if len(datasets_splits) > 1:
    raw_datasets = concatenate_datasets(datasets_splits).shuffle(seed=args.seed)
else:
    raw_datasets = datasets_splits[0]

# Take ``args.validation_split_percentage`` from the training dataset for the validation_split_percentage
num_validation_samples = raw_datasets["train"].num_rows * args.validation_split_percentage // 100

if num_validation_samples == 0:
    raise ValueError(
        "`args.validation_split_percentage` is less than a single sample "
        f"for {len(raw_datasets['train'])} training samples. Increase "
        "`args.num_validation_split_percentage`. "
    )

raw_datasets["validation"] = raw_datasets["train"].select(range(num_validation_samples))
raw_datasets["train"] = raw_datasets["train"].select(range(num_validation_samples, raw_datasets["train"].num_rows))

# 2. Now we preprocess the datasets including loading the audio, resampling and normalization
# Thankfully, `datasets` takes care of automatically loading and resampling the audio,
# so that we just need to set the correct target sampling rate and normalize the input
# via the `feature_extractor`
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(args.model_name_or_path)

# make sure that dataset decodes audio with correct sampling rate
raw_datasets = raw_datasets.cast_column(
    args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
)

# only normalized-inputs-training is supported
if not feature_extractor.do_normalize:
    raise ValueError(
        "Training is only supported for normalized inputs. Make sure ``feature_extractor.do_normalize == True``"
    )

# set max & min audio length in number of samples
max_length = int(args.max_duration_in_seconds * feature_extractor.sampling_rate)
min_length = int(args.min_duration_in_seconds * feature_extractor.sampling_rate)

def prepare_dataset(batch):
    sample = batch[args.audio_column_name]

    inputs = feature_extractor(
        sample["array"], sampling_rate=sample["sampling_rate"], max_length=max_length, truncation=True
    )
    batch["input_values"] = inputs.input_values[0]
    batch["input_length"] = len(inputs.input_values[0])

    return batch

# load via mapped files via path
cache_file_names = None
if args.train_cache_file_name is not None:
    cache_file_names = {"train": args.train_cache_file_name, "validation": args.validation_cache_file_name}

# load audio files into numpy arrays
with accelerator.main_process_first():
    vectorized_datasets = raw_datasets.map(
        prepare_dataset,
        num_proc=args.preprocessing_num_workers,
        remove_columns=raw_datasets["train"].column_names,
        cache_file_names=cache_file_names,
    )

    if min_length > 0.0:
        vectorized_datasets = vectorized_datasets.filter(
            lambda x: x > min_length,
            num_proc=args.preprocessing_num_workers,
            input_columns=["input_length"],
        )

    vectorized_datasets = vectorized_datasets.remove_columns("input_length")

In [None]:
from_scratch = False

if from_scratch:
    ## If we want to train the model from scratch
    # 3. Load model
    config = Wav2Vec2Config.from_pretrained(args.model_name_or_path)

    # pretraining is only supported for "newer" stable layer norm architecture
    # apply_spec_augment has to be True, mask_feature_prob has to be 0.0
    if not config.do_stable_layer_norm or config.feat_extract_norm != "layer":
        raise ValueError(
            "PreTraining is only supported for ``config.do_stable_layer_norm=True`` and"
            " ``config.feat_extract_norm='layer'"
        )

    # initialize random model
    model = Wav2Vec2ForPreTraining(config)

else:
    ## If we want to train the model using any pre-trained weights
    model = Wav2Vec2ForPreTrainingfrom_pretrained(args.model_name_or_path)

# Activate gradient checkpointing if needed
if args.gradient_checkpointing:
    model.gradient_checkpointing_enable()

# 4. Define data collator, optimizer and scheduler

mask_time_prob = config.mask_time_prob if args.mask_time_prob is None else args.mask_time_prob
mask_time_length = config.mask_time_length if args.mask_time_length is None else args.mask_time_length

data_collator = DataCollatorForWav2Vec2Pretraining(
    model=model,
    feature_extractor=feature_extractor,
    pad_to_multiple_of=args.pad_to_multiple_of,
    mask_time_prob=mask_time_prob,
    mask_time_length=mask_time_length,
)
train_dataloader = DataLoader(
    vectorized_datasets["train"],
    shuffle=True,
    collate_fn=data_collator,
    batch_size=args.per_device_train_batch_size,
)
eval_dataloader = DataLoader(
    vectorized_datasets["validation"], collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
)

# Optimizer
optimizer = AdamW(
    list(model.parameters()),
    lr=args.learning_rate,
    betas=[args.adam_beta1, args.adam_beta2],
    eps=args.adam_epsilon,
)

# Prepare everything with our `accelerator`.
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader
)

# Scheduler and math around the number of training steps.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)

if args.max_train_steps is None:
    args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch

lr_scheduler = get_scheduler(
    name=args.lr_scheduler_type,
    optimizer=optimizer,
    num_warmup_steps=args.num_warmup_steps,
    num_training_steps=args.max_train_steps,
)

# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

# 5. Train
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

logger.info("***** Running training *****")
logger.info(f"  Num examples = {len(vectorized_datasets['train'])}")
logger.info(f"  Num Epochs = {args.num_train_epochs}")
logger.info(f"  Instantaneous batch size per device = {args.per_device_train_batch_size}")
logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f"  Total optimization steps = {args.max_train_steps}")
completed_steps = 0
starting_epoch = 0

# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0
starting_epoch = 0
for epoch in range(starting_epoch, args.num_train_epochs):
    model.train()
    for step, batch in enumerate(train_dataloader):
        # compute num of losses
        num_losses = batch["mask_time_indices"].sum()
        sub_attention_mask = batch.pop("sub_attention_mask", None)
        sub_attention_mask = (
            sub_attention_mask if sub_attention_mask is not None else torch.ones_like(batch["mask_time_indices"])
        )
        percent_masked = num_losses / sub_attention_mask.sum()

        # forward
        outputs = model(**batch)

        # divide loss by gradient accumulation steps since gradients
        # are accumulated for multiple backward passes in PyTorch
        loss = outputs.loss / args.gradient_accumulation_steps
        accelerator.backward(loss)

        # make sure that `num_losses` is summed for distributed training
        # and average gradients over losses of all devices
        if accelerator.state.num_processes > 1:
            num_losses = accelerator.gather_for_metrics(num_losses).sum()
            gradient_multiplier = accelerator.state.num_processes / num_losses
            multiply_grads(model.module.parameters(), gradient_multiplier)
        else:
            multiply_grads(model.parameters(), 1 / num_losses)

        # update step
        if (step + 1) % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
            # compute grad norm for monitoring
            scale = (
                accelerator.scaler._scale.item()
                if hasattr(accelerator, "scaler") and accelerator.scaler is not None
                else 1
            )
            if accelerator.state.num_processes > 1:
                grad_norm = get_grad_norm(model.module.parameters(), scale)
            else:
                grad_norm = get_grad_norm(model.parameters(), scale)

            # update parameters
            optimizer.step()
            optimizer.zero_grad()

            if not accelerator.optimizer_step_was_skipped:
                lr_scheduler.step()
            elif accelerator.is_local_main_process:
                progress_bar.write(
                    f"Gradients have overflown - skipping update step... Updating gradient scale to {scale}..."
                )

            # update gumbel temperature
            gumbel_temperature = max(
                args.max_gumbel_temperature * args.gumbel_temperature_decay**completed_steps,
                args.min_gumbel_temperature,
            )
            if hasattr(model, "module"):
                model.module.set_gumbel_temperature(gumbel_temperature)
            else:
                model.set_gumbel_temperature(gumbel_temperature)

            progress_bar.update(1)
            completed_steps += 1

        # 6. Log all results
        if (step + 1) % (args.gradient_accumulation_steps * args.logging_steps) == 0:
            loss.detach()
            outputs.contrastive_loss.detach()
            outputs.diversity_loss.detach()

            if accelerator.state.num_processes > 1:
                loss = accelerator.gather_for_metrics(loss).sum()
                outputs.contrastive_loss = accelerator.gather_for_metrics(outputs.contrastive_loss).sum()
                outputs.diversity_loss = accelerator.gather_for_metrics(outputs.diversity_loss).sum()
                percent_masked = accelerator.gather_for_metrics(percent_masked).sum()

            train_logs = {
                "loss": (loss * args.gradient_accumulation_steps) / num_losses,
                "constrast_loss": outputs.contrastive_loss / num_losses,
                "div_loss": outputs.diversity_loss / num_losses,
                "%_mask_idx": percent_masked / accelerator.num_processes,
                "ppl": outputs.codevector_perplexity,
                "lr": torch.tensor(optimizer.param_groups[0]["lr"]),
                "temp": torch.tensor(gumbel_temperature),
                "grad_norm": torch.tensor(grad_norm),
            }
            log_str = ""
            for k, v in train_logs.items():
                log_str += "| {}: {:.3e}".format(k, v.item())

            if accelerator.is_local_main_process:
                progress_bar.write(log_str)
                if is_wandb_available():
                    wandb.log(train_logs)

        # save model every `args.saving_steps` steps
        if (step + 1) % (args.gradient_accumulation_steps * args.saving_steps) == 0:
            if (args.push_to_hub and epoch < args.num_train_epochs - 1) or args.output_dir is not None:
                accelerator.wait_for_everyone()
                unwrapped_model = accelerator.unwrap_model(model)
                unwrapped_model.save_pretrained(
                    args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
                )

            if (args.push_to_hub and epoch < args.num_train_epochs - 1) and accelerator.is_main_process:
                repo.push_to_hub(
                    commit_message=f"Training in progress step {completed_steps}",
                    blocking=False,
                    auto_lfs_prune=True,
                )

        # if completed steps > `args.max_train_steps` stop
        if completed_steps >= args.max_train_steps:
            break

    # 7. Validate!
    model.eval()

    # init logs
    val_logs = {
        "val_loss": 0,
        "val_contrastive_loss": 0,
        "val_diversity_loss": 0,
        "val_num_losses": 0,
    }
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            batch.pop("sub_attention_mask", None)
            outputs = model(**batch)

        val_logs["val_loss"] += outputs.loss
        val_logs["val_contrastive_loss"] += outputs.contrastive_loss
        val_logs["val_diversity_loss"] += outputs.diversity_loss
        val_logs["val_num_losses"] += batch["mask_time_indices"].sum()

    # sum over devices in multi-processing
    if accelerator.num_processes > 1:
        val_logs = {k: accelerator.gather_for_metrics(v).sum() for k, v in val_logs.items()}

    val_logs = {k: v / val_logs["val_num_losses"] for k, v in val_logs.items()}

    log_str = ""
    for k, v in val_logs.items():
        log_str += "| {}: {:.3e}".format(k, v.item())

    if accelerator.is_local_main_process:
        progress_bar.write(log_str)
        if is_wandb_available():
            wandb.log(val_logs)

    if args.output_dir is not None:
        accelerator.wait_for_everyone()
        unwrapped_model = accelerator.unwrap_model(model)
        unwrapped_model.save_pretrained(
            args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
        )
        if accelerator.is_main_process:
            if args.push_to_hub:
                repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)