In [1]:
from __future__ import print_function
import os
import cv2
from models.resnet import *
import torch
from torch import nn 
import numpy as np
import time
from config import Config as opt
from torch.nn import DataParallel

In [2]:
class EnDecoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.Encoder = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1, stride=2), # 32x16 => 16x16
            nn.GELU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(16, 2*16, kernel_size=3, padding=1, stride=2), # 32x32 => 16x16
            nn.GELU(),
            nn.Conv2d(2*16, 2*16, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(2*16, 3*16, kernel_size=3, padding=1, stride=2), # 16x16 => 8x8
            nn.GELU(),
            nn.Conv2d(3*16, 3*16, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(3*16, 4*16, kernel_size=3, padding=1, stride=2), # 16x16 => 8x8
            nn.GELU(),
            nn.Conv2d(4*16, 4*16, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(4*16, 5*16, kernel_size=3, padding=1, stride=2), # 8x8 => 4x4
            nn.GELU(),
            nn.Flatten(), # Image grid to single feature vector
            nn.Linear(5*16*16, 512) # 特征向量压缩到512维
        )
        
        self.linear = nn.Sequential(
            nn.Linear(512, 5*16*16),
            nn.GELU()
        )

        self.Decoder = nn.Sequential(
            nn.ConvTranspose2d(5*16, 4*16, kernel_size=3, output_padding=1, padding=1, stride=2), # 4x4 => 8x8
            nn.GELU(),
            nn.Conv2d(4*16, 4*16, kernel_size=3, padding=1),
            nn.GELU(),
            nn.ConvTranspose2d(4*16, 3*16, kernel_size=3, output_padding=1, padding=1, stride=2), # 8x8 => 16x16
            nn.GELU(),
            nn.Conv2d(3*16, 3*16, kernel_size=3, padding=1),
            nn.GELU(),
            nn.ConvTranspose2d(3*16, 2*16, kernel_size=3, output_padding=1, padding=1, stride=2), # 8x8 => 16x16
            nn.GELU(),
            nn.Conv2d(2*16, 2*16, kernel_size=3, padding=1),
            nn.GELU(),
            nn.ConvTranspose2d(2*16, 16, kernel_size=3, output_padding=1, padding=1, stride=2), # 8x8 => 16x16
            nn.GELU(),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.GELU(),
            nn.ConvTranspose2d(16, 1, kernel_size=3, output_padding=1, padding=1, stride=2), # 16x16 => 32x32
            nn.Tanh() # The input images is scaled between -1 and 1, hence the output has to be bounded as well
        )

    def forward(self, x):
        x = self.Encoder(x)
        x = self.linear(x)
        x = x.reshape(x.shape[0], -1, 4, 4)
        x = self.Decoder(x)
        return x

AE = EnDecoder()
AE_path = "AEs/bAE2.pt"
AE.load_state_dict(torch.load(AE_path, map_location=torch.device("cuda:1")))
AE.to(torch.device("cuda:1"))
AE.eval()

EnDecoder(
  (Encoder): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): GELU()
    (2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): GELU()
    (4): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): GELU()
    (6): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): GELU()
    (8): Conv2d(32, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (9): GELU()
    (10): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): GELU()
    (12): Conv2d(48, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (13): GELU()
    (14): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): GELU()
    (16): Conv2d(64, 80, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (17): GELU()
    (18): Flatten(start_dim=1, end_dim=-1)
    (19): Linear(in_features=1280, out_features=512, bias=True)
  )
  (linear): Sequential(
   

In [3]:
def get_lfw_list(pair_list):
    with open(pair_list, 'r') as fd:
        pairs = fd.readlines()
    data_list = []
    for pair in pairs:
        splits = pair.split()

        if splits[0] not in data_list:
            data_list.append(splits[0])

        if splits[1] not in data_list:
            data_list.append(splits[1])
    return data_list

def load_image(img_path):
    image = cv2.imread(img_path, 0)
    if image is None:
        return None
    image = np.dstack((image, np.fliplr(image)))
    image = image.transpose((2, 0, 1))
    image = image[:, np.newaxis, :, :]
    image = image.astype(np.float32, copy=False)
    image -= 127.5
    image /= 127.5
    return image

In [4]:
def get_featurs(AE, model, test_list, batch_size=10):
    images = None
    features = None
    cnt = 0
    for i, img_path in enumerate(test_list):
        image = load_image(img_path)
        if image is None:
            print('read {} error'.format(img_path))

        if images is None:
            images = image
        else:
            images = np.concatenate((images, image), axis=0)

        if images.shape[0] % batch_size == 0 or i == len(test_list) - 1:
            cnt += 1

            data = torch.from_numpy(images)
            data = data.to(torch.device("cuda:1"))
            data = AE(data)
            output = model(data)
            output = output.data.cpu().numpy()

            fe_1 = output[::2]
            fe_2 = output[1::2]
            feature = np.hstack((fe_1, fe_2))
            # print(feature.shape)

            if features is None:
                features = feature
            else:
                features = np.vstack((features, feature))

            images = None

    return features, cnt

In [5]:
def get_feature_dict(test_list, features):
    fe_dict = {}
    for i, each in enumerate(test_list):
        # key = each.split('/')[1]
        fe_dict[each] = features[i]
    return fe_dict

def cosin_metric(x1, x2):
    return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))

def cal_accuracy(y_score, y_true):
    y_score = np.asarray(y_score)
    y_true = np.asarray(y_true)
    best_acc = 0
    best_th = 0
    for i in range(len(y_score)):
        th = y_score[i]
        y_test = (y_score >= th)
        acc = np.mean((y_test == y_true).astype(int))
        if acc > best_acc:
            best_acc = acc
            best_th = th

    return (best_acc, best_th)

def test_performance(fe_dict, pair_list):
    with open(pair_list, 'r') as fd:
        pairs = fd.readlines()

    sims = []
    labels = []
    for pair in pairs:
        splits = pair.split()
        fe_1 = fe_dict[splits[0]]
        fe_2 = fe_dict[splits[1]]
        label = int(splits[2])
        sim = cosin_metric(fe_1, fe_2)

        sims.append(sim)
        labels.append(label)

    acc, th = cal_accuracy(sims, labels)
    return acc, th

def lfw_test(AE, model, img_paths, identity_list, compair_list, batch_size):
    s = time.time()
    features, cnt = get_featurs(AE, model, img_paths, batch_size=batch_size)
    print(features.shape)
    t = time.time() - s
    print('total time is {}, average time is {}'.format(t, t / cnt))
    fe_dict = get_feature_dict(identity_list, features)
    acc, th = test_performance(fe_dict, compair_list)
    print('lfw face verification accuracy: ', acc, 'threshold: ', th)
    return acc

In [6]:
def load_model(model, model_path):
    model_dict = model.state_dict()
    pretrained_dict = torch.load(model_path)
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

In [7]:
use_se = False
test_model_path = opt.test_model_path
lfw_test_list = opt.lfw_test_list
lfw_root = opt.lfw_root

model = resnet_face18(use_se)
model = DataParallel(model)
load_model(model, opt.load_model_path)
model.to(torch.device("cuda"))

identity_list = get_lfw_list(lfw_test_list)
img_paths = [os.path.join(lfw_root, each) for each in identity_list]

model.eval()
lfw_test(AE, model, img_paths, identity_list, lfw_test_list, 20)

(7701, 1024)
total time is 46.41174650192261, average time is 0.06019681777162465
lfw face verification accuracy:  0.9626666666666667 threshold:  0.2261485


0.9626666666666667