In [1]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid

from configs.config import config
from nets.siamese_net import SiameseNet, ContrastDataset, ContrastLoss
from utils.utils import plot_history, imshow

import pickle
import time
import copy
from tqdm import tqdm

# from sklearn.model_selection import train_test_split

# 1 训练数据集

## 1.1 数据集

训练集：CASIA-WebFace

验证集：ORL-Face

In [2]:
image_datasets = {x: ImageFolder(root=config[x + "_set_root"],
                                 transform=None)
                  for x in ["train", "val"]}

dataset_size = {x: len(image_datasets[x]) for x in ["train", "val"]}
class_names = image_datasets["train"].classes

In [3]:
print("number of train identities: {}".format(len(image_datasets["train"].classes)))
print("number of train faces: {}".format(dataset_size["train"]))
print("number of val identities: {}".format(dataset_size["val"]))
print("number of val faces: {}".format(len(image_datasets["val"].classes)))

number of train identities: 10575
number of train faces: 494414
number of val identities: 400
number of val faces: 40


## 1.2 对比数据集

In [4]:
data_transforms = {
    "train": transforms.Compose([
        transforms.RandomRotation(10),
        transforms.Resize(100),
        transforms.RandomResizedCrop(90),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    "val": transforms.Compose([
        transforms.Resize(100),
        transforms.CenterCrop(90),
        transforms.ToTensor(),
        # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

contrast_datasets = {x: ContrastDataset(img_folder_dataset=image_datasets[x],
                                        transform=data_transforms[x])
                     for x in ["train", "val"]}


## 1.3 加载器

In [5]:
data_loaders = {x: torch.utils.data.DataLoader(dataset=contrast_datasets[x],
                                               batch_size=config[x + "_batch_size"],
                                               shuffle=True,
                                               num_workers=0)
                for x in ["train", "val"]}
dataset_size = {x: len(image_datasets[x]) for x in ["train", "val"]}
class_names = image_datasets["train"].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 2 模型

## 2.1 孪生网络

In [6]:
# gpu
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# network
net = SiameseNet(dim_embedding=config["dim_embedding"],
                 is_rgb=config["is_rgb"])
net.to(device)

SiameseNet(
  (res_net_50): ResNet(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace)
    (max_pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (down_sample): Sequential(
          (0

## 2.2 对比损失

In [7]:
# loss
criterion = ContrastLoss()
optimizer = optim.Adam(net.parameters(), lr=1e-3)

# Decay LR by a factor of 0.5 every 10 epochs
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)

## 2.3 训练

In [8]:
def train_model(model, data_loaders, criterion, optimizer, scheduler, num_epochs=25,
                early_stopping_patience=None,
                reduce_lr_on_plateau=None):

    history = dict(epoch=[],
                   train_loss=[],
                   val_loss=[])

    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 0

    if early_stopping_patience is not None:
        early_stopping_cnt = 0

    if reduce_lr_on_plateau is not None:
        reduce_lr_on_plateau_cnt = 0

    for epoch in range(num_epochs):
        print("-" * 10)
        print("Epoch {}/{}".format(epoch, num_epochs - 1))
        print("-" * 10)

        # Each epoch has a training and validation phase
        for phase in ["train", "val"]:
            if phase == "train":
                scheduler.step()
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0

            # progress bar
            pbar = tqdm(total=len(data_loaders[phase]),
                        desc=phase,
                        ascii=True)

            # Iterate over data.
            for data in data_loaders[phase]:
                img1s = data[0].to(device)
                img2s = data[1].to(device)
                labels = data[2].to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == "train"):
                    embedding1s, embedding2s = model(img1s, img2s)
                    loss = criterion(embedding1s, embedding2s, labels)

                    # backward + optimize only if in training phase
                    if phase == "train":
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * labels.size(0)
                pbar.update(1)

            epoch_loss = running_loss / dataset_size[phase]
            pbar.close()

            print("{} Loss: {:.4f}".format(
                phase, epoch_loss))

            # history
            if phase == "train":
                history["epoch"].append(epoch)
                history["train_loss"].append(epoch_loss)
            elif phase == "val":
                history["val_loss"].append(epoch_loss)
            else:
                pass

            # early stopping
            if early_stopping_patience is not None:
                if phase == "val" and epoch_loss >= best_loss:
                    early_stopping_cnt += 1
                elif phase == "val" and epoch_loss < best_loss:
                    early_stopping_cnt = 0
                else:
                    pass

                if early_stopping_cnt >= early_stopping_patience:
                    print("Early Stopping...")
                    # load best model weights
                    model.load_state_dict(best_model_wts)
                    return model, history

            # reduce lr on plateau
            if reduce_lr_on_plateau is not None:
                if phase == "val" and epoch_loss >= best_loss:
                    reduce_lr_on_plateau_cnt += 1
                elif phase == "val" and epoch_loss < best_loss:
                    reduce_lr_on_plateau_cnt = 0
                else:
                    pass

                if reduce_lr_on_plateau_cnt >= reduce_lr_on_plateau["patience"]:
                    reduce_lr_on_plateau_cnt = 0
                    print("Error Plateau, Reducing the Learning Rate...")
                    for param_group in optimizer.param_groups:
                        param_group["lr"] *= reduce_lr_on_plateau["factor"]
                    print("Learning Rate: {}".format(param_group["lr"]))

            # best save according to val_loss
            if phase == "val" and epoch_loss < best_loss:
                print("Best Save...")
                best_loss = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(model.state_dict(),
                           "./output/best_model-epoch_{}-val_loss_{:.4f}.pth".format(
                               epoch, epoch_loss))
                print("./output/best_model-epoch_{}-val_loss_{:.4f}.pth".format(
                    epoch, epoch_loss))

        print("\n\n")

    time_elapsed = time.time() - since
    print("Training complete in {:.0f}m {:.0f}s".format(
        time_elapsed // 60, time_elapsed % 60))
    print("Best val Loss: {:4f}".format(best_loss))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, history



In [9]:
net, history = train_model(model=net,
                           data_loaders=data_loaders,
                           criterion=criterion,
                           optimizer=optimizer,
                           scheduler=exp_lr_scheduler,
                           num_epochs=config["train_epochs"],
                           early_stopping_patience=config["early_stopping_patience"],
                           reduce_lr_on_plateau = config["reduce_lr_on_plateau"])

with open("./output/history.pickle", "wb") as fw:
    pickle.dump(history, fw)
    
plot_history(history, "./output/history.png")

----------
Epoch 0/99
----------


train: 100%|#####################################################################| 7726/7726 [4:49:38<00:00,  1.71s/it]


train Loss: 1.0318


val: 100%|###############################################################################| 1/1 [00:04<00:00,  4.59s/it]


val Loss: 1.1042
Best Save...
./output/best_model-epoch_0-val_loss_1.1042.pth



----------
Epoch 1/99
----------


train: 100%|#####################################################################| 7726/7726 [3:47:36<00:00,  1.20s/it]


train Loss: 1.0126


val: 100%|###############################################################################| 1/1 [00:04<00:00,  4.09s/it]


val Loss: 1.0909



----------
Epoch 2/99
----------


train: 100%|#####################################################################| 7726/7726 [3:18:01<00:00,  1.24s/it]


train Loss: 1.0092


val: 100%|###############################################################################| 1/1 [00:03<00:00,  3.57s/it]


val Loss: 1.0723



----------
Epoch 3/99
----------


train: 100%|#####################################################################| 7726/7726 [3:19:00<00:00,  1.18s/it]


train Loss: 1.0071


val: 100%|###############################################################################| 1/1 [00:03<00:00,  3.44s/it]


val Loss: 1.0567



----------
Epoch 4/99
----------


train: 100%|#####################################################################| 7726/7726 [3:37:45<00:00,  1.39s/it]


train Loss: 1.0064


val: 100%|###############################################################################| 1/1 [00:04<00:00,  4.34s/it]


val Loss: 1.0493



----------
Epoch 5/99
----------


train: 100%|#####################################################################| 7726/7726 [3:42:43<00:00,  1.58s/it]


train Loss: 1.0057


val: 100%|###############################################################################| 1/1 [00:04<00:00,  4.27s/it]


val Loss: 1.0585
Error Plateau, Reducing the Learning Rate...
Learning Rate: 0.0002



----------
Epoch 6/99
----------


train: 100%|#####################################################################| 7726/7726 [4:08:25<00:00,  1.42s/it]


train Loss: 1.0035


val: 100%|###############################################################################| 1/1 [00:04<00:00,  4.22s/it]


val Loss: 1.0554



----------
Epoch 7/99
----------


train: 100%|#####################################################################| 7726/7726 [3:56:24<00:00,  1.63s/it]


train Loss: 1.0031


val: 100%|###############################################################################| 1/1 [00:04<00:00,  4.56s/it]


val Loss: 1.0572



----------
Epoch 8/99
----------


train: 100%|#####################################################################| 7726/7726 [3:17:54<00:00,  1.14s/it]


train Loss: 1.0031


val: 100%|###############################################################################| 1/1 [00:03<00:00,  3.67s/it]


val Loss: 1.0521



----------
Epoch 9/99
----------


train: 100%|#####################################################################| 7726/7726 [3:29:14<00:00,  1.30s/it]


train Loss: 1.0028


val: 100%|###############################################################################| 1/1 [00:03<00:00,  3.53s/it]


val Loss: 1.0843



----------
Epoch 10/99
----------


train:  75%|###################################################7                 | 5791/7726 [2:27:39<48:59,  1.52s/it]

RuntimeError: cuda runtime error (4) : unspecified launch failure at C:/w/1/s/tmp_conda_3.7_044431/conda/conda-bld/pytorch_1556686009173/work/aten/src\THC/generic/THCTensorMathPointwise.cu:552