diff --git a/README.md b/README.md index 9e333b39aaa30..44a48bf58bf4a 100644 --- a/README.md +++ b/README.md @@ -18,10 +18,11 @@ [PLSC](https://github.com/PaddlePaddle/PLSC) is an open source repo for a collection of Paddle Large Scale Classification Tools, which supports large-scale classification model pre-training as well as finetune for downstream tasks. ## Available Models +* [Face Recognition](./task/recognition/face/) * [ViT](./task/classification/vit/) * [DeiT](./task/classification/deit/) * [CaiT](./task/classification/cait/) -* [Face Recognition](./task/recognition/face/) +* [MoCo v3](./task/ssl/mocov3/) ## Top News 🔥 diff --git a/task/ssl/mocov3/README.md b/task/ssl/mocov3/README.md new file mode 100644 index 0000000000000..c80a64c17ceb2 --- /dev/null +++ b/task/ssl/mocov3/README.md @@ -0,0 +1,117 @@ +## MoCo v3 for Self-supervised ResNet and ViT + + +PaddlePaddle reimplementation of [facebookresearch's repository for the MoCo v3 model](https://github.com/facebookresearch/moco-v3) that was released with the paper [An Empirical Study of Training Self-Supervised Vision Transformers](https://arxiv.org/abs/2104.02057). + +## Requirements +To enjoy some new features, PaddlePaddle 2.4 is required. For more installation tutorials +refer to [installation.md](../../../tutorials/get_started/installation.md) + +## Data Preparation + +Prepare the data into the following directory: +```text +dataset/ +└── ILSVRC2012 + ├── train + └── val +``` + + +## How to Self-supervised Pre-Training + +With a batch size of 4096, ViT-Base is trained with 4 nodes: + +```bash +# Note: Set the following environment variables +# and then need to run the script on each node. +unset PADDLE_TRAINER_ENDPOINTS +export PADDLE_NNODES=4 +export PADDLE_MASTER="xxx.xxx.xxx.xxx:12538" +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export FLAGS_stop_check_timeout=3600 + +IMAGENET_DIR=./dataset/ILSVRC2012/ +python -m paddle.distributed.launch \ + --nnodes=$PADDLE_NNODES \ + --master=$PADDLE_MASTER \ + --devices=$CUDA_VISIBLE_DEVICES \ + main_moco.py \ + -a moco_vit_base \ + --optimizer=adamw --lr=1.5e-4 --weight-decay=.1 \ + --epochs=300 --warmup-epochs=40 \ + --stop-grad-conv1 --moco-m-cos --moco-t=.2 \ + ${IMAGENET_DIR} +``` + +## How to Linear Classification + +By default, we use momentum-SGD and a batch size of 1024 for linear classification on frozen features/weights. This can be done with a single 8-GPU node. + +```bash +unset PADDLE_TRAINER_ENDPOINTS +export PADDLE_NNODES=1 +export PADDLE_MASTER="127.0.0.1:12538" +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export FLAGS_stop_check_timeout=3600 + +IMAGENET_DIR=./dataset/ILSVRC2012/ +python -m paddle.distributed.launch \ + --nnodes=$PADDLE_NNODES \ + --master=$PADDLE_MASTER \ + --devices=$CUDA_VISIBLE_DEVICES \ + main_lincls.py \ + -a moco_vit_base \ + --lr=3 \ + --pretrained pretrained/checkpoint_0299.pd \ + ${IMAGENET_DIR} +``` + +## How to End-to-End Fine-tuning +To perform end-to-end fine-tuning for ViT, use our script to convert the pre-trained ViT checkpoint to PLSC DeiT format: + +```bash +python extract_weight.py \ + --input pretrained/checkpoint_0299.pd \ + --output pretrained/moco_vit_base.pdparams +``` + +Then run the training with the converted PLSC format checkpoint: + +```bash +unset PADDLE_TRAINER_ENDPOINTS +export PADDLE_NNODES=1 +export PADDLE_MASTER="127.0.0.1:12538" +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export FLAGS_stop_check_timeout=3600 + +python -m paddle.distributed.launch \ + --nnodes=$PADDLE_NNODES \ + --master=$PADDLE_MASTER \ + --devices=$CUDA_VISIBLE_DEVICES \ + plsc-train \ + -c ./configs/DeiT_base_patch16_224_in1k_1n8c_dp_fp16o1.yaml \ + -o Global.epochs=150 \ + -o Global.pretrained_model=pretrained/moco_vit_base \ + -o Global.finetune=True +``` + +## Models + +### ViT-Base +| Model | Phase | Dataset | Configs | GPUs | Epochs | Top1 Acc | Checkpoint | +| ------------- | ----------- | ------------ | ------------------------------------------------------------ | ---------- | ------ | -------- | ------------------------------------------------------------ | +| moco_vit_base | pretrain | ImageNet2012 | - | A100*N4C32 | 300 | - | [download](https://plsc.bj.bcebos.com/models/mocov3/v2.4/moco_vit_base_in1k_300ep.pd) | +| moco_vit_base | linear prob | ImageNet2012 | - | A100*N1C8 | 90 | 0.7662 | | +| moco_vit_base | finetune | ImageNet2012 | [config](./configs/DeiT_base_patch16_224_in1k_1n8c_dp_fp16o1.yaml) | A100*N1C8 | 150 | 0.8288 | | + +## Citations + +```bibtex +@Article{chen2021mocov3, + author = {Xinlei Chen* and Saining Xie* and Kaiming He}, + title = {An Empirical Study of Training Self-Supervised Vision Transformers}, + journal = {arXiv preprint arXiv:2104.02057}, + year = {2021}, +} +``` diff --git a/task/ssl/mocov3/builder_moco.py b/task/ssl/mocov3/builder_moco.py new file mode 100644 index 0000000000000..614d7654976b5 --- /dev/null +++ b/task/ssl/mocov3/builder_moco.py @@ -0,0 +1,159 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn + + +class MoCo(nn.Layer): + """ + Build a MoCo model with a base encoder, a momentum encoder, and two MLPs + https://arxiv.org/abs/1911.05722 + """ + + def __init__(self, base_encoder, dim=256, mlp_dim=4096, T=1.0): + """ + dim: feature dimension (default: 256) + mlp_dim: hidden dimension in MLPs (default: 4096) + T: softmax temperature (default: 1.0) + """ + super(MoCo, self).__init__() + + self.T = T + + # build encoders + self.base_encoder = base_encoder(num_classes=mlp_dim) + self.momentum_encoder = base_encoder(num_classes=mlp_dim) + + self._build_projector_and_predictor_mlps(dim, mlp_dim) + + for param_b, param_m in zip(self.base_encoder.parameters(), + self.momentum_encoder.parameters()): + param_m.copy_(param_b, False) # initialize + param_m.stop_gradient = True # not update by gradient + + def _build_mlp(self, + num_layers, + input_dim, + mlp_dim, + output_dim, + last_bn=True): + mlp = [] + for l in range(num_layers): + dim1 = input_dim if l == 0 else mlp_dim + dim2 = output_dim if l == num_layers - 1 else mlp_dim + + mlp.append(nn.Linear(dim1, dim2, bias_attr=False)) + + if l < num_layers - 1: + mlp.append(nn.BatchNorm1D(dim2)) + mlp.append(nn.ReLU()) + elif last_bn: + # follow SimCLR's design: https://github.com/google-research/simclr/blob/master/model_util.py#L157 + # for simplicity, we further removed gamma in BN + mlp.append( + nn.BatchNorm1D( + dim2, weight_attr=False, bias_attr=False)) + + return nn.Sequential(*mlp) + + def _build_projector_and_predictor_mlps(self, dim, mlp_dim): + pass + + @paddle.no_grad() + def _update_momentum_encoder(self, m): + """Momentum update of the momentum encoder""" + with paddle.amp.auto_cast(False): + for param_b, param_m in zip(self.base_encoder.parameters(), + self.momentum_encoder.parameters()): + paddle.assign((param_m * m + param_b * (1. - m)), param_m) + + def contrastive_loss(self, q, k): + # normalize + q = nn.functional.normalize(q, axis=1) + k = nn.functional.normalize(k, axis=1) + # gather all targets + k = concat_all_gather(k) + # Einstein sum is more intuitive + logits = paddle.einsum('nc,mc->nm', q, k) / self.T + N = logits.shape[0] # batch size per GPU + labels = (paddle.arange( + N, dtype=paddle.int64) + N * paddle.distributed.get_rank()) + return nn.CrossEntropyLoss()(logits, labels) * (2 * self.T) + + def forward(self, x1, x2, m): + """ + Input: + x1: first views of images + x2: second views of images + m: moco momentum + Output: + loss + """ + + # compute features + q1 = self.predictor(self.base_encoder(x1)) + q2 = self.predictor(self.base_encoder(x2)) + + with paddle.no_grad(): # no gradient + self._update_momentum_encoder(m) # update the momentum encoder + + # compute momentum features as targets + k1 = self.momentum_encoder(x1) + k2 = self.momentum_encoder(x2) + + return self.contrastive_loss(q1, k2) + self.contrastive_loss(q2, k1) + + +class MoCo_ResNet(MoCo): + def _build_projector_and_predictor_mlps(self, dim, mlp_dim): + hidden_dim = self.base_encoder.fc.weight.shape[0] + del self.base_encoder.fc, self.momentum_encoder.fc # remove original fc layer + + # projectors + self.base_encoder.fc = self._build_mlp(2, hidden_dim, mlp_dim, dim) + self.momentum_encoder.fc = self._build_mlp(2, hidden_dim, mlp_dim, dim) + + # predictor + self.predictor = self._build_mlp(2, dim, mlp_dim, dim, False) + + +class MoCo_ViT(MoCo): + def _build_projector_and_predictor_mlps(self, dim, mlp_dim): + hidden_dim = self.base_encoder.head.weight.shape[0] + del self.base_encoder.head, self.momentum_encoder.head # remove original fc layer + + # projectors + self.base_encoder.head = self._build_mlp(3, hidden_dim, mlp_dim, dim) + self.momentum_encoder.head = self._build_mlp(3, hidden_dim, mlp_dim, + dim) + + # predictor + self.predictor = self._build_mlp(2, dim, mlp_dim, dim) + + +# utils +@paddle.no_grad() +def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + """ + if paddle.distributed.get_world_size() < 2: + return tensor + + tensors_gather = [] + paddle.distributed.all_gather(tensors_gather, tensor) + + output = paddle.concat(tensors_gather, axis=0) + return output diff --git a/task/ssl/mocov3/configs/DeiT_base_patch16_224_in1k_1n8c_dp_fp16o1.yaml b/task/ssl/mocov3/configs/DeiT_base_patch16_224_in1k_1n8c_dp_fp16o1.yaml new file mode 100644 index 0000000000000..73937a74e4e1e --- /dev/null +++ b/task/ssl/mocov3/configs/DeiT_base_patch16_224_in1k_1n8c_dp_fp16o1.yaml @@ -0,0 +1,144 @@ +# global configs +Global: + checkpoint: null + pretrained_model: null + output_dir: ./output/ + device: gpu + save_interval: 1 + max_num_latest_checkpoint: 0 + eval_during_train: True + eval_interval: 1 + eval_unit: "epoch" + accum_steps: 1 + epochs: 150 + print_batch_step: 10 + use_visualdl: False + seed: 2022 + +# FP16 setting +FP16: + level: O1 + GradScaler: + init_loss_scaling: 65536.0 + +DistributedStrategy: + data_parallel: True + +# model architecture +Model: + name: DeiT_base_patch16_224 + drop_path_rate : 0.1 + drop_rate : 0.0 + class_num: 1000 + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + Eval: + - CELoss: + weight: 1.0 + +LRScheduler: + name: TimmCosine + learning_rate: 1e-3 + eta_min: 1e-5 + warmup_epoch: 5 + warmup_start_lr: 1e-6 + decay_unit: epoch + +Optimizer: + name: AdamW + betas: (0.9, 0.999) + eps: 1e-8 + weight_decay: 0.05 + no_weight_decay_name: ["cls_token", "pos_embed", "norm", "bias"] + use_master_param: True + exp_avg_force_fp32: True + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageFolder + root: ./dataset/ILSVRC2012/train + transform: + - RandomResizedCrop: + size: 224 + interpolation: bicubic + - RandomHorizontalFlip: + - TimmAutoAugment: + config_str: rand-m9-mstd0.5-inc1 + interpolation: bicubic + img_size: 224 + mean: [0.485, 0.456, 0.406] + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - RandomErasing: + EPSILON: 0.25 + sl: 0.02 + sh: 1.0/3.0 + r1: 0.3 + attempt: 10 + use_log_aspect: True + mode: pixel + - ToCHWImage: + batch_transform: + - TransformOpSampler: + Mixup: + alpha: 0.8 + prob: 0.5 + epsilon: 0.1 + class_num: 1000 + Cutmix: + alpha: 1.0 + prob: 0.5 + epsilon: 0.1 + class_num: 1000 + sampler: + name: RepeatedAugSampler + batch_size: 128 # accum_steps: 1, total batchsize: 1024 + drop_last: False + shuffle: True + loader: + num_workers: 8 + use_shared_memory: True + + Eval: + dataset: + name: ImageFolder + root: ./dataset/ILSVRC2012/val + transform: + - Resize: + size: 256 + interpolation: bicubic + backend: pil + - CenterCrop: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + sampler: + name: DistributedBatchSampler + batch_size: 256 + drop_last: False + shuffle: False + loader: + num_workers: 8 + use_shared_memory: True + +Metric: + Eval: + - TopkAcc: + topk: [1, 5] + +Export: + export_type: paddle + input_shape: [None, 3, 224, 224] diff --git a/task/ssl/mocov3/extract_weight.py b/task/ssl/mocov3/extract_weight.py new file mode 100644 index 0000000000000..b60f3fe3a54b0 --- /dev/null +++ b/task/ssl/mocov3/extract_weight.py @@ -0,0 +1,56 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import paddle + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='Convert MoCo Pre-Traind Model to DEiT') + parser.add_argument( + '--input', + default='', + type=str, + metavar='PATH', + required=True, + help='path to moco pre-trained checkpoint') + parser.add_argument( + '--output', + default='', + type=str, + metavar='PATH', + required=True, + help='path to output checkpoint in DEiT format') + args = parser.parse_args() + print(args) + + # load input + checkpoint = paddle.load(args.input) + state_dict = checkpoint['state_dict'] + for k in list(state_dict.keys()): + # retain only base_encoder up to before the embedding layer + if k.startswith('base_encoder') and not k.startswith( + 'base_encoder.head'): + # remove prefix + state_dict[k[len("base_encoder."):]] = state_dict[k] + # delete renamed or unused k + del state_dict[k] + + # make output directory if necessary + output_dir = os.path.dirname(args.output) + if not os.path.isdir(output_dir): + os.makedirs(output_dir) + # save to output + paddle.save(state_dict, args.output) diff --git a/task/ssl/mocov3/finetune.sh b/task/ssl/mocov3/finetune.sh new file mode 100644 index 0000000000000..9de7fb438f09e --- /dev/null +++ b/task/ssl/mocov3/finetune.sh @@ -0,0 +1,31 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Note: Set the following environment variables +# and then need to run the script on each node. +unset PADDLE_TRAINER_ENDPOINTS +export PADDLE_NNODES=1 +export PADDLE_MASTER="127.0.0.1:12538" +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export FLAGS_stop_check_timeout=3600 + +python -m paddle.distributed.launch \ + --nnodes=$PADDLE_NNODES \ + --master=$PADDLE_MASTER \ + --devices=$CUDA_VISIBLE_DEVICES \ + plsc-train \ + -c ./configs/DeiT_base_patch16_224_in1k_1n8c_dp_fp16o1.yaml \ + -o Global.epochs=150 \ + -o Global.pretrained_model=pretrained/moco_vit_base \ + -o Global.finetune=True diff --git a/task/ssl/mocov3/linprob.sh b/task/ssl/mocov3/linprob.sh new file mode 100644 index 0000000000000..dec2299467d09 --- /dev/null +++ b/task/ssl/mocov3/linprob.sh @@ -0,0 +1,30 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_TRAINER_ENDPOINTS +export PADDLE_NNODES=1 +export PADDLE_MASTER="127.0.0.1:12538" +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export FLAGS_stop_check_timeout=3600 + +IMAGENET_DIR=./dataset/ILSVRC2012/ +python -m paddle.distributed.launch \ + --nnodes=$PADDLE_NNODES \ + --master=$PADDLE_MASTER \ + --devices=$CUDA_VISIBLE_DEVICES \ + main_lincls.py \ + -a moco_vit_base \ + --lr=3 \ + --pretrained pretrained/checkpoint_0299.pd \ + ${IMAGENET_DIR} diff --git a/task/ssl/mocov3/main_lincls.py b/task/ssl/mocov3/main_lincls.py new file mode 100644 index 0000000000000..cd0e527b0c224 --- /dev/null +++ b/task/ssl/mocov3/main_lincls.py @@ -0,0 +1,512 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import builtins +import math +import os +import random +import shutil +import time +import warnings +from functools import partial + +import paddle +import paddle.nn as nn +import paddle.distributed as dist +from plsc.data import preprocess as transforms +from plsc.data import dataset as datasets +from plsc.nn import init +from visualdl import LogWriter as SummaryWriter + +import plsc + +import builder_moco +import vit_moco + +model_names = [ + 'moco_vit_small', 'moco_vit_base', 'moco_vit_conv_small', + 'moco_vit_conv_base' +] + +parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') +parser.add_argument('data', metavar='DIR', help='path to dataset') +parser.add_argument( + '-a', + '--arch', + metavar='ARCH', + default='resnet50', + choices=model_names, + help='model architecture: ' + ' | '.join(model_names) + + ' (default: resnet50)') +parser.add_argument( + '-j', + '--workers', + default=8, + type=int, + metavar='N', + help='number of data loading workers (default: 32)') +parser.add_argument( + '--epochs', + default=90, + type=int, + metavar='N', + help='number of total epochs to run') +parser.add_argument( + '--start-epoch', + default=0, + type=int, + metavar='N', + help='manual epoch number (useful on restarts)') +parser.add_argument( + '-b', + '--batch-size', + default=1024, + type=int, + metavar='N', + help='mini-batch size (default: 1024), this is the total ' + 'batch size of all GPUs on all nodes when ' + 'using Data Parallel or Distributed Data Parallel') +parser.add_argument( + '--lr', + '--learning-rate', + default=0.1, + type=float, + metavar='LR', + help='initial (base) learning rate', + dest='lr') +parser.add_argument( + '--momentum', default=0.9, type=float, metavar='M', help='momentum') +parser.add_argument( + '--wd', + '--weight-decay', + default=0., + type=float, + metavar='W', + help='weight decay (default: 0.)', + dest='weight_decay') +parser.add_argument( + '-p', + '--print-freq', + default=10, + type=int, + metavar='N', + help='print frequency (default: 10)') +parser.add_argument( + '--resume', + default='', + type=str, + metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument( + '-e', + '--evaluate', + dest='evaluate', + action='store_true', + help='evaluate model on validation set') +parser.add_argument( + '--world-size', + default=-1, + type=int, + help='number of nodes for distributed training') +parser.add_argument( + '--rank', default=-1, type=int, help='node rank for distributed training') +parser.add_argument( + '--dist-url', + default='tcp://224.66.41.62:23456', + type=str, + help='url used to set up distributed training') +parser.add_argument( + '--dist-backend', default='nccl', type=str, help='distributed backend') +parser.add_argument( + '--seed', default=None, type=int, help='seed for initializing training. ') +parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.') +parser.add_argument( + '--multiprocessing-distributed', + action='store_true', + help='Use multi-processing distributed training to launch ' + 'N processes per node, which has N GPUs. This is the ' + 'fastest way to use PyTorch for either single node or ' + 'multi node data parallel training') + +# additional configs: +parser.add_argument( + '--pretrained', + default='', + type=str, + help='path to moco pretrained checkpoint') + +best_acc1 = 0 + + +def main(): + args = parser.parse_args() + + if args.seed is not None: + random.seed(args.seed) + paddle.seed(args.seed) + np.random.seed(args.seed) + RELATED_FLAGS_SETTING = {} + RELATED_FLAGS_SETTING['FLAGS_cudnn_deterministic'] = 1 + paddle.fluid.set_flags(RELATED_FLAGS_SETTING) + warnings.warn('You have chosen to seed training. ' + 'This will turn on the CUDNN deterministic setting, ' + 'which can slow down your training considerably! ' + 'You may see unexpected behavior when restarting ' + 'from checkpoints.') + + device = paddle.set_device("gpu") + dist.init_parallel_env() + args.world_size = dist.get_world_size() + args.rank = dist.get_rank() + args.distributed = args.world_size > 1 + + # suppress printing if not first GPU on each node + if args.rank != 0: + + def print_pass(*args): + pass + + builtins.print = print_pass + + global best_acc1 + + # create model + print("=> creating model '{}'".format(args.arch)) + + model = vit_moco.__dict__[args.arch]() + linear_keyword = 'head' + + # freeze all layers but the last fc + for name, param in model.named_parameters(): + if name not in [ + '%s.weight' % linear_keyword, '%s.bias' % linear_keyword + ]: + param.stop_gradient = True + + init.normal_(getattr(model, linear_keyword).weight, mean=0.0, std=0.01) + init.zeros_(getattr(model, linear_keyword).bias) + + # load from pre-trained, before DistributedDataParallel constructor + if args.pretrained: + if os.path.isfile(args.pretrained): + print("=> loading checkpoint '{}'".format(args.pretrained)) + checkpoint = paddle.load(args.pretrained) + + # rename moco pre-trained keys + state_dict = checkpoint['state_dict'] + for k in list(state_dict.keys()): + # retain only base_encoder up to before the embedding layer + if k.startswith('base_encoder') and not k.startswith( + 'base_encoder.%s' % linear_keyword): + # remove prefix + state_dict[k[len("base_encoder."):]] = state_dict[k] + # delete renamed or unused k + del state_dict[k] + + args.start_epoch = 0 + msg = model.set_state_dict(state_dict) + # assert set(msg.missing_keys) == {"%s.weight" % linear_keyword, "%s.bias" % linear_keyword} + + print("=> loaded pre-trained model '{}'".format(args.pretrained)) + else: + print("=> no checkpoint found at '{}'".format(args.pretrained)) + + # infer learning rate before changing batch size + init_lr = args.lr * args.batch_size / 256 + + if args.distributed: + args.batch_size = int(args.batch_size / args.world_size) + model = paddle.DataParallel(model) + + # define loss function (criterion) and optimizer + criterion = nn.CrossEntropyLoss() + + # optimize only the linear classifier + parameters = list( + filter(lambda p: not p.stop_gradient, model.parameters())) + assert len(parameters) == 2 # weight, bias + + optimizer = plsc.optimizer.Momentum( + parameters, + init_lr, + momentum=args.momentum, + weight_decay=args.weight_decay) + + # optionally resume from a checkpoint + if args.resume: + if os.path.isfile(args.resume): + print("=> loading checkpoint '{}'".format(args.resume)) + checkpoint = paddle.load(args.resume) + args.start_epoch = checkpoint['epoch'] + best_acc1 = checkpoint['best_acc1'] + model.set_state_dict(checkpoint['state_dict']) + optimizer.set_state_dict(checkpoint['optimizer']) + print("=> loaded checkpoint '{}' (epoch {})" + .format(args.resume, checkpoint['epoch'])) + else: + print("=> no checkpoint found at '{}'".format(args.resume)) + + # Data loading code + traindir = os.path.join(args.data, 'train') + valdir = os.path.join(args.data, 'val') + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + train_dataset = datasets.ImageFolder(traindir, + transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])) + + train_sampler = paddle.io.DistributedBatchSampler( + train_dataset, shuffle=True, batch_size=args.batch_size) + + train_loader = paddle.io.DataLoader( + train_dataset, + batch_sampler=train_sampler, + num_workers=args.workers, + use_shared_memory=True, ) + + val_dataset = datasets.ImageFolder(valdir, + transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ])) + + val_sampler = paddle.io.BatchSampler( + val_dataset, shuffle=False, batch_size=256, drop_last=False) + + val_loader = paddle.io.DataLoader( + val_dataset, + batch_sampler=val_sampler, + num_workers=args.workers, + use_shared_memory=True, ) + + if args.evaluate: + validate(val_loader, model, criterion, args) + return + + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + train_loader.batch_sampler.set_epoch(epoch) + adjust_learning_rate(optimizer, init_lr, epoch, args) + + # train for one epoch + train(train_loader, model, criterion, optimizer, epoch, args) + + # evaluate on validation set + acc1 = validate(val_loader, model, criterion, args) + + # remember best acc@1 and save checkpoint + is_best = acc1 > best_acc1 + best_acc1 = max(acc1, best_acc1) + + if args.rank == 0 and epoch % 10 == 0 or epoch == args.epochs - 1: # only the first GPU saves checkpoint + save_checkpoint({ + 'epoch': epoch + 1, + 'arch': args.arch, + 'state_dict': model.state_dict(), + 'best_acc1': best_acc1, + 'optimizer': optimizer.state_dict(), + }, is_best) + if epoch == args.start_epoch: + sanity_check(model.state_dict(), args.pretrained, + linear_keyword) + + +def train(train_loader, model, criterion, optimizer, epoch, args): + batch_time = AverageMeter('Time', ':6.3f') + data_time = AverageMeter('Data', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + progress = ProgressMeter( + len(train_loader), [batch_time, data_time, losses, top1, top5], + prefix="Epoch: [{}]".format(epoch)) + """ + Switch to eval mode: + Under the protocol of linear classification on frozen features/models, + it is not legitimate to change any part of the pre-trained model. + BatchNorm in train mode may revise running mean/std (even if it receives + no gradient), which are part of the model parameters too. + """ + model.eval() + + end = time.time() + for i, (images, target) in enumerate(train_loader): + # measure data loading time + data_time.update(time.time() - end) + + # compute output + output = model(images) + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), images.shape[0]) + top1.update(acc1[0].item(), images.shape[0]) + top5.update(acc5[0].item(), images.shape[0]) + + # compute gradient and do SGD step + optimizer.clear_grad() + loss.backward() + optimizer.step() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + progress.display(i) + + +def validate(val_loader, model, criterion, args): + batch_time = AverageMeter('Time', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + progress = ProgressMeter( + len(val_loader), [batch_time, losses, top1, top5], prefix='Test: ') + + # switch to evaluate mode + model.eval() + + with paddle.no_grad(): + end = time.time() + for i, (images, target) in enumerate(val_loader): + + # compute output + output = model(images) + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), images.shape[0]) + top1.update(acc1[0].item(), images.shape[0]) + top5.update(acc5[0].item(), images.shape[0]) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + progress.display(i) + + # TODO: this should also be done with the ProgressMeter + print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format( + top1=top1, top5=top5)) + + return top1.avg + + +def save_checkpoint(state, is_best, filename='checkpoint.pd'): + paddle.save(state, filename) + if is_best: + shutil.copyfile(filename, 'model_best.pd') + + +def sanity_check(state_dict, pretrained_weights, linear_keyword): + """ + Linear classifier should not change any weights other than the linear layer. + This sanity check asserts nothing wrong happens (e.g., BN stats updated). + """ + print("=> loading '{}' for sanity check".format(pretrained_weights)) + checkpoint = paddle.load(pretrained_weights) + state_dict_pre = checkpoint['state_dict'] + + for k in list(state_dict.keys()): + # only ignore linear layer + if '%s.weight' % linear_keyword in k or '%s.bias' % linear_keyword in k: + continue + + # name in pretrained model + k_pre = 'base_encoder.' + k + + assert ((state_dict[k].cpu() == state_dict_pre[k_pre]).all()), \ + '{} is changed in linear classifier training.'.format(k) + + print("=> sanity check passed.") + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + + +class ProgressMeter(object): + def __init__(self, num_batches, meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def display(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print('\t'.join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = '{:' + str(num_digits) + 'd}' + return '[' + fmt + '/' + fmt.format(num_batches) + ']' + + +def adjust_learning_rate(optimizer, init_lr, epoch, args): + """Decay the learning rate based on schedule""" + cur_lr = init_lr * 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) + for param_group in optimizer.param_groups: + param_group['lr'] = cur_lr + + +@paddle.no_grad() +def accuracy(output, target, topk=(1, )): + """Computes the accuracy over the k top predictions for the specified values of k""" + maxk = min(max(topk), output.shape[1]) + batch_size = target.shape[0] + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = ( + pred == target.reshape([1, -1]).expand_as(pred)).astype(paddle.float32) + return [ + correct[:min(k, maxk)].reshape([-1]).sum(0) * 100. / batch_size + for k in topk + ] + + +if __name__ == '__main__': + main() diff --git a/task/ssl/mocov3/main_moco.py b/task/ssl/mocov3/main_moco.py new file mode 100644 index 0000000000000..38b4cdc283aef --- /dev/null +++ b/task/ssl/mocov3/main_moco.py @@ -0,0 +1,473 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import builtins +import math +import os +import random +import shutil +import time +import warnings +from functools import partial + +import paddle +import paddle.nn as nn +import paddle.distributed as dist +from plsc.data import preprocess as transforms +from plsc.data import dataset as datasets +from visualdl import LogWriter as SummaryWriter + +import plsc + +import builder_moco +import vit_moco + +model_names = [ + 'moco_vit_small', 'moco_vit_base', 'moco_vit_conv_small', + 'moco_vit_conv_base' +] + +parser = argparse.ArgumentParser(description='MoCo ImageNet Pre-Training') +parser.add_argument('data', metavar='DIR', help='path to dataset') +parser.add_argument( + '-a', + '--arch', + metavar='ARCH', + default='resnet50', + choices=model_names, + help='model architecture: ' + ' | '.join(model_names) + + ' (default: resnet50)') +parser.add_argument( + '-j', + '--workers', + default=8, + type=int, + metavar='N', + help='number of data loading workers (default: 8)') +parser.add_argument( + '--epochs', + default=100, + type=int, + metavar='N', + help='number of total epochs to run') +parser.add_argument( + '--start-epoch', + default=0, + type=int, + metavar='N', + help='manual epoch number (useful on restarts)') +parser.add_argument( + '-b', + '--batch-size', + default=4096, + type=int, + metavar='N', + help='mini-batch size (default: 4096), this is the total ' + 'batch size of all GPUs on all nodes when ' + 'using Data Parallel or Distributed Data Parallel') +parser.add_argument( + '--lr', + '--learning-rate', + default=0.6, + type=float, + metavar='LR', + help='initial (base) learning rate', + dest='lr') +parser.add_argument( + '--momentum', default=0.9, type=float, metavar='M', help='momentum') +parser.add_argument( + '--wd', + '--weight-decay', + default=1e-6, + type=float, + metavar='W', + help='weight decay (default: 1e-6)', + dest='weight_decay') +parser.add_argument( + '-p', + '--print-freq', + default=10, + type=int, + metavar='N', + help='print frequency (default: 10)') +parser.add_argument( + '--resume', + default='', + type=str, + metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument( + '--world-size', + default=-1, + type=int, + help='number of nodes for distributed training') +parser.add_argument( + '--rank', default=-1, type=int, help='node rank for distributed training') +parser.add_argument( + '--dist-url', + default='tcp://224.66.41.62:23456', + type=str, + help='url used to set up distributed training') +parser.add_argument( + '--dist-backend', default='nccl', type=str, help='distributed backend') +parser.add_argument( + '--seed', default=None, type=int, help='seed for initializing training. ') +parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.') +parser.add_argument( + '--multiprocessing-distributed', + action='store_true', + help='Use multi-processing distributed training to launch ' + 'N processes per node, which has N GPUs. This is the ' + 'fastest way to use PyTorch for either single node or ' + 'multi node data parallel training') + +# moco specific configs: +parser.add_argument( + '--moco-dim', + default=256, + type=int, + help='feature dimension (default: 256)') +parser.add_argument( + '--moco-mlp-dim', + default=4096, + type=int, + help='hidden dimension in MLPs (default: 4096)') +parser.add_argument( + '--moco-m', + default=0.99, + type=float, + help='moco momentum of updating momentum encoder (default: 0.99)') +parser.add_argument( + '--moco-m-cos', + action='store_true', + help='gradually increase moco momentum to 1 with a ' + 'half-cycle cosine schedule') +parser.add_argument( + '--moco-t', + default=1.0, + type=float, + help='softmax temperature (default: 1.0)') + +# vit specific configs: +parser.add_argument( + '--stop-grad-conv1', + action='store_true', + help='stop-grad after first conv, or patch embedding') + +# other upgrades +parser.add_argument( + '--optimizer', + default='lars', + type=str, + choices=['lars', 'adamw'], + help='optimizer used (default: lars)') +parser.add_argument( + '--warmup-epochs', + default=10, + type=int, + metavar='N', + help='number of warmup epochs') +parser.add_argument( + '--crop-min', + default=0.08, + type=float, + help='minimum scale for random cropping (default: 0.08)') + + +def main(): + args = parser.parse_args() + + if args.seed is not None: + random.seed(args.seed) + paddle.seed(args.seed) + np.random.seed(args.seed) + RELATED_FLAGS_SETTING = {} + RELATED_FLAGS_SETTING['FLAGS_cudnn_deterministic'] = 1 + paddle.fluid.set_flags(RELATED_FLAGS_SETTING) + warnings.warn('You have chosen to seed training. ' + 'This will turn on the CUDNN deterministic setting, ' + 'which can slow down your training considerably! ' + 'You may see unexpected behavior when restarting ' + 'from checkpoints.') + + device = paddle.set_device("gpu") + dist.init_parallel_env() + args.world_size = dist.get_world_size() + args.rank = dist.get_rank() + args.distributed = args.world_size > 1 + + # suppress printing if not first GPU on each node + if args.rank != 0: + + def print_pass(*args): + pass + + builtins.print = print_pass + + # create model + print("=> creating model '{}'".format(args.arch)) + + model = builder_moco.MoCo_ViT( + partial( + vit_moco.__dict__[args.arch], + stop_grad_conv1=args.stop_grad_conv1), + args.moco_dim, + args.moco_mlp_dim, + args.moco_t) + + # infer learning rate before changing batch size + args.lr = args.lr * args.batch_size / 256 + + if args.distributed: + # apply SyncBN + model = nn.SyncBatchNorm.convert_sync_batchnorm(model) + + args.batch_size = int(args.batch_size / args.world_size) + model = paddle.DataParallel(model) + + print(model) # print model after SyncBatchNorm + + if args.optimizer == 'lars': + optimizer = plsc.optimizer.MomentumLARS( + model.parameters(), + args.lr, + weight_decay=args.weight_decay, + momentum=args.momentum) + elif args.optimizer == 'adamw': + optimizer = plsc.optimizer.AdamW( + model.parameters(), args.lr, weight_decay=args.weight_decay) + + scaler = paddle.amp.GradScaler( + init_loss_scaling=2.**16, incr_every_n_steps=2000) + + summary_writer = SummaryWriter() if args.rank == 0 else None + + # optionally resume from a checkpoint + if args.resume: + if os.path.isfile(args.resume): + print("=> loading checkpoint '{}'".format(args.resume)) + checkpoint = paddle.load(args.resume) + args.start_epoch = checkpoint['epoch'] + model.set_state_dict(checkpoint['state_dict']) + optimizer.set_state_dict(checkpoint['optimizer']) + scaler.load_state_dict(checkpoint['scaler']) + print("=> loaded checkpoint '{}' (epoch {})" + .format(args.resume, checkpoint['epoch'])) + else: + print("=> no checkpoint found at '{}'".format(args.resume)) + + # Data loading code + traindir = os.path.join(args.data, 'train') + + # follow BYOL's augmentation recipe: https://arxiv.org/abs/2006.07733 + augmentation1 = [ + transforms.RandomResizedCrop( + 224, scale=(args.crop_min, 1.)), + transforms.RandomApply( + [ + transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) # not strengthened + ], + p=0.8), + transforms.RandomGrayscale(p=0.2), + transforms.RandomApply( + [transforms.SimCLRGaussianBlur([.1, 2.])], p=1.0), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + + augmentation2 = [ + transforms.RandomResizedCrop( + 224, scale=(args.crop_min, 1.)), + transforms.RandomApply( + [ + transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) # not strengthened + ], + p=0.8), + transforms.RandomGrayscale(p=0.2), + transforms.RandomApply( + [transforms.SimCLRGaussianBlur([.1, 2.])], p=0.1), + transforms.RandomApply( + [transforms.BYOLSolarize()], p=0.2), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + + train_dataset = datasets.ImageFolder( + traindir, + transforms.TwoViewsTransform( + transforms.Compose(augmentation1), + transforms.Compose(augmentation2))) + + train_sampler = paddle.io.DistributedBatchSampler( + train_dataset, + shuffle=True, + batch_size=args.batch_size, + drop_last=True) + + train_loader = paddle.io.DataLoader( + train_dataset, + batch_sampler=train_sampler, + num_workers=args.workers, + use_shared_memory=True, ) + + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + train_loader.batch_sampler.set_epoch(epoch) + + # train for one epoch + train(train_loader, model, optimizer, scaler, summary_writer, epoch, + args) + + if args.rank == 0 and epoch % 10 == 0 or epoch == args.epochs - 1: # only the first GPU saves checkpoint + save_checkpoint( + { + 'epoch': epoch + 1, + 'arch': args.arch, + 'state_dict': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'scaler': scaler.state_dict(), + }, + is_best=False, + filename='checkpoint_%04d.pd' % epoch) + + if args.rank == 0: + summary_writer.close() + + +def train(train_loader, model, optimizer, scaler, summary_writer, epoch, args): + batch_time = AverageMeter('Time', ':6.3f') + data_time = AverageMeter('Data', ':6.3f') + learning_rates = AverageMeter('LR', ':.4e') + losses = AverageMeter('Loss', ':.4e') + progress = ProgressMeter( + len(train_loader), [batch_time, data_time, learning_rates, losses], + prefix="Epoch: [{}]".format(epoch)) + + # switch to train mode + model.train() + + end = time.time() + iters_per_epoch = len(train_loader) + moco_m = args.moco_m + for i, (images, _) in enumerate(train_loader): + # measure data loading time + data_time.update(time.time() - end) + + # adjust learning rate and momentum coefficient per iteration + lr = adjust_learning_rate(optimizer, epoch + i / iters_per_epoch, args) + learning_rates.update(lr) + if args.moco_m_cos: + moco_m = adjust_moco_momentum(epoch + i / iters_per_epoch, args) + + images[0] = images[0].cuda() + images[1] = images[1].cuda() + + # compute output + with paddle.amp.auto_cast(): + loss = model(images[0], images[1], moco_m) + + losses.update(loss.item(), images[0].shape[0]) + if args.rank == 0: + summary_writer.add_scalar("loss", + loss.item(), epoch * iters_per_epoch + i) + + # compute gradient and do SGD step + optimizer.clear_grad() + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + progress.display(i) + + +def save_checkpoint(state, is_best, filename='checkpoint.pd'): + paddle.save(state, filename) + if is_best: + shutil.copyfile(filename, 'model_best.pd') + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + + +class ProgressMeter(object): + def __init__(self, num_batches, meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def display(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print('\t'.join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = '{:' + str(num_digits) + 'd}' + return '[' + fmt + '/' + fmt.format(num_batches) + ']' + + +def adjust_learning_rate(optimizer, epoch, args): + """Decays the learning rate with half-cycle cosine after warmup""" + if epoch < args.warmup_epochs: + lr = args.lr * epoch / args.warmup_epochs + else: + lr = args.lr * 0.5 * ( + 1. + math.cos(math.pi * (epoch - args.warmup_epochs) / + (args.epochs - args.warmup_epochs))) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + return lr + + +def adjust_moco_momentum(epoch, args): + """Adjust moco momentum based on current epoch""" + m = 1. - 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) * ( + 1. - args.moco_m) + return m + + +if __name__ == '__main__': + main() diff --git a/task/ssl/mocov3/pretrain.sh b/task/ssl/mocov3/pretrain.sh new file mode 100644 index 0000000000000..d8b6ae022e070 --- /dev/null +++ b/task/ssl/mocov3/pretrain.sh @@ -0,0 +1,31 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#unset PADDLE_TRAINER_ENDPOINTS +#export PADDLE_NNODES=4 +#export PADDLE_MASTER="10.67.228.16:12538" +#export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export FLAGS_stop_check_timeout=3600 + +IMAGENET_DIR=./dataset/ILSVRC2012/ +python -m paddle.distributed.launch \ + --nnodes=$PADDLE_NNODES \ + --master=$PADDLE_MASTER \ + --devices=$CUDA_VISIBLE_DEVICES \ + main_moco.py \ + -a moco_vit_base \ + --optimizer=adamw --lr=1.5e-4 --weight-decay=.1 \ + --epochs=300 --warmup-epochs=40 \ + --stop-grad-conv1 --moco-m-cos --moco-t=.2 \ + ${IMAGENET_DIR} diff --git a/task/ssl/mocov3/vit_moco.py b/task/ssl/mocov3/vit_moco.py new file mode 100644 index 0000000000000..697bcf31d1748 --- /dev/null +++ b/task/ssl/mocov3/vit_moco.py @@ -0,0 +1,197 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import paddle +import paddle.nn as nn +from functools import partial, reduce +from operator import mul + +from plsc.models.vision_transformer import VisionTransformer, PatchEmbed, to_2tuple +from plsc.nn import init + + +class VisionTransformerMoCo(VisionTransformer): + def __init__(self, stop_grad_conv1=False, **kwargs): + super().__init__(**kwargs) + # Use fixed 2D sin-cos position embedding + self.build_2d_sincos_position_embedding() + + # weight initialization + for name, m in self.named_sublayers(): + if isinstance(m, nn.Linear): + if 'qkv' in name: + # treat the weights of Q, K, V separately + val = math.sqrt( + 6. / float(m.weight.shape[1] // 3 + m.weight.shape[0])) + init.uniform_(m.weight, -val, val) + else: + init.xavier_uniform_(m.weight) + init.zeros_(m.bias) + init.normal_(self.cls_token, std=1e-6) + + if isinstance(self.patch_embed, PatchEmbed): + # xavier_uniform initialization + val = math.sqrt(6. / float(3 * reduce( + mul, self.patch_embed.patch_size, 1) + self.embed_dim)) + init.uniform_(self.patch_embed.proj.weight, -val, val) + init.zeros_(self.patch_embed.proj.bias) + + if stop_grad_conv1: + self.patch_embed.proj.weight.stop_gradient = True + self.patch_embed.proj.bias.stop_gradient = True + + def build_2d_sincos_position_embedding(self, temperature=10000.): + h = self.patch_embed.img_size[0] // self.patch_embed.patch_size[0] + w = self.patch_embed.img_size[1] // self.patch_embed.patch_size[1] + grid_w = paddle.arange(w, dtype=paddle.float32) + grid_h = paddle.arange(h, dtype=paddle.float32) + grid_w, grid_h = paddle.meshgrid(grid_w, grid_h) + assert self.embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding' + pos_dim = self.embed_dim // 4 + omega = paddle.arange(pos_dim, dtype=paddle.float32) / pos_dim + omega = 1. / (temperature**omega) + + out_w = grid_w.flatten()[..., None] @omega[None] + out_h = grid_h.flatten()[..., None] @omega[None] + pos_emb = paddle.concat( + [ + paddle.sin(out_w), paddle.cos(out_w), paddle.sin(out_h), + paddle.cos(out_h) + ], + axis=1)[None, :, :] + pe_token = paddle.zeros([1, 1, self.embed_dim], dtype=paddle.float32) + + pos_embed = paddle.concat([pe_token, pos_emb], axis=1) + self.pos_embed = self.create_parameter(shape=pos_embed.shape) + self.pos_embed.set_value(pos_embed) + self.pos_embed.stop_gradient = True + + +class ConvStem(nn.Layer): + """ + ConvStem, from Early Convolutions Help Transformers See Better, Tete et al. https://arxiv.org/abs/2106.14881 + """ + + def __init__(self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True): + super().__init__() + + assert patch_size == 16, 'ConvStem only supports patch size of 16' + assert embed_dim % 8 == 0, 'Embed dimension must be divisible by 8 for ConvStem' + + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], + img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + # build stem, similar to the design in https://arxiv.org/abs/2106.14881 + stem = [] + input_dim, output_dim = 3, embed_dim // 8 + for l in range(4): + stem.append( + nn.Conv2D( + input_dim, + output_dim, + kernel_size=3, + stride=2, + padding=1, + bias_attr=False)) + stem.append(nn.BatchNorm2D(output_dim)) + stem.append(nn.ReLU()) + input_dim = output_dim + output_dim *= 2 + stem.append(nn.Conv2D(input_dim, embed_dim, kernel_size=1)) + self.proj = nn.Sequential(*stem) + + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose((0, 2, 1)) # BCHW -> BNC + x = self.norm(x) + return x + + +def moco_vit_small(**kwargs): + model = VisionTransformerMoCo( + patch_size=16, + embed_dim=384, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial( + nn.LayerNorm, epsilon=1e-6), + **kwargs) + return model + + +def moco_vit_base(**kwargs): + model = VisionTransformerMoCo( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial( + nn.LayerNorm, epsilon=1e-6), + **kwargs) + return model + + +def moco_vit_conv_small(**kwargs): + # minus one ViT block + model = VisionTransformerMoCo( + patch_size=16, + embed_dim=384, + depth=11, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial( + nn.LayerNorm, epsilon=1e-6), + embed_layer=ConvStem, + **kwargs) + return model + + +def moco_vit_conv_base(**kwargs): + # minus one ViT block + model = VisionTransformerMoCo( + patch_size=16, + embed_dim=768, + depth=11, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial( + nn.LayerNorm, epsilon=1e-6), + embed_layer=ConvStem, + **kwargs) + return model