## 特征提取

In [2]:
import os, sys, codecs
import glob
import pandas as pd
import numpy as np
import pickle
from PIL import Image
from tqdm import tqdm

import cv2

from sklearn.preprocessing import normalize as sknormalize
from sklearn.decomposition import PCA

import torch
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data.dataset import Dataset

import logging
logging.basicConfig(level = logging.DEBUG, filename = 'example.log',
                    format = '%(asctime)s - %(filename)s[line:%(lineno)d]: %(message)s')  # 

PATH = '/home/wx/work/video_copy_detection/'
TRAIN_PATH = PATH + 'train/'
TEST_PATH = PATH + 'test/'
TRAIN_QUERY_PATH = TRAIN_PATH + 'query/'
REFER_PATH = TRAIN_PATH + 'refer/'
TRAIN_QUERY_FRAME_PATH = TRAIN_PATH + 'query_frame/'
REFER_FRAME_PATH = TRAIN_PATH + 'refer_frame/'
TEST_QUERY_PATH = TEST_PATH + 'query/'
TEST_QUERY_FRAME_PATH = TEST_PATH + 'query_frame/'
CODE_DIR = PATH + 'code/'

In [5]:
class QRDataset(Dataset):
    def __init__(self, img_path, transform = None):
        self.img_path = img_path

        self.img_label = np.zeros(len(img_path))
    
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None
    
    def __getitem__(self, index):
        img = Image.open(self.img_path[index])
        
        if self.transform is not None:
            img = self.transform(img)
        
        return img, self.img_path[index]

    def __len__(self):
        return len(self.img_path)
# 比res50效果好一些
model = models.resnet18(pretrained = True)
# res18 bolck.expansion = 1
model.fc = nn.Linear(512, 1024)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
transformer = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(),
    # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

def extract_feature(path):
    if not isinstance(path, list):
        path = [path]
    
    data_loader = torch.utils.data.DataLoader(
        QRDataset(path, transformer), batch_size = 40, shuffle = False, num_workers = 16)

    img_feature = []
    with torch.no_grad():
        for batch_data in tqdm(data_loader):
            batch_x, batch_y = batch_data
            
            batch_x = Variable(batch_x, requires_grad = False).cuda()
            feature_pred = model(batch_x)

            # max-pooling
            # feature_pred = F.max_pool2d(feature_pred, kernel_size=(24, 32))
            
            # ave-pooling
            # feature_pred = F.avg_pool2d(feature_pred, kernel_size=(24, 32))[:, :, 0, 0]
            
            #print(feature_pred.shape, batch_x.shape)
            feature_pred = feature_pred.data.cpu().numpy()
            # feature_pred = feature_pred.max(-1).max(-1)
            
            # feature_pred = feature_pred.reshape((-1, 512))
            img_feature.append(feature_pred)
            
            del feature_pred
            # img_feature.append(feature_pred)
            
    img_feature = np.vstack(img_feature)
    return img_feature

In [6]:
# 读取 test_query 视频的关键帧，并按照视频和关键帧时间进行排序
test_query_imgs_path = []
for id in pd.read_csv(TEST_PATH + 'submit_example.csv')['query_id']:
    test_query_imgs_path += glob.glob(TEST_QUERY_FRAME_PATH + id + '/*.jpg')

test_query_imgs_path.sort(key = lambda x: x.lower())

In [7]:
# 读取 train_query 视频的关键帧，并按照视频和关键帧时间进行排序
train_query_imgs_path = []
for id in pd.read_csv(TRAIN_PATH + 'train.csv')['query_id']:
    train_query_imgs_path += glob.glob(TRAIN_QUERY_FRAME_PATH + id + '/*.jpg')

train_query_imgs_path.sort(key = lambda x: x.lower())

In [8]:
# 读取 refer 视频的关键帧，并按照视频和关键帧时间进行排序

refer_imgs_path = glob.glob(REFER_FRAME_PATH + '*/*.jpg')
refer_imgs_path.sort(key = lambda x: x.lower())

In [9]:
# 抽取 test_query 关键帧特征
test_query_features = extract_feature(test_query_imgs_path[:])

100%|██████████| 1564/1564 [01:05<00:00, 23.71it/s]


In [10]:
# 抽取 train_query 关键帧特征
train_query_features = extract_feature(train_query_imgs_path[:])

100%|██████████| 3128/3128 [02:12<00:00, 23.57it/s]


In [11]:
# 抽取 refer 关键帧特征
refer_features = extract_feature(list(refer_imgs_path[:]))

100%|██████████| 4527/4527 [03:17<00:00, 22.91it/s]


In [12]:
def normalize(x, copy = False):
    """
    A helper function that wraps the function of the same name in sklearn.
    This helper handles the case of a single column vector.
    """
    if type(x) == np.ndarray and len(x.shape) == 1:
        return np.squeeze(sknormalize(x.reshape(1, -1), copy = copy))
        #return np.squeeze(x / np.sqrt((x ** 2).sum(-1))[..., np.newaxis])
    else:
        return sknormalize(x, copy = copy)
        #return x / np.sqrt((x ** 2).sum(-1))[..., np.newaxis]

In [13]:
# PCA 降维
'''
pca = PCA(n_components=512)

train_query_features = pca.fit_transform(train_query_features)
test_query_features = pca.fit_transform(test_query_features)
refer_features = pca.fit_transform(refer_features)
'''

'\npca = PCA(n_components=512)\n\ntrain_query_features = pca.fit_transform(train_query_features)\ntest_query_features = pca.fit_transform(test_query_features)\nrefer_features = pca.fit_transform(refer_features)\n'

In [14]:
# L2正则化
train_query_features = normalize(train_query_features)
test_query_features = normalize(test_query_features)
refer_features = normalize(refer_features)

In [15]:
# 保存 test_query 关键帧特征

with open(PATH + 'var/test_query_features.pk', 'wb') as pk_file:
    pickle.dump(test_query_features, pk_file)

In [16]:
# 保存 train_query 关键帧特征

with open(PATH + 'var/train_query_features.pk', 'wb') as pk_file:
    pickle.dump(train_query_features, pk_file)

In [17]:
# 保存 refer 关键帧特征

with open(PATH + 'var/refer_features.pk', 'wb') as pk_file:
    pickle.dump(refer_features, pk_file)