In [None]:
import numpy as np
import transformers
import datasets
import wandb

import audiocap

In [None]:
config_name = "openai/whisper-base"
config = transformers.WhisperConfig.from_pretrained(config_name)
model = transformers.WhisperConfig.from_pretrained(config_name)
tokenizer = transformers.WhisperTokenizer.from_pretrained(config_name)
feature_extractor = transformers.WhisperFeatureExtractor.from_pretrained(config_name)
model = transformers.WhisperForConditionalGeneration(config) # not pretrained

In [None]:
ds_raw = datasets.load_dataset(
    "audiofolder",
    data_files={
        "dev": "../data/clotho_v2.1/development/*",
        "val": "../data/clotho_v2.1/validation/*",
        "test": "../data/clotho_v2.1/evaluation/*",
        "dev_mini": "../data/clotho_v2.1/development/*",
        "val_mini": "../data/clotho_v2.1/validation/*",
    }
)

random_gen = np.random.default_rng(seed=1)
dev_log_indices = random_gen.choice(len(ds_raw["dev"]), size=32, replace=False)
val_log_indices = random_gen.choice(len(ds_raw["val"]), size=16, replace=False)

ds_raw["dev_mini"] = ds_raw["dev"].select(dev_log_indices)
ds_raw["val_mini"] = ds_raw["val"].select(val_log_indices)

In [None]:
preprocessing = audiocap.preprocess.Preprocess(tokenizer, feature_extractor)

ds = datasets.IterableDatasetDict()
ds_mini = datasets.DatasetDict()
for split_name in ds_raw.keys():
    # TODO add augmentations, but only to development split
    split = (ds_raw[split_name]
        .to_iterable_dataset()
        .map(
            audiocap.preprocess.clotho_flatten_captions,
            batched=True,
            batch_size=10,
            remove_columns=["caption_1", "caption_2", "caption_3", "caption_4", "caption_5"],
        )
        .shuffle(
            seed=42,
            buffer_size=100,
        )
        .map(
            preprocessing,
            batched=True,
            batch_size=16,
            remove_columns=["audio"],
        )
        .take(100) # TODO remove (this is for debugging purposes)
    )

    if "mini" in split_name:
        # there are multiple rows (captions) per each audio clip
        # we want to keep only one from each for logging predictions
        ds_mini[split_name] = datasets.Dataset.from_list(list({x["filename"]: x for x in split}.values()))
    else:
        ds[split_name] = split


expected_keys = {
    'caption_idx',
    'caption',
    'path',
    'audio_array',
    'sampling_rate',
    'filename',
    'input_features',
    'labels'
}

# all datasets have the same features
assert set(ds_mini["dev_mini"][0].keys()) == expected_keys

In [None]:
collator = audiocap.preprocess.DataCollatorAudioSeq2SeqWithPadding(tokenizer, feature_extractor)

In [None]:
compute_metrics = audiocap.metrics.CaptioningMetrics(tokenizer)

In [None]:
log_preds_callback_valid = audiocap.callbacks.WandbPredictionLogger(
    log_prefix="val",
    log_every_n_steps=5, # TODO change (this is for debugging purposes)
    dataset=ds_mini["val_mini"],
    collator=collator,
    generate_kwargs={"max_length": 50},
)

log_preds_callback_dev = audiocap.callbacks.WandbPredictionLogger(
    log_prefix="dev",
    log_every_n_steps=5, # TODO change (this is for debugging purposes)
    dataset=ds_mini["dev_mini"],
    collator=collator,
    generate_kwargs={"max_length": 50},
)

In [None]:
wandb.init(
    project="audio-captioning",
    tags=["supervised"],
    # group="", # for organizing runs
    # dir="", # change for some tmp dir if you need
)

In [None]:
training_args = transformers.Seq2SeqTrainingArguments(
    output_dir="checkpoints/" + wandb.run.name, # or some tmp dir
    do_train=True,
    do_eval=True,

    max_steps=100_000,
    optim='adamw_torch',
    learning_rate=5e-5, # TODO
    warmup_steps=500, # TODO
    
    # TODO all of these
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    per_device_eval_batch_size=2,
    eval_accumulation_steps=16,
    
    logging_steps=3, # TODO 50
    report_to="wandb",

    # TODO check if this makes training faster in our setup
    fp16=True,

    metric_for_best_model="sacrebleu", # TODO change
    greater_is_better=True, # TODO change according to metric
    load_best_model_at_end=True,
    predict_with_generate=True,
    generation_num_beams=1, # TODO?
    generation_max_length=50,
    evaluation_strategy="steps",   
    eval_steps=10, # TODO

    save_strategy="steps",
    save_steps=1_000, # TODO
    save_total_limit=10, # TODO
)


trainer = transformers.Seq2SeqTrainer(
    model=model,
    tokenizer=tokenizer,
    data_collator=collator,
    compute_metrics=compute_metrics,
    train_dataset=ds["dev"],
    eval_dataset=ds["val"],
    args=training_args,
    callbacks=[log_preds_callback_valid, log_preds_callback_dev],
)

In [None]:
trainer.train()