# Federated Unlearning with Knowledge Distillation

## Problem definition

In a general perspective, the problem includes two entities: the server $S$ and a group of clients $C$ participated in the FL. The server $S$ relies on the group of clients $C$ to help train a global model $M$ to be used by all the clients. According to the FL definition, training data is on the clients' side and cannot be shared with others. Each client will train the global model $M$ on its local dataset and send its model updates to the server. The server will aggregate model updates from clients and generate an updated global model for the next round of training. Suppose a client $i \in C$ needs to revoke its previous contribution/updates to the global model $M$. In that case, the server should be able to update the current global model $M$ to an unlearned model $M^{C \setminus \{i\}}$, where $M^{C \setminus \{i\}}$ is a model that could be trained if client $i$ was never included in the training group $C$.

Assume that there are a total of $N$ clients to participate in the FL training at round $t$, and each client trains the global model $M_{t}$ with its private dataset and updates parameter changes to the server. The new global model $ M_{t+1}$ is calculated by an averaged aggregation (*FedAvg*) of the weight updates from these clients.

$$M_{t+1} = M_{t} + \frac{1}{N} \sum_{i=1}^{N} \Delta M_{t}^{i}$$

where $\Delta M_{t}^{i}$ is the parameter update contributed by client $i$ based on the model $M_{t}$. The server keeps running the training process and updates the global model from $M_{1}$ to $M_{F}$ when the termination criterion (e.g. test accuracy) has been satisfied at round $F-1$. 

The federated unlearning task is defined as to completely remove the influence of all the updates $\Delta M_{t}^{i}$ from a target client $i$, with $t \in [1, F-1]$ from the model $M_{F}$. In other words, it removes the contribution of any client $i$ to the final global model $M_{F}$ and creates a new model $M_{F}^{C \setminus \{i\}}$ as if client $i$ has never participated in the training process.

## Unlearning design

![dist](https://drive.google.com/uc?export=view&id=1GXOU3yTWW8l_iG-YvrWH11UgEuxD7znA)

This method requires the server to keep the history of parameter updates from each contributing client and possess some extra outsourced unlabeled data. The key idea is to first erase the historical parameter updates from the target client and then recover the damage through the knowledge distillation method. 

### Erase historical parameter updates

To completely remove the contribution of any client $i$ to the final global model $M_{F}$, we want to erase all the historical updates $\Delta M_{t}^{i}$ from this client as long as $t \in [1, F-1]$. The update of the global model at round $t$ consists of averaged weight updates from participating clients. If we use $\Delta M_{t}$ to represent this update at round $t$, the finalized global model $M_{F}$ can be viewed as a composition of initial model weight $M_{1}$ and updates to the global model from round $1$ to round $F-1$. 

$$M_{F} = M_{1} + \sum_{t=1}^{F-1} \Delta M_{t}$$

For simplicity, we assume that each round, there are $N$ clients participating in the FL training, and client $N$ is the target client that wants to be unlearned from the global model. At this time, we can simplify the problem as to remove the contribution $\Delta M_{t}^{N}$ of the target client $N$ from the global model update $\Delta M_{t}$ at each round $t$. 

$$\Delta M_{t} = \frac{1}{N} \sum_{i=1}^{N} \Delta M_{t}^{i} = \frac{1}{N} \sum_{i=1}^{N-1} \Delta M_{t}^{i} + \frac{1}{N} \Delta M_{t}^{N}$$

There are two ways to calculate the new global model update $\Delta M_{t}'$ at round $t$. The first one is to assume that only $N-1$ clients were participating in the FL at round $t$. In this way, the new global model updates $\Delta M_{t}'$ at round $t$ becomes the following equation.

$$\Delta M_{t}' = \frac{1}{N-1} \sum_{i=1}^{N-1} \Delta M_{t}^{i} = \frac{N}{N-1} \Delta M_{t} - \frac{1}{N-1} \Delta M_{t}^{N}$$

However, we can not directly accumulate the new updates to reconstruct the unlearning model because of the incremental learning property of FL, as discussed in the previous section. Any update to the global model $M_{t}$ will result in a requirement of updates to all the model updates that happened afterward. Hence, we use $\epsilon_{t}$ to represent the necessary amendment (skew) to the global model at each round $t$. After combining the above equations, we can get the unlearning version of the final global model $M_{F}'$.

$$M_{F}' = M_{1} + \frac{N}{N-1} \sum_{t=1}^{F-1} \Delta M_{t} - \frac{1}{N-1} \sum_{t=1}^{F-1} \Delta M_{t}^{N} + \sum_{t=1}^{F-1} \epsilon_{t}$$

where $\epsilon_{t}$ is the necessary correction to amend the skew produced by change of the model at previous rounds. Because of this characteristic of the incremental learning process in FL, the skew $\epsilon_{t}$ will increase with more training rounds after updates to the global model. Thus, the above unlearning rule has a shortcoming that when the target client $N$ makes little contribution to the model at round $t$ (e.g., $\Delta M_{t}^{N} \approx 0$), the global model update $\Delta M_{t}$ will still change a lot by multiplying itself with a factor of $\frac{N}{N-1}$. This will bring more skew $\epsilon_{t}$ to the global model in the following rounds. % ($t=t+x$, $x > 0$).

To mitigate this problem, we propose to use a lazy learning strategy to eliminate the influence of target client $N$. Specifically, we assume client $N$ still participated in the training process but set his updates $\Delta M_{t}^{N} = 0$ for all rounds $t \in [1, F-1]$. The unlearning of the global model update can be simplified as follows.

$$\Delta M_{t}' = \frac{1}{N} \sum_{i=1}^{N-1} \Delta M_{t}^{i} = \Delta M_{t} - \frac{1}{N} \Delta M_{t}^{N}$$

A combination of the above formula with unlearning equation gives us the unlearning result of the final global model $M_{F}'$.

$$M_{F}' = M_{1} + \sum_{t=1}^{F-1} \Delta M_{t}' + \sum_{t=1}^{F-1} \epsilon_{t} \\
= M_{1} + \sum_{t=1}^{F-1} \Delta M_{t} - \frac{1}{N} \sum_{t=1}^{F-1} \Delta M_{t}^{N} + \sum_{t=1}^{F-1} \epsilon_{t} \\
= M_{F} - \frac{1}{N} \sum_{t=1}^{F-1} \Delta M_{t}^{N} + \sum_{t=1}^{F-1} \epsilon_{t}
$$

Now the unlearning model update rule becomes surprisingly straightforward and easy to understand. We just need to subtract all the historical averaged updates from target client $N$ from the final global model $M_{F}$. Then, we remedy the skew $\epsilon_{t}$ caused by this process because of the incremental learning characteristic of the FL. 

### Remedy with knowledge distillation

There is no existing method to calculate the skew $\epsilon_{t}$ without retraining the updated model on the original dataset again. However, as we discussed before, one of the challenges in FL is that we cannot rely on the clients to hold the dataset forever and prepare for the unlearning purpose. How to remedy this skew without relying on clients' training data or the participation of clients becomes crucially important. 

To tackle this problem, we propose to leverage the knowledge distillation method to train the unlearning model using the original global model. 
The distillation technique is first motivated by the goal of reducing the size of DNN architectures or ensembles of models. 

It uses the prediction results of class probabilities produced by an ensemble of models or a complex DNN to train another DNN of a reduced number of parameters without much loss of accuracy. The intuition is based on the fact that the knowledge acquired by DNNs during the training process is not only encoded in the weight parameters but also can be reflected from the class probability prediction output of the model. 

Distillation training uses these soft probabilities instead of hard ground truth labels to provide additional information about each class and the prediction logic to the training process. 
For example, given an image of "7" from the MNIST dataset, the distillation training process can train the new model with more information (except from the ground truth label that the image belongs to class "7"), such as this image is close to "1" and "9" but a lot different from "3" and "8".

To perform knowledge distillation in the unlearning problem, we treat the original global model as the teacher model and the skewed unlearning model as the student model. Then, the server can use any unlabeled data %(note that the server is not assumed to have labelled data in FL) 
to train the unlearning model and remedy the skew $\epsilon_{t}$ caused by the previous subtraction process. Specifically, the original global model produces class prediction probabilities through a "softmax" output layer that converts the logit, $z_{i}$, computed for each class into a probability, $q_{i}$, by comparing $z_{i}$ with the other logits. 

$$q_{i} = \frac{\exp (z_{i} / T)}{\sum_{j} \exp (z_{j} / T)}$$

where $T$ is a parameter named \textit{temperature} and shared across the softmax layer. The value of $T$ is normally set to 1 for traditional ML training and predictions. A higher temperature $T$ makes the DNN produce a softer probability distribution over classes. In other words, the probability output will be forced to produce relatively large values for each class, and logits $z_{i}$ become negligible compared to temperature $T$. An example is that the output probability for each class $i$ converge to $1 / Z$ (assume there are $Z$ classes for prediction) as $T \rightarrow \infty$. In summary, higher temperature $T$ produces probability distribution more ambiguously while lower temperature $T$ produces probability distribution more discretely. 

We use this soft class prediction probability produced by the original global model $M_{F}$ to label the dataset. The skewed unlearning model is then trained with this dataset with soft labels (with high temperature $T$). On the other hand, if the server possesses a labeled dataset, we can also leverage a combination of hard labels (ground truth with temperature $T = 1$) and soft labels produced by the global model. Using a weighted average of these two objective functions with a considerably lower weight on the objective function of the hard labels can produce the best results. The temperature will be set back to $1$ after distillation training, so the unlearning model $M_{F}'$ can produce more discrete class prediction probabilities during test time. 

## Experiemnts

In our FL settings, the clients and the server share the same model for each dataset. We use three different models for different classification tasks.
Specifically, the model for MNIST is consisted of 2 convolutional layers, 2 max pool layers, and 2 fully connected layers in the end for prediction output.
As for CIFAR-10, we use the well-known VGG11 network, which consists of 8 convolutional layers, 5 max pool layers, followed by one fully connected layer to produce probability prediction output. On the GTSRB dataset, we use another famous AlexNet, which is composed of 5 convolutional layers, 3 max pool layers, and 3 fully connected layers for output.

We use backdoor attacks in the target client's updates to the global model as described before, so that we can intuitively investigate the unlearning effect based on the attack success rate of the unlearned global model.
The backdoor attack is triggered by backdoor patterns in the input image. 
The target client changes some pixels in the benign inputs to create a backdoor pattern. Backdoor targets in the experiments include making digit "1" predicted as digit "9" in MNIST, "truck" predicted as "car" in CIFAR-10, and "Stop Sign" predicted as "Speed limit Sign (120 km/h)" in the GTSRB dataset.

![acc](https://drive.google.com/uc?export=view&id=1-F7MzAb71gkakjmi3NhZeDppKRMPF4KU)

![val](https://drive.google.com/uc?export=view&id=1XYeNSr86bB7EAVJvnnDibJ2O6jY22wlL)

# References

- C. Wu, S. Zhu, and P. Mitra, Federated Unlearning with Knowledge Distillation. arXiv, 2022. [[Paper](https://arxiv.org/abs/2201.09441)]
