In [9]:
import argparse
import logging
import os
import numpy as np
from PIL import Image
import random
from tqdm import tqdm

import torch

logging.basicConfig(level=logging.INFO)

In [16]:
parser = argparse.ArgumentParser(description='PyTorch HardNet')
parser.add_argument('--gpu-id', default='3', type=str)
parser.add_argument('--no-cuda', action='store_true', default=False)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--data-dir', type=str)
parser.add_argument('--data-file', type=str)
parser.add_argument('--image-extension', type=str)
parser.add_argument('--n-patches', type=int)
parser.add_argument('--info-file', type=str)
parser.add_argument('--matches-files', type=str)

args = parser.parse_args([
    '--data-dir', 'data/liberty',
    '--info-file', 'info.txt',
    '--matches-files', 'm50_100000_100000_0.txt',
    '--image-extension', 'bmp',
    '--n-patches', '450092',

    '--data-file', 'liberty.pt'])

In [17]:
gpu_id = 3

os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id

args.cuda = not args.no_cuda and torch.cuda.is_available()

logging.info(("NOT " if not args.cuda else "") + "Using cuda")

# set random seeds
random.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)

INFO:root:Using cuda


In [18]:
def read_image_file(data_dir, image_ext, n):
    """Return a Tensor containing the patches
    """

    def PIL2array(_img):
        """Convert PIL image type to numpy 2D array
        """
        return np.array(_img.getdata(), dtype=np.uint8).reshape(64, 64)

    def find_files(_data_dir, _image_ext):
        """Return a list with the file names of the images containing the patches
        """
        files = []
        # find those files with the specified extension
        for file_dir in os.listdir(_data_dir):
            if file_dir.endswith(_image_ext):
                files.append(os.path.join(_data_dir, file_dir))
        return sorted(files)  # sort files in ascend order to keep relations

    patches = []
    list_files = find_files(data_dir, image_ext)

    for fpath in tqdm(list_files):
        img = Image.open(fpath)
        for y in range(0, 1024, 64):
            for x in range(0, 1024, 64):
                patch = img.crop((x, y, x + 64, y + 64))
                patches.append(PIL2array(patch))
    return torch.ByteTensor(np.array(patches[:n]))


def read_info_file(data_dir, info_file):
    """Return a Tensor containing the list of labels
       Read the file and keep only the ID of the 3D point.
    """
    labels = []
    with open(os.path.join(data_dir, info_file), 'r') as f:
        labels = [int(line.split()[0]) for line in f]
    return torch.LongTensor(labels)


def read_matches_files(data_dir, matches_file):
    """Return a Tensor containing the ground truth matches
       Read the file and keep only 3D point ID.
       Matches are represented with a 1, non matches with a 0.
    """
    matches = []
    with open(os.path.join(data_dir, matches_file), 'r') as f:
        for line in f:
            line_split = line.split()
            matches.append([int(line_split[0]), int(line_split[3]),
                            int(line_split[1] == line_split[4])])
    return torch.LongTensor(matches)

In [None]:
dataset = (
    read_image_file(args.data_dir, args.image_extension, args.n_patches),
    read_info_file(args.data_dir, args.info_file),
    read_matches_files(args.data_dir, args.matches_files)
)

 58%|█████▊    | 1018/1759 [10:21<07:01,  1.76it/s]

In [None]:
with open(data_file, 'wb') as f:
    torch.save(dataset, f)