In [1]:
from __future__ import print_function
import os
from dataset import Dataset
import torch
from torch.utils import data
import torch.nn.functional as F
from models.focal_loss import *
from models.metrics import *
from models.resnet import *
import torchvision
import torch
import numpy as np
import random
import time
from config import Config
from torch.nn import DataParallel
from torch.optim.lr_scheduler import StepLR
from test import *

In [2]:
def save_model(model, save_path, name, iter_cnt):
    save_name = os.path.join(save_path, name + '_' + str(iter_cnt) + '.pth')
    torch.save(model.state_dict(), save_name)
    return save_name

In [3]:
opt = Config()
device = torch.device("cuda")

In [4]:
train_dataset = Dataset(opt.web_root, opt.web_train_list, phase='train', input_shape=opt.input_shape)
trainloader = data.DataLoader(train_dataset,
                                batch_size=opt.train_batch_size,
                                shuffle=True,
                                num_workers=opt.num_workers)

In [5]:
identity_list = get_lfw_list(opt.lfw_test_list)
img_paths = [os.path.join(opt.lfw_root, each) for each in identity_list]

In [6]:
print('{} train iters per epoch:'.format(len(trainloader)))

3143 train iters per epoch:


In [7]:
if opt.loss == 'focal_loss':
    criterion = FocalLoss(gamma=2)
else:
    criterion = torch.nn.CrossEntropyLoss()

In [8]:
if opt.backbone == 'resnet18':
    model = resnet_face18(use_se=opt.use_se)
elif opt.backbone == 'resnet34':
    model = resnet34()
elif opt.backbone == 'resnet50':
    model = resnet50()

In [9]:
if opt.metric == 'add_margin':
    metric_fc = AddMarginProduct(512, opt.num_classes, s=30, m=0.35)
elif opt.metric == 'arc_margin':
    metric_fc = ArcMarginProduct(512, opt.num_classes, s=30, m=0.5, easy_margin=opt.easy_margin)
elif opt.metric == 'sphere':
    metric_fc = SphereProduct(512, opt.num_classes, m=4)
else:
    metric_fc = nn.Linear(512, opt.num_classes)

In [10]:
# view_model(model, opt.input_shape)
# print(model)
#model = DataParallel(model)
#load_model(model, opt.load_model_path)
model.to(device)
metric_fc.to(device)
#metric_fc = DataParallel(metric_fc)

ArcMarginProduct()

In [11]:
if opt.optimizer == 'sgd':
    optimizer = torch.optim.SGD([{'params': model.parameters()}, {'params': metric_fc.parameters()}],
                                lr=opt.lr, weight_decay=opt.weight_decay)
    #optimizer = torch.optim.SGD([{'params': metric_fc.parameters()}],
    #                            lr=opt.lr, weight_decay=opt.weight_decay)
else:
    optimizer = torch.optim.Adam([{'params': model.parameters()}, {'params': metric_fc.parameters()}],
                                    lr=opt.lr, weight_decay=opt.weight_decay)
scheduler = StepLR(optimizer, step_size=opt.lr_step, gamma=0.1)

In [12]:
start = time.time()
for i in range(opt.max_epoch):
    scheduler.step()

    #model.train()
    for ii, data in enumerate(trainloader):
        data_input, label = data
        data_input = data_input.to(device)
        label = label.to(device).long()
        feature = model(data_input)
        output = metric_fc(feature, label)
        loss = criterion(output, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        iters = i * len(trainloader) + ii

        if iters % opt.print_freq == 0:
            output = output.data.cpu().numpy()
            output = np.argmax(output, axis=1)
            label = label.data.cpu().numpy()
            # print(output)
            # print(label)
            acc = np.mean((output == label).astype(int))
            speed = opt.print_freq / (time.time() - start)
            time_str = time.asctime(time.localtime(time.time()))
            print('{} train epoch {} iter {} {} iters/s loss {} acc {}'.format(time_str, i, ii, speed, loss.item(), acc))
            #if opt.display:
            #    visualizer.display_current_results(iters, loss.item(), name='train_loss')
            #    visualizer.display_current_results(iters, acc, name='train_acc')

            start = time.time()

    if i % opt.save_interval == 0 or i == opt.max_epoch:
        save_model(model, opt.checkpoints_path, opt.backbone, i)
        save_model(metric_fc, opt.checkpoints_path, opt.metric, i)

    model.eval()
    acc = lfw_test(model, img_paths, identity_list, opt.lfw_test_list, opt.test_batch_size)
    #if opt.display:
    #    visualizer.display_current_results(iters, acc, name='test_acc')



Fri May 13 15:05:37 2022 train epoch 0 iter 0 362.11667323678626 iters/s loss 24.60559844970703 acc 0.0
Fri May 13 15:07:48 2022 train epoch 0 iter 500 3.8248336521308732 iters/s loss 22.752410888671875 acc 0.0
Fri May 13 15:09:59 2022 train epoch 0 iter 1000 3.7944826355374506 iters/s loss 21.197765350341797 acc 0.0
Fri May 13 15:12:11 2022 train epoch 0 iter 1500 3.788774766168061 iters/s loss 19.463258743286133 acc 0.0
Fri May 13 15:14:23 2022 train epoch 0 iter 2000 3.7902841220221815 iters/s loss 18.375158309936523 acc 0.0
Fri May 13 15:16:35 2022 train epoch 0 iter 2500 3.787931688423978 iters/s loss 16.258466720581055 acc 0.0
Fri May 13 15:18:47 2022 train epoch 0 iter 3000 3.7882527289966603 iters/s loss 16.369417190551758 acc 0.015625
(7701, 1024)
total time is 14.956380844116211, average time is 0.058196034412903545
lfw face verification accuracy:  0.9031666666666667 threshold:  0.3820949
Fri May 13 15:21:16 2022 train epoch 1 iter 357 3.361581873404786 iters/s loss 15.814413