# Federated Audio Classification tutorial with 🤗 Transformers

In [None]:
!pip install "datasets==1.14" "transformers==4.11.3" "librosa" "torch" "ipywidgets" "numpy==1.21.5"

# Connect to the Federation

In [None]:
from openfl.interface.interactive_api.federation import Federation

client_id = "frontend"
director_node_fqdn = "localhost"
director_port = 50050

federation = Federation(
    client_id=client_id,
    director_node_fqdn=director_node_fqdn,
    director_port=director_port,
    tls=False,
)

In [None]:
shard_registry = federation.get_shard_registry()
shard_registry

In [None]:
federation.target_shape

## Creating a FL experiment using Interactive API

In [None]:
from openfl.interface.interactive_api.experiment import (
    DataInterface,
    FLExperiment,
    ModelInterface,
    TaskInterface,
)

### Register dataset

In [None]:
import datasets
import numpy as np
import torch
from torch.utils.data import Dataset
from transformers import (
    AutoFeatureExtractor,
    AutoModelForAudioClassification,
    Trainer,
    TrainingArguments,
)

In [None]:
model_checkpoint = "facebook/wav2vec2-base"

labels = [
    "yes",
    "no",
    "up",
    "down",
    "left",
    "right",
    "on",
    "off",
    "stop",
    "go",
    "_silence_",
    "_unknown_",
]

label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

In [None]:
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
max_duration = 1.0


def preprocess_function(pre_processed_data):
    audio_arrays = pre_processed_data
    inputs = feature_extractor(
        audio_arrays,
        sampling_rate=feature_extractor.sampling_rate,
        max_length=int(feature_extractor.sampling_rate * max_duration),
        truncation=True,
    )

    return inputs

In [None]:
class SuperbShardDataset(Dataset):
    def __init__(self, dataset):
        self._dataset = dataset

    def __getitem__(self, index):
        x, y = self._dataset[index]
        x = preprocess_function(x)
        return {"input_values": x["input_values"][0], "labels": y}

    def __len__(self):
        return len(self._dataset)


class SuperbFedDataset(DataInterface):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @property
    def shard_descriptor(self):
        return self._shard_descriptor

    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        """
        Describe per-collaborator procedures for sharding.

        This method will be called during a collaborator initialization.
        Local shard_descriptor will be set by Envoy.
        """
        self._shard_descriptor = shard_descriptor
        self.train_set = SuperbShardDataset(
            self._shard_descriptor.get_dataset("train"),
        )
        self.valid_set = SuperbShardDataset(
            self._shard_descriptor.get_dataset("val"),
        )
        self.test_set = SuperbShardDataset(
            self._shard_descriptor.get_dataset("test"),
        )

    def __getitem__(self, index):
        return self.shard_descriptor[index]

    def __len__(self):
        return len(self.shard_descriptor)

    def get_train_loader(self):
        return self.train_set

    def get_valid_loader(self):
        return self.valid_set

    def get_train_data_size(self):
        return len(self.train_set)

    def get_valid_data_size(self):
        return len(self.valid_set)

In [None]:
fed_dataset = SuperbFedDataset()

### Describe a model and optimizer

In [None]:
"""
Download the pretrained model and fine-tune it. For classification we use the AutoModelForAudioClassification class.
"""

num_labels = len(id2label)

model = AutoModelForAudioClassification.from_pretrained(
    model_checkpoint,
    num_labels=num_labels,
    label2id=label2id,
    id2label=id2label,
)

In [None]:
from transformers import AdamW

params_to_update = []
for param in model.parameters():
    if param.requires_grad == True:
        params_to_update.append(param)

optimizer = AdamW(params_to_update, lr=3e-5)

#### Register model

In [None]:
framework_adapter = (
    "openfl.plugins.frameworks_adapters.pytorch_adapter.FrameworkAdapterPlugin"
)
MI = ModelInterface(
    model=model, optimizer=optimizer, framework_plugin=framework_adapter
)

### Define and register FL tasks

In [None]:
batch_size = 16
args = TrainingArguments(
    "finetuned_model",
    save_strategy="epoch",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=1,
    warmup_ratio=0.1,
    logging_steps=10,
    push_to_hub=False,
)

In [None]:
from datasets import load_metric

metric = load_metric("accuracy")


def compute_metrics(eval_pred):
    """Computes accuracy on a batch of predictions"""
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)

In [None]:
TI = TaskInterface()

import torch.nn as nn
import tqdm


@TI.register_fl_task(
    model="model", data_loader="train_loader", device="device", optimizer="optimizer"
)
def train(model, train_loader, optimizer, device):

    print(f"\n\n TASK TRAIN GOT DEVICE {device}\n\n")

    trainer = Trainer(
        model.to(device),
        args,
        train_dataset=train_loader,
        tokenizer=feature_extractor,
        optimizers=(optimizer, None),
        compute_metrics=compute_metrics,
    )
    train_metrics = trainer.train()
    return {"train_loss": train_metrics.metrics["train_loss"]}


@TI.register_fl_task(model="model", data_loader="val_loader", device="device")
def validate(model, val_loader, device):

    print(f"\n\n TASK VALIDATE GOT DEVICE {device}\n\n")

    trainer = Trainer(
        model.to(device),
        args,
        eval_dataset=val_loader,
        tokenizer=feature_extractor,
        compute_metrics=compute_metrics,
    )
    eval_metrics = trainer.evaluate()
    return {"eval_accuracy": eval_metrics["eval_accuracy"]}

## Time to start a federated learning experiment

In [None]:
experiment_name = "HF_audio_test_experiment"
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [None]:
fl_experiment.start(
    model_provider=MI,
    task_keeper=TI,
    data_loader=fed_dataset,
    rounds_to_train=2,
    opt_treatment="CONTINUE_GLOBAL",
    device_assignment_policy="CUDA_PREFERRED",
)

In [None]:
fl_experiment.stream_metrics()