In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [0]:
import torch

In [3]:
torch.cuda.is_available()

True

In [0]:
import sys

import torch
from torch.utils.data import DataLoader, SubsetRandomSampler
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR
# from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
import numpy as np

In [0]:
sys.path.append('/content/drive/My Drive/my_framework/facenet-pytorch/')

In [0]:
class EarlyStopCallback(object):
    """ 早停止回调类.多少个epoch没有变好就停止训练.
    """

    def __init__(self, patience=10):
        """
        :param int patience: epoch的数量
        """
        super().__init__()
        self.patience = patience
        self.wait = 0
        # epoch计数，用于后续日志输出
        self.epoch_no = 1
        self.max_metric_value = 0

    def on_valid_end(self, metric_value, metric_key='acc'):
        """
        每次执行验证集的evaluation后会调用。

        :param metric_value 指标值。
        :param str metric_key: 指标key。
        :return:
        """
        print('======epoch : {} , early stopping : {}/{}======'.format(self.epoch_no, self.wait, self.patience))
        print('metric_key : {}, metric_value : {}, max_metric_value:{}'.format(metric_key, metric_value, self.max_metric_value))
        self.epoch_no += 1
        # 判断是否超过上次指标
        is_better_eval = False
        if metric_value > self.max_metric_value:
            is_better_eval = True
            self.max_metric_value = metric_value
            self.wait = 0
        else:
            self.wait += 1
        if not is_better_eval:
            # current result is getting worse
            if self.wait >= self.patience:
                print('reach early stopping patience, stop training.')
                raise Exception("Early stopping raised.")

In [0]:
from models.mtcnn import MTCNN  # noqa
from models.utils import training  # noqa
from models.mtcnn import fixed_image_standardization  # noqa
from models.inception_resnet_v1 import InceptionResnetV1  # noqa

In [8]:
data_dir = '/content/drive/My Drive/data/cv/face/mingxing/train'
batch_size = 32
epochs = 300
workers = 3
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Running on device: {}'.format(device))

Running on device: cuda:0


In [9]:
mtcnn = MTCNN(image_size=160,
              margin=0,
              min_face_size=20,
              thresholds=[0.6, 0.7, 0.7],
              factor=0.709,
              post_process=True,
              device=device)
print('mtcnn:{}'.format(mtcnn))

