In [1]:
import torch
import torch.optim as optim

from torch.utils.data import DataLoader, Dataset, random_split

import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid


from configs.config import config
from nets.siamese_net import SiameseNet, ContrastDataset, ContrastLoss
from utils.utils import plot_loss, imshow

# from sklearn.model_selection import train_test_split

# 1 训练数据集

## 1.1 数据集

训练集：CASIA-WebFace

验证集：ORL-Face

In [2]:
train_set = ImageFolder(root=config["train_set_root"])

n_train_identities = len(train_set.classes)
print("number of casia identities: {}".format(len(train_set.classes)))
print("number of casia faces: {}".format(len(train_set)))


number of casia identities: 10575
number of casia faces: 494414


In [3]:
val_set = ImageFolder(root=config["test_set_root"])

n_val_identities = len(val_set.classes)
print("number of orl identities: {}".format(len(val_set.classes)))
print("number of orl faces: {}".format(len(val_set)))


number of orl identities: 40
number of orl faces: 400


## 1.2 对比数据集

In [5]:
transform = transforms.Compose([transforms.Resize((100, 100)),
                                transforms.RandomRotation(10),
                                transforms.RandomCrop((90, 90)),
                                transforms.ToTensor()])
train_contrast_set = ContrastDataset(img_folder_dataset=train_set,
                                     transform=transform)

val_contrast_set = ContrastDataset(img_folder_dataset=val_set,
                                   transform=None)


## 1.3 加载器

In [None]:
train_set_loader = DataLoader(dataset=train_contrast_set,
                              batch_size=config["train_batch_size"],
                              shuffle=True,
                              num_workers=0)
val_set_loader = DataLoader(dataset=val_contrast_set,
                            batch_size=config["train_batch_size"],
                            shuffle=True,
                            num_workers=0)


In [None]:
# gpu
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# network
net = SiameseNet(dim_embedding=config["dim_embedding"],
                 is_rgb=config["is_rgb"])
net.to(device)

# loss
criterion = ContrastLoss()
optimizer = optim.Adam(net.parameters(), lr=1e-4)

# train
counter = []
loss_history = []

for epoch in range(config["train_epochs"]):

    for idx, data in enumerate(train_set_loader):

        img1s, img2s, labels = data[0].to(device), data[1].to(device), data[2].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        embedding1s, embedding2s = net(img1s, img2s)
        train_loss = criterion(embedding1s, embedding2s, labels)
        train_loss.backward()
        optimizer.step()

        if idx % 10 == 0 :
            print("Epoch number {}\n Current loss {}\n".format(epoch, train_loss.item()))
            counter.append(idx + epoch * len_contrast_set)
            loss_history.append(train_loss.item())

plot_loss(counter, loss_history)

# 保存模型
torch.save(net.state_dict(), "best.siamese.ph")
