Skip to content

Commit

Permalink
add moco v3 (PaddlePaddle#176)
Browse files Browse the repository at this point in the history
  • Loading branch information
GuoxiaWang committed Jan 9, 2023
1 parent 79ae3c0 commit 7a53dad
Show file tree
Hide file tree
Showing 11 changed files with 1,752 additions and 1 deletion.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
[PLSC](https://github.com/PaddlePaddle/PLSC) is an open source repo for a collection of Paddle Large Scale Classification Tools, which supports large-scale classification model pre-training as well as finetune for downstream tasks.

## Available Models
* [Face Recognition](./task/recognition/face/)
* [ViT](./task/classification/vit/)
* [DeiT](./task/classification/deit/)
* [CaiT](./task/classification/cait/)
* [Face Recognition](./task/recognition/face/)
* [MoCo v3](./task/ssl/mocov3/)

## Top News 🔥

Expand Down
117 changes: 117 additions & 0 deletions task/ssl/mocov3/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
## MoCo v3 for Self-supervised ResNet and ViT


PaddlePaddle reimplementation of [facebookresearch's repository for the MoCo v3 model](https://github.com/facebookresearch/moco-v3) that was released with the paper [An Empirical Study of Training Self-Supervised Vision Transformers](https://arxiv.org/abs/2104.02057).

## Requirements
To enjoy some new features, PaddlePaddle 2.4 is required. For more installation tutorials
refer to [installation.md](../../../tutorials/get_started/installation.md)

## Data Preparation

Prepare the data into the following directory:
```text
dataset/
└── ILSVRC2012
├── train
└── val
```


## How to Self-supervised Pre-Training

With a batch size of 4096, ViT-Base is trained with 4 nodes:

```bash
# Note: Set the following environment variables
# and then need to run the script on each node.
unset PADDLE_TRAINER_ENDPOINTS
export PADDLE_NNODES=4
export PADDLE_MASTER="xxx.xxx.xxx.xxx:12538"
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export FLAGS_stop_check_timeout=3600

IMAGENET_DIR=./dataset/ILSVRC2012/
python -m paddle.distributed.launch \
--nnodes=$PADDLE_NNODES \
--master=$PADDLE_MASTER \
--devices=$CUDA_VISIBLE_DEVICES \
main_moco.py \
-a moco_vit_base \
--optimizer=adamw --lr=1.5e-4 --weight-decay=.1 \
--epochs=300 --warmup-epochs=40 \
--stop-grad-conv1 --moco-m-cos --moco-t=.2 \
${IMAGENET_DIR}
```

## How to Linear Classification

By default, we use momentum-SGD and a batch size of 1024 for linear classification on frozen features/weights. This can be done with a single 8-GPU node.

```bash
unset PADDLE_TRAINER_ENDPOINTS
export PADDLE_NNODES=1
export PADDLE_MASTER="127.0.0.1:12538"
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export FLAGS_stop_check_timeout=3600

IMAGENET_DIR=./dataset/ILSVRC2012/
python -m paddle.distributed.launch \
--nnodes=$PADDLE_NNODES \
--master=$PADDLE_MASTER \
--devices=$CUDA_VISIBLE_DEVICES \
main_lincls.py \
-a moco_vit_base \
--lr=3 \
--pretrained pretrained/checkpoint_0299.pd \
${IMAGENET_DIR}
```

## How to End-to-End Fine-tuning
To perform end-to-end fine-tuning for ViT, use our script to convert the pre-trained ViT checkpoint to PLSC DeiT format:

```bash
python extract_weight.py \
--input pretrained/checkpoint_0299.pd \
--output pretrained/moco_vit_base.pdparams
```

Then run the training with the converted PLSC format checkpoint:

```bash
unset PADDLE_TRAINER_ENDPOINTS
export PADDLE_NNODES=1
export PADDLE_MASTER="127.0.0.1:12538"
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export FLAGS_stop_check_timeout=3600

python -m paddle.distributed.launch \
--nnodes=$PADDLE_NNODES \
--master=$PADDLE_MASTER \
--devices=$CUDA_VISIBLE_DEVICES \
plsc-train \
-c ./configs/DeiT_base_patch16_224_in1k_1n8c_dp_fp16o1.yaml \
-o Global.epochs=150 \
-o Global.pretrained_model=pretrained/moco_vit_base \
-o Global.finetune=True
```

## Models

### ViT-Base
| Model | Phase | Dataset | Configs | GPUs | Epochs | Top1 Acc | Checkpoint |
| ------------- | ----------- | ------------ | ------------------------------------------------------------ | ---------- | ------ | -------- | ------------------------------------------------------------ |
| moco_vit_base | pretrain | ImageNet2012 | - | A100*N4C32 | 300 | - | [download](https://plsc.bj.bcebos.com/models/mocov3/v2.4/moco_vit_base_in1k_300ep.pd) |
| moco_vit_base | linear prob | ImageNet2012 | - | A100*N1C8 | 90 | 0.7662 | |
| moco_vit_base | finetune | ImageNet2012 | [config](./configs/DeiT_base_patch16_224_in1k_1n8c_dp_fp16o1.yaml) | A100*N1C8 | 150 | 0.8288 | |

## Citations

```bibtex
@Article{chen2021mocov3,
author = {Xinlei Chen* and Saining Xie* and Kaiming He},
title = {An Empirical Study of Training Self-Supervised Vision Transformers},
journal = {arXiv preprint arXiv:2104.02057},
year = {2021},
}
```
159 changes: 159 additions & 0 deletions task/ssl/mocov3/builder_moco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddle.nn as nn


class MoCo(nn.Layer):
"""
Build a MoCo model with a base encoder, a momentum encoder, and two MLPs
https://arxiv.org/abs/1911.05722
"""

def __init__(self, base_encoder, dim=256, mlp_dim=4096, T=1.0):
"""
dim: feature dimension (default: 256)
mlp_dim: hidden dimension in MLPs (default: 4096)
T: softmax temperature (default: 1.0)
"""
super(MoCo, self).__init__()

self.T = T

# build encoders
self.base_encoder = base_encoder(num_classes=mlp_dim)
self.momentum_encoder = base_encoder(num_classes=mlp_dim)

self._build_projector_and_predictor_mlps(dim, mlp_dim)

for param_b, param_m in zip(self.base_encoder.parameters(),
self.momentum_encoder.parameters()):
param_m.copy_(param_b, False) # initialize
param_m.stop_gradient = True # not update by gradient

def _build_mlp(self,
num_layers,
input_dim,
mlp_dim,
output_dim,
last_bn=True):
mlp = []
for l in range(num_layers):
dim1 = input_dim if l == 0 else mlp_dim
dim2 = output_dim if l == num_layers - 1 else mlp_dim

mlp.append(nn.Linear(dim1, dim2, bias_attr=False))

if l < num_layers - 1:
mlp.append(nn.BatchNorm1D(dim2))
mlp.append(nn.ReLU())
elif last_bn:
# follow SimCLR's design: https://github.com/google-research/simclr/blob/master/model_util.py#L157
# for simplicity, we further removed gamma in BN
mlp.append(
nn.BatchNorm1D(
dim2, weight_attr=False, bias_attr=False))

return nn.Sequential(*mlp)

def _build_projector_and_predictor_mlps(self, dim, mlp_dim):
pass

@paddle.no_grad()
def _update_momentum_encoder(self, m):
"""Momentum update of the momentum encoder"""
with paddle.amp.auto_cast(False):
for param_b, param_m in zip(self.base_encoder.parameters(),
self.momentum_encoder.parameters()):
paddle.assign((param_m * m + param_b * (1. - m)), param_m)

def contrastive_loss(self, q, k):
# normalize
q = nn.functional.normalize(q, axis=1)
k = nn.functional.normalize(k, axis=1)
# gather all targets
k = concat_all_gather(k)
# Einstein sum is more intuitive
logits = paddle.einsum('nc,mc->nm', q, k) / self.T
N = logits.shape[0] # batch size per GPU
labels = (paddle.arange(
N, dtype=paddle.int64) + N * paddle.distributed.get_rank())
return nn.CrossEntropyLoss()(logits, labels) * (2 * self.T)

def forward(self, x1, x2, m):
"""
Input:
x1: first views of images
x2: second views of images
m: moco momentum
Output:
loss
"""

# compute features
q1 = self.predictor(self.base_encoder(x1))
q2 = self.predictor(self.base_encoder(x2))

with paddle.no_grad(): # no gradient
self._update_momentum_encoder(m) # update the momentum encoder

# compute momentum features as targets
k1 = self.momentum_encoder(x1)
k2 = self.momentum_encoder(x2)

return self.contrastive_loss(q1, k2) + self.contrastive_loss(q2, k1)


class MoCo_ResNet(MoCo):
def _build_projector_and_predictor_mlps(self, dim, mlp_dim):
hidden_dim = self.base_encoder.fc.weight.shape[0]
del self.base_encoder.fc, self.momentum_encoder.fc # remove original fc layer

# projectors
self.base_encoder.fc = self._build_mlp(2, hidden_dim, mlp_dim, dim)
self.momentum_encoder.fc = self._build_mlp(2, hidden_dim, mlp_dim, dim)

# predictor
self.predictor = self._build_mlp(2, dim, mlp_dim, dim, False)


class MoCo_ViT(MoCo):
def _build_projector_and_predictor_mlps(self, dim, mlp_dim):
hidden_dim = self.base_encoder.head.weight.shape[0]
del self.base_encoder.head, self.momentum_encoder.head # remove original fc layer

# projectors
self.base_encoder.head = self._build_mlp(3, hidden_dim, mlp_dim, dim)
self.momentum_encoder.head = self._build_mlp(3, hidden_dim, mlp_dim,
dim)

# predictor
self.predictor = self._build_mlp(2, dim, mlp_dim, dim)


# utils
@paddle.no_grad()
def concat_all_gather(tensor):
"""
Performs all_gather operation on the provided tensors.
"""
if paddle.distributed.get_world_size() < 2:
return tensor

tensors_gather = []
paddle.distributed.all_gather(tensors_gather, tensor)

output = paddle.concat(tensors_gather, axis=0)
return output
Loading

0 comments on commit 7a53dad

Please sign in to comment.