In [None]:
import transformers
import datasets
import wandb

import audiocap

In [None]:
config_name = "openai/whisper-base"
config = transformers.WhisperConfig.from_pretrained(config_name)
model = transformers.WhisperConfig.from_pretrained(config_name)
tokenizer = transformers.WhisperTokenizer.from_pretrained(config_name)
feature_extractor = transformers.WhisperFeatureExtractor.from_pretrained(config_name)
model = transformers.WhisperForConditionalGeneration(config) # not pretrained

In [None]:
ds_raw = datasets.load_dataset(
    "audiofolder",
    data_files={
        "development": "../data/clotho_v2.1/development/*",
        "validation": "../data/clotho_v2.1/validation/*",
        "evaluation": "../data/clotho_v2.1/evaluation/*",
    }
)

In [None]:
preprocessing = audiocap.preprocess.Preprocess(tokenizer, feature_extractor)

ds = datasets.DatasetDict()
for split in ["development", "validation", "evaluation"]:
    # TODO add augmentations, but only to development split
    ds[split] = (ds_raw[split]
        .to_iterable_dataset()
        .map(
            audiocap.preprocess.clotho_flatten_captions,
            batched=True,
            batch_size=10,
            remove_columns=["caption_1", "caption_2", "caption_3", "caption_4", "caption_5"],
        )
        .shuffle(
            seed=42,
            buffer_size=100,
        )
        .map(
            preprocessing,
            batched=True,
            batch_size=16,
            remove_columns=["audio"],
        )
        .take(100) # TODO remove (this is for debugging purposes)
    )

In [None]:
collator = audiocap.preprocess.DataCollatorAudioSeq2SeqWithPadding(tokenizer, feature_extractor)

In [None]:
compute_metrics = audiocap.metrics.CaptioningMetrics(tokenizer)

In [None]:
wandb.init(
    project="audio-captioning",
    tags=["supervised"],
    # group="", # for organizing runs
    # dir="", # change for some tmp dir if you need
)

In [None]:
training_args = transformers.Seq2SeqTrainingArguments(
    output_dir="checkpoints/" + wandb.run.name, # or some tmp dir
    do_train=True,
    do_eval=True,

    max_steps=100_000,
    optim='adamw_torch',
    learning_rate=5e-5, # TODO
    warmup_steps=500, # TODO
    
    # TODO all of these
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    per_device_eval_batch_size=2,
    eval_accumulation_steps=16,
    
    logging_steps=50,
    report_to="wandb",

    # TODO check if this makes training faster in our setup
    fp16=True,

    metric_for_best_model="sacrebleu", # TODO change
    greater_is_better=True, # TODO change accoring to metric
    load_best_model_at_end=True,
    predict_with_generate=True,
    generation_max_length=50,
    evaluation_strategy="steps",   
    eval_steps=10, # TODO

    save_strategy="steps",
    save_steps=1_000, # TODO
    save_total_limit=10, # TODO
)


trainer = transformers.Seq2SeqTrainer(
    model=model,
    tokenizer=tokenizer,
    data_collator=collator,
    compute_metrics=compute_metrics,
    train_dataset=ds["development"],
    eval_dataset=ds["validation"],
    args=training_args,
)

In [None]:
trainer.train()