mtcnn:MTCNN(
  (pnet): PNet(
    (conv1): Conv2d(3, 10, kernel_size=(3, 3), stride=(1, 1))
    (prelu1): PReLU(num_parameters=10)
    (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True)
    (conv2): Conv2d(10, 16, kernel_size=(3, 3), stride=(1, 1))
    (prelu2): PReLU(num_parameters=16)
    (conv3): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))
    (prelu3): PReLU(num_parameters=32)
    (conv4_1): Conv2d(32, 2, kernel_size=(1, 1), stride=(1, 1))
    (softmax4_1): Softmax(dim=1)
    (conv4_2): Conv2d(32, 4, kernel_size=(1, 1), stride=(1, 1))
  )
  (rnet): RNet(
    (conv1): Conv2d(3, 28, kernel_size=(3, 3), stride=(1, 1))
    (prelu1): PReLU(num_parameters=28)
    (pool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
    (conv2): Conv2d(28, 48, kernel_size=(3, 3), stride=(1, 1))
    (prelu2): PReLU(num_parameters=48)
    (pool2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
    (conv3): Conv2d(4

In [10]:
dataset = datasets.ImageFolder(data_dir,
                                transform=transforms.Resize((512, 512)))
dataset.samples = [(p, p.replace(data_dir, data_dir + '_cropped'))
                    for p, _ in dataset.samples]

loader = DataLoader(dataset,
                    num_workers=workers,
                    batch_size=batch_size,
                    collate_fn=training.collate_pil)

for i, (x, y) in enumerate(loader):
    mtcnn(x, save_path=y)
    print('\rBatch {} of {}'.format(i + 1, len(loader)), end='')

# Remove mtcnn to reduce GPU memory usage
del mtcnn
resnet = InceptionResnetV1(
    classify=True,
    pretrained='vggface2',
    num_classes=len(dataset.class_to_idx)
).to(device)
print('resnet:{}'.format(resnet))

Batch 6 of 6resnet:InceptionResnetV1(
  (conv2d_1a): BasicConv2d(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (conv2d_2a): BasicConv2d(
    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (conv2d_2b): BasicConv2d(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (maxpool_3a): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2d_3b): BasicConv2d(
    (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()


In [11]:
# optimizer = optim.Adam(resnet.parameters(), lr=0.001)
print([name for name, param in resnet.named_parameters()])
optim_params = [param for name, param in resnet.named_parameters() if name in {'logits.weight', 'logits.bias'}]
print('optim_params:{}'.format(optim_params))
optimizer = optim.Adam(optim_params)

scheduler = MultiStepLR(optimizer, [5, 10])

trans = transforms.Compose([
    np.float32,
    transforms.ToTensor(),
    fixed_image_standardization
])
dataset = datasets.ImageFolder(data_dir + '_cropped', transform=trans)
img_inds = np.arange(len(dataset))
np.random.shuffle(img_inds)
train_inds = img_inds  # [:int(0.8 * len(img_inds))]
val_inds = img_inds  # [int(0.8 * len(img_inds)):]

train_loader = DataLoader(
    dataset,
    num_workers=workers,
    batch_size=batch_size,
    sampler=SubsetRandomSampler(train_inds)
)
print('train_loader:{}'.format(len(train_loader)))
val_loader = DataLoader(
    dataset,
    num_workers=workers,
    batch_size=batch_size,
    sampler=SubsetRandomSampler(val_inds)
)
print('val_loader:{}'.format(len(val_loader)))
loss_fn = torch.nn.CrossEntropyLoss()
metrics = {
    'fps': training.BatchTimer(),
    'acc': training.accuracy
}
# writer = SummaryWriter()
# writer.iteration, writer.interval = 0, 10

print('\n\nInitial')

['conv2d_1a.conv.weight', 'conv2d_1a.bn.weight', 'conv2d_1a.bn.bias', 'conv2d_2a.conv.weight', 'conv2d_2a.bn.weight', 'conv2d_2a.bn.bias', 'conv2d_2b.conv.weight', 'conv2d_2b.bn.weight', 'conv2d_2b.bn.bias', 'conv2d_3b.conv.weight', 'conv2d_3b.bn.weight', 'conv2d_3b.bn.bias', 'conv2d_4a.conv.weight', 'conv2d_4a.bn.weight', 'conv2d_4a.bn.bias', 'conv2d_4b.conv.weight', 'conv2d_4b.bn.weight', 'conv2d_4b.bn.bias', 'repeat_1.0.branch0.conv.weight', 'repeat_1.0.branch0.bn.weight', 'repeat_1.0.branch0.bn.bias', 'repeat_1.0.branch1.0.conv.weight', 'repeat_1.0.branch1.0.bn.weight', 'repeat_1.0.branch1.0.bn.bias', 'repeat_1.0.branch1.1.conv.weight', 'repeat_1.0.branch1.1.bn.weight', 'repeat_1.0.branch1.1.bn.bias', 'repeat_1.0.branch2.0.conv.weight', 'repeat_1.0.branch2.0.bn.weight', 'repeat_1.0.branch2.0.bn.bias', 'repeat_1.0.branch2.1.conv.weight', 'repeat_1.0.branch2.1.bn.weight', 'repeat_1.0.branch2.1.bn.bias', 'repeat_1.0.branch2.2.conv.weight', 'repeat_1.0.branch2.2.bn.weight', 'repeat_1.0

In [12]:
print('-' * 10)
resnet.eval()
training.pass_epoch(
    resnet, loss_fn, val_loader,
    batch_metrics=metrics, show_running=True, device=device
)

# 早停止
early_stop_callback = EarlyStopCallback(patience=5)

for epoch in range(epochs):
    print('\nEpoch {}/{}'.format(epoch + 1, epochs))
    print('-' * 10)

    resnet.train()
    training.pass_epoch(
        resnet, loss_fn, train_loader, optimizer, scheduler,
        batch_metrics=metrics, show_running=True, device=device
    )

    resnet.eval()
    val_loss, val_metrics = training.pass_epoch(
            resnet, loss_fn, val_loader,
            batch_metrics=metrics, show_running=True, device=device
        )
    val_acc = val_metrics['acc']
    try:
        early_stop_callback.on_valid_end(metric_value=val_acc, metric_key='acc')
    except Exception:
        break
torch.save(resnet, '/content/drive/My Drive/model_save/cv/vggface2_mingxing_finetune.pt')
print('dataset.classes:{}'.format(dataset.classes))
print('dataset.class_to_idx:{}'.format(dataset.class_to_idx))
# writer.close()

----------
Valid |     6/6    | loss:    2.9973 | fps:  545.9679 | acc:    0.0649   

Epoch 1/300
----------
Train |     6/6    | loss:    2.9798 | fps:  190.0909 | acc:    0.1657   
Valid |     6/6    | loss:    2.9459 | fps:  567.8839 | acc:    0.3239   
metric_key : acc, metric_value : 0.3238636255264282, max_metric_value:0

Epoch 2/300
----------
Train |     6/6    | loss:    2.9246 | fps:  198.0014 | acc:    0.4920   
Valid |     6/6    | loss:    2.8875 | fps:  655.5583 | acc:    0.6705   
metric_key : acc, metric_value : 0.6704545617103577, max_metric_value:0.3238636255264282

Epoch 3/300
----------
Train |     6/6    | loss:    2.8710 | fps:  190.5884 | acc:    0.7713   
Valid |     6/6    | loss:    2.8319 | fps:  615.5287 | acc:    0.8778   
metric_key : acc, metric_value : 0.8778409361839294, max_metric_value:0.6704545617103577

Epoch 4/300
----------
Train |     6/6    | loss:    2.8177 | fps:  193.9848 | acc:    0.8887   
Valid |     6/6    | loss:    2.7785 | fps:  643.84

测试

In [0]:
from PIL import Image
from models.mtcnn import MTCNN  # noqa

In [0]:
mtcnn = MTCNN()

In [0]:
model_path = '/content/drive/My Drive/model_save/cv/vggface2_mingxing_finetune.pt'

In [0]:
resnet = torch.load(model_path).to('cpu').eval()
# print('resnet:{}'.format(resnet))

In [17]:
img_path = '/content/drive/My Drive/data/cv/face/mingxing/test/tongliya.jpg'
save_path = '/content/drive/My Drive/data/cv/face/mingxing/test/tongliya_cropped.jpg'
img = Image.open(img_path)
# Get cropped and prewhitened image tensor
img_cropped = mtcnn(img, save_path=save_path)
# Calculate embedding (unsqueeze to add batch dimension)
img_embedding = resnet(img_cropped.unsqueeze(0))
# Or, if using for VGGFace2 classification
resnet.classify = True
img_probs = resnet(img_cropped.unsqueeze(0))
probs = torch.nn.functional.softmax(img_probs, dim=-1)
print('probs:{}'.format(probs))
print('img_probs:{}'.format(probs.max(dim=1)))

probs:tensor([[0.0476, 0.0567, 0.0487, 0.0548, 0.0481, 0.0514, 0.0461, 0.0473, 0.0460,
         0.0454, 0.0441, 0.0489, 0.0540, 0.0645, 0.0622, 0.0503, 0.0421, 0.0475,
         0.0476, 0.0468]], grad_fn=<SoftmaxBackward>)
img_probs:torch.return_types.max(
values=tensor([0.0645], grad_fn=<MaxBackward0>),
indices=tensor([13]))


成龙：img_probs:torch.return_types.max(
values=tensor([0.1486], grad_fn=<MaxBackward0>),
indices=tensor([0]))
胡歌：img_probs:torch.return_types.max(
values=tensor([0.1181], grad_fn=<MaxBackward0>),
indices=tensor([5]))


In [18]:
img_names = dataset.classes
print(img_names)

['chenglong', 'dongxuan', 'guanzhilin', 'gulinazha', 'gutianle', 'huge', 'jindong', 'jingtian', 'lilianjie', 'liming', 'linjunjie', 'liudehua', 'sunli', 'tongliya', 'yangmi', 'zhangmin', 'zhangxueyou', 'zhoujielun', 'zhourunfa', 'zhouxingchi']


In [19]:
right_count = 0
sum_count = 0
for img_id, img_name in enumerate(img_names):
  try:
    img_path = '/content/drive/My Drive/data/cv/face/mingxing/test/{}.jpg'.format(img_name)
    save_path = '/content/drive/My Drive/data/cv/face/mingxing/test/{}_cropped.jpg'.format(img_name)
    img = Image.open(img_path)
    # Get cropped and prewhitened image tensor
    img_cropped = mtcnn(img, save_path=save_path)
    # Calculate embedding (unsqueeze to add batch dimension)
    img_embedding = resnet(img_cropped.unsqueeze(0))
    # Or, if using for VGGFace2 classification
    resnet.classify = True
    img_probs = resnet(img_cropped.unsqueeze(0))
    probs = torch.nn.functional.softmax(img_probs, dim=-1)
    label_prob, label = probs.max(dim=1)
    label_prob = label_prob.item()
    label = label.item()
    # print('probs:{}'.format(probs))
    print('label_prob:{},label:{}'.format(label_prob, label))
    print('img_id:{},label:{}'.format(img_id, label))
    if img_id == label:
      right_count += 1
    sum_count += 1
  except:
    pass

label_prob:0.06498130410909653,label:0
img_id:0,label:0
label_prob:0.06359836459159851,label:1
img_id:1,label:1
label_prob:0.07143895328044891,label:2
img_id:2,label:2
label_prob:0.06662105768918991,label:4
img_id:4,label:4
label_prob:0.06232147291302681,label:5
img_id:5,label:5
label_prob:0.064668670296669,label:6
img_id:6,label:6
label_prob:0.0635976493358612,label:7
img_id:7,label:7
label_prob:0.06679338961839676,label:8
img_id:8,label:8
label_prob:0.061744146049022675,label:9
img_id:9,label:9
label_prob:0.06533089280128479,label:10
img_id:10,label:10
label_prob:0.07053721696138382,label:11
img_id:11,label:11
label_prob:0.06830945611000061,label:12
img_id:12,label:12
label_prob:0.06450928747653961,label:13
img_id:13,label:13
label_prob:0.07102397829294205,label:14
img_id:14,label:14
label_prob:0.06275743246078491,label:15
img_id:15,label:15
label_prob:0.0657811090350151,label:16
img_id:16,label:16
label_prob:0.0658702552318573,label:17
img_id:17,label:17
label_prob:0.068779759109020

In [20]:
print('{}/{}'.format(right_count, sum_count))

19/19


In [0]:
# 预训练微调，测试集与验证集相同：16/19 14/19
# 预训练微调，测试集与验证集8:2切分：16/